diff --git a/vmware_nsx/plugins/nsx_v3/utils.py b/vmware_nsx/plugins/nsx_v3/utils.py index 46f585f442..685b4b7f7d 100644 --- a/vmware_nsx/plugins/nsx_v3/utils.py +++ b/vmware_nsx/plugins/nsx_v3/utils.py @@ -37,11 +37,13 @@ LOG = logging.getLogger(__name__) class DbCertProvider(client_cert.ClientCertProvider): """Write cert data from DB to file and delete after use - Since several connections may use same filename simultaneously, - this class maintains refcount to write/delete the file only once + New file with random filename is created for each thread. This + is not most efficient, but the safest way to avoid race conditions, + since backend connections can occur both before and after neutron + fork. """ EXPIRATION_ALERT_DAYS = 30 # days prior to expiration - lock = threading.Lock() + _thread_local = threading.local() def __init__(self): # Note: we cannot initialize filename here, because this call @@ -54,11 +56,6 @@ class DbCertProvider(client_cert.ClientCertProvider): super(DbCertProvider, self).__init__(None) random.seed() - with self.lock: - # Initialize refcount if other threads did not do it already - if not hasattr(self, 'refcount'): - self.refcount = 0 - def _check_expiration(self, expires_in_days): if expires_in_days > self.EXPIRATION_ALERT_DAYS: return @@ -72,16 +69,10 @@ class DbCertProvider(client_cert.ClientCertProvider): expires_in_days) def __enter__(self): - with self.lock: - self.refcount += 1 - - if self.refcount > 1: - # The file was already created and not yet deleted, use it - return self - # No certificate file available - need to create one # Choose a random filename to contain the certificate - self._filename = '/tmp/.' + str(random.randint(1, 10000000)) + self._thread_local._filename = '/tmp/.' + str( + random.randint(1, 10000000)) try: context = q_context.get_admin_context() @@ -95,15 +86,15 @@ class DbCertProvider(client_cert.ClientCertProvider): msg = _("Unable to load from nsx-db") raise nsx_exc.ClientCertificateException(err_msg=msg) - if not os.path.exists(os.path.dirname(self._filename)): - if len(os.path.dirname(self._filename)) > 0: - os.makedirs(os.path.dirname(self._filename)) - cert_manager.export_pem(self._filename) + filename = self._thread_local._filename + if not os.path.exists(os.path.dirname(filename)): + if len(os.path.dirname(filename)) > 0: + os.makedirs(os.path.dirname(filename)) + cert_manager.export_pem(filename) expires_in_days = cert_manager.expires_in_days() self._check_expiration(expires_in_days) except Exception as e: - # refcount has to be 1 here self._on_exit() raise e @@ -111,23 +102,17 @@ class DbCertProvider(client_cert.ClientCertProvider): return self def _on_exit(self): - self.refcount -= 1 + if os.path.isfile(self._thread_local._filename): + os.remove(self._thread_local._filename) + LOG.debug("Deleted client certificate file") - if self.refcount == 0: - # I am the last user of this file - if os.path.isfile(self._filename): - os.remove(self._filename) - LOG.debug("Deleted client certificate file") - - self._filename = None + self._thread_local._filename = None def __exit__(self, type, value, traceback): - with self.lock: - self._on_exit() + self._on_exit() def filename(self): - with self.lock: - return self._filename + return self._thread_local._filename def get_client_cert_provider():