diff --git a/nova/network/minidns.py b/nova/network/minidns.py index 42e52a4930..5da77a53e5 100644 --- a/nova/network/minidns.py +++ b/nova/network/minidns.py @@ -13,25 +13,28 @@ # under the License. import os -import shutil -import tempfile from oslo_config import cfg from oslo_log import log as logging +import six from nova import exception from nova.i18n import _ from nova.network import dns_driver + CONF = cfg.CONF LOG = logging.getLogger(__name__) class MiniDNS(dns_driver.DNSDriver): - """Trivial DNS driver. This will read/write to a local, flat file - and have no effect on your actual DNS system. This class is - strictly for testing purposes, and should keep you out of dependency - hell. + """Trivial DNS driver. This will read/write to either a local, + flat file or an in memory StringIO and have no effect on your actual + DNS system. This class is strictly for testing purposes, and should + keep you out of dependency hell. + + A file is used when CONF.log_dir is set. This is relevant for when + two different DNS driver instances share the same data file. Note that there is almost certainly a race condition here that will manifest anytime instances are rapidly created and deleted. @@ -39,25 +42,23 @@ class MiniDNS(dns_driver.DNSDriver): """ def __init__(self): + filename = None if CONF.log_dir: - self.filename = os.path.join(CONF.log_dir, "dnstest.txt") - self.tempdir = None + filename = os.path.join(CONF.log_dir, "dnstest.txt") + self.file = open(filename, 'w+') else: - self.tempdir = tempfile.mkdtemp() - self.filename = os.path.join(self.tempdir, "dnstest.txt") - LOG.debug('minidns file is |%s|', self.filename) - - if not os.path.exists(self.filename): - with open(self.filename, "w+") as f: - f.write("# minidns\n\n\n") + self.file = six.StringIO() + if not filename or not os.path.exists(filename): + self.file.write("# minidns\n\n\n") + self.file.flush() def get_domains(self): entries = [] - with open(self.filename, 'r') as infile: - for line in infile: - entry = self.parse_line(line) - if entry and entry['address'] == 'domain': - entries.append(entry['name']) + self.file.seek(0) + for line in self.file: + entry = self.parse_line(line) + if entry and entry['address'] == 'domain': + entries.append(entry['name']) return entries def qualify(self, name, domain): @@ -79,9 +80,10 @@ class MiniDNS(dns_driver.DNSDriver): if self.get_entries_by_name(name, domain): raise exception.FloatingIpDNSExists(name=name, domain=domain) - with open(self.filename, 'a+') as outfile: - outfile.write("%s %s %s\n" % - (address, self.qualify(name, domain), type)) + self.file.seek(0, os.SEEK_END) + self.file.write("%s %s %s\n" % + (address, self.qualify(name, domain), type)) + self.file.flush() def parse_line(self, line): vals = line.split() @@ -103,17 +105,19 @@ class MiniDNS(dns_driver.DNSDriver): raise exception.InvalidInput(_("Invalid name")) deleted = False - outfile = tempfile.NamedTemporaryFile('w', delete=False) - with open(self.filename, 'r') as infile: - for line in infile: - entry = self.parse_line(line) - if (not entry or - entry['name'] != self.qualify(name, domain)): - outfile.write(line) - else: - deleted = True - outfile.close() - shutil.move(outfile.name, self.filename) + keeps = [] + self.file.seek(0) + for line in self.file: + entry = self.parse_line(line) + if (not entry or + entry['name'] != self.qualify(name, domain)): + keeps.append(line) + else: + deleted = True + self.file.truncate(0) + self.file.seek(0) + self.file.write(''.join(keeps)) + self.file.flush() if not deleted: LOG.warning('Cannot delete entry |%s|', self.qualify(name, domain)) raise exception.NotFound @@ -123,76 +127,80 @@ class MiniDNS(dns_driver.DNSDriver): if not self.get_entries_by_name(name, domain): raise exception.NotFound - outfile = tempfile.NamedTemporaryFile('w', delete=False) - with open(self.filename, 'r') as infile: - for line in infile: - entry = self.parse_line(line) - if (entry and - entry['name'] == self.qualify(name, domain)): - outfile.write("%s %s %s\n" % - (address, self.qualify(name, domain), entry['type'])) - else: - outfile.write(line) - outfile.close() - shutil.move(outfile.name, self.filename) + lines = [] + self.file.seek(0) + for line in self.file: + entry = self.parse_line(line) + if (entry and + entry['name'] == self.qualify(name, domain)): + lines.append("%s %s %s\n" % + (address, self.qualify(name, domain), entry['type'])) + else: + lines.append(line) + self.file.truncate(0) + self.file.seek(0) + self.file.write(''.join(lines)) + self.file.flush() def get_entries_by_address(self, address, domain): entries = [] - with open(self.filename, 'r') as infile: - for line in infile: - entry = self.parse_line(line) - if entry and entry['address'] == address.lower(): - if entry['name'].endswith(domain.lower()): - name = entry['name'].split(".")[0] - if name not in entries: - entries.append(name) + self.file.seek(0) + for line in self.file: + entry = self.parse_line(line) + if entry and entry['address'] == address.lower(): + if entry['name'].endswith(domain.lower()): + name = entry['name'].split(".")[0] + if name not in entries: + entries.append(name) return entries def get_entries_by_name(self, name, domain): entries = [] - with open(self.filename, 'r') as infile: - for line in infile: - entry = self.parse_line(line) - if (entry and - entry['name'] == self.qualify(name, domain)): - entries.append(entry['address']) + self.file.seek(0) + for line in self.file: + entry = self.parse_line(line) + if (entry and + entry['name'] == self.qualify(name, domain)): + entries.append(entry['address']) return entries def delete_dns_file(self): - if os.path.exists(self.filename): - try: - os.remove(self.filename) - except OSError: - pass - if self.tempdir and os.path.exists(self.tempdir): - try: - shutil.rmtree(self.tempdir) - except OSError: - pass + self.file.close() + try: + if os.path.exists(self.file.name): + try: + os.remove(self.file.name) + except OSError: + pass + except AttributeError: + # This was a BytesIO, which has no name. + pass def create_domain(self, fqdomain): if self.get_entries_by_name(fqdomain, ''): raise exception.FloatingIpDNSExists(name=fqdomain, domain='') - with open(self.filename, 'a+') as outfile: - outfile.write("%s %s %s\n" % - ('domain', fqdomain, 'domain')) + self.file.seek(0, os.SEEK_END) + self.file.write("%s %s %s\n" % ('domain', fqdomain, 'domain')) + self.file.flush() def delete_domain(self, fqdomain): deleted = False - outfile = tempfile.NamedTemporaryFile('w', delete=False) - with open(self.filename, 'r') as infile: - for line in infile: - entry = self.parse_line(line) - if (not entry or - entry['domain'] != fqdomain.lower()): - outfile.write(line) - else: - LOG.info("deleted %s", entry) - deleted = True - outfile.close() - shutil.move(outfile.name, self.filename) + keeps = [] + self.file.seek(0) + for line in self.file: + entry = self.parse_line(line) + if (not entry or + entry['domain'] != fqdomain.lower()): + keeps.append(line) + else: + LOG.info("deleted %s", entry) + deleted = True + self.file.truncate(0) + self.file.seek(0) + self.file.write(''.join(keeps)) + self.file.flush() if not deleted: LOG.warning('Cannot delete domain |%s|', fqdomain) raise exception.NotFound