diff --git a/designate/central/service.py b/designate/central/service.py index 613af8d7c..509857896 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -16,11 +16,13 @@ # under the License. import re import contextlib +import functools from oslo.config import cfg from oslo import messaging from designate.openstack.common import log as logging +from designate.openstack.common import excutils from designate.openstack.common.gettextutils import _LI from designate.openstack.common.gettextutils import _LC from designate import backend @@ -31,7 +33,7 @@ from designate import policy from designate import quota from designate import service from designate import utils -from designate.storage import api as storage_api +from designate import storage LOG = logging.getLogger(__name__) @@ -50,6 +52,22 @@ def wrap_backend_call(): raise exceptions.Backend('Unknown backend failure: %r' % exc) +def transaction(f): + # TODO(kiall): Get this a better home :) + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + self.storage.begin() + try: + result = f(self, *args, **kwargs) + except Exception: + with excutils.save_and_reraise_exception(): + self.storage.rollback() + else: + self.storage.commit() + return result + return wrapper + + class Service(service.Service): RPC_API_VERSION = '4.0' @@ -65,7 +83,7 @@ class Service(service.Service): # Get a storage connection storage_driver = cfg.CONF['service:central'].storage_driver - self.storage_api = storage_api.StorageAPI(storage_driver) + self.storage = storage.get_storage(storage_driver) # Get a quota manager instance self.quota = quota.get_quota() @@ -74,7 +92,7 @@ class Service(service.Service): def start(self): # Check to see if there are any TLDs in the database - tlds = self.storage_api.find_tlds({}) + tlds = self.storage.find_tlds({}) if tlds: self.check_for_tlds = True LOG.info(_LI("Checking for TLDs")) @@ -107,14 +125,14 @@ class Service(service.Service): # Check the TLD for validity if there are entries in the database if self.check_for_tlds: try: - self.storage_api.find_tld(context, {'name': domain_labels[-1]}) + self.storage.find_tld(context, {'name': domain_labels[-1]}) except exceptions.TLDNotFound: raise exceptions.InvalidDomainName('Invalid TLD') # Now check that the domain name is not the same as a TLD try: stripped_domain_name = domain_name.strip('.').lower() - self.storage_api.find_tld( + self.storage.find_tld( context, {'name': stripped_domain_name}) except exceptions.TLDNotFound: @@ -162,7 +180,7 @@ class Service(service.Service): if recordset_type != 'CNAME': criterion['type'] = 'CNAME' - recordsets = self.storage_api.find_recordsets(context, criterion) + recordsets = self.storage.find_recordsets(context, criterion) if ((len(recordsets) == 1 and recordsets[0].id != recordset_id) or len(recordsets) > 1): @@ -189,7 +207,7 @@ class Service(service.Service): if domain.name == recordset_name: return - child_domains = self.storage_api.find_domains( + child_domains = self.storage.find_domains( context, {"parent_domain_id": domain.id}) for child_domain in child_domains: try: @@ -207,7 +225,7 @@ class Service(service.Service): Ensures the provided domain_name is not blacklisted. """ - blacklists = self.storage_api.find_blacklists(context) + blacklists = self.storage.find_blacklists(context) for blacklist in blacklists: if bool(re.search(blacklist.pattern, domain_name)): @@ -229,7 +247,7 @@ class Service(service.Service): name = '.'.join(labels[i:]) try: - domain = self.storage_api.find_domain(context, {'name': name}) + domain = self.storage.find_domain(context, {'name': name}) except exceptions.DomainNotFound: i += 1 else: @@ -247,22 +265,22 @@ class Service(service.Service): % min_ttl) def _increment_domain_serial(self, context, domain_id): - domain = self.storage_api.get_domain(context, domain_id) + domain = self.storage.get_domain(context, domain_id) # Increment the serial number values = {'serial': utils.increment_serial(domain['serial'])} - with self.storage_api.update_domain( - context, domain_id, values) as domain: - with wrap_backend_call(): - self.backend.update_domain(context, domain) + domain = self.storage.update_domain(context, domain_id, values) + + with wrap_backend_call(): + self.backend.update_domain(context, domain) return domain # Quota Enforcement Methods def _enforce_domain_quota(self, context, tenant_id): criterion = {'tenant_id': tenant_id} - count = self.storage_api.count_domains(context, criterion) + count = self.storage.count_domains(context, criterion) self.quota.limit_check(context, tenant_id, domains=count) @@ -273,7 +291,7 @@ class Service(service.Service): def _enforce_record_quota(self, context, domain, recordset): # Ensure the records per domain quota is OK criterion = {'domain_id': domain['id']} - count = self.storage_api.count_records(context, criterion) + count = self.storage.count_records(context, criterion) self.quota.limit_check(context, domain['tenant_id'], domain_records=count) @@ -298,6 +316,7 @@ class Service(service.Service): return self.quota.get_quota(context, tenant_id, resource) + @transaction def set_quota(self, context, tenant_id, resource, hard_limit): target = { 'tenant_id': tenant_id, @@ -309,6 +328,7 @@ class Service(service.Service): return self.quota.set_quota(context, tenant_id, resource, hard_limit) + @transaction def reset_quotas(self, context, tenant_id): target = {'tenant_id': tenant_id} policy.check('reset_quotas', context, target) @@ -316,13 +336,15 @@ class Service(service.Service): self.quota.reset_quotas(context, tenant_id) # Server Methods + @transaction def create_server(self, context, server): policy.check('create_server', context) - with self.storage_api.create_server(context, server) as created_server: - # Update backend with the new server.. - with wrap_backend_call(): - self.backend.create_server(context, created_server) + created_server = self.storage.create_server(context, server) + + # Update backend with the new server.. + with wrap_backend_call(): + self.backend.create_server(context, created_server) self.notifier.info(context, 'dns.server.create', created_server) @@ -332,50 +354,54 @@ class Service(service.Service): sort_key=None, sort_dir=None): policy.check('find_servers', context) - return self.storage_api.find_servers(context, criterion, marker, limit, - sort_key, sort_dir) + return self.storage.find_servers(context, criterion, marker, limit, + sort_key, sort_dir) def get_server(self, context, server_id): policy.check('get_server', context, {'server_id': server_id}) - return self.storage_api.get_server(context, server_id) + return self.storage.get_server(context, server_id) + @transaction def update_server(self, context, server_id, values): policy.check('update_server', context, {'server_id': server_id}) - with self.storage_api.update_server( - context, server_id, values) as server: - # Update backend with the new details.. - with wrap_backend_call(): - self.backend.update_server(context, server) + server = self.storage.update_server(context, server_id, values) + + # Update backend with the new details.. + with wrap_backend_call(): + self.backend.update_server(context, server) self.notifier.info(context, 'dns.server.update', server) return server + @transaction def delete_server(self, context, server_id): policy.check('delete_server', context, {'server_id': server_id}) # don't delete last of servers - servers = self.storage_api.find_servers(context) + servers = self.storage.find_servers(context) if len(servers) == 1 and server_id == servers[0].id: raise exceptions.LastServerDeleteNotAllowed( "Not allowed to delete last of servers") - with self.storage_api.delete_server(context, server_id) as server: - # Update backend with the new server.. - with wrap_backend_call(): - self.backend.delete_server(context, server) + server = self.storage.delete_server(context, server_id) + + # Update backend with the new server.. + with wrap_backend_call(): + self.backend.delete_server(context, server) self.notifier.info(context, 'dns.server.delete', server) # TLD Methods + @transaction def create_tld(self, context, tld): policy.check('create_tld', context) # The TLD is only created on central's storage and not on the backend. - with self.storage_api.create_tld(context, tld) as created_tld: - pass + created_tld = self.storage.create_tld(context, tld) + self.notifier.info(context, 'dns.tld.create', created_tld) # Set check for tlds to be true @@ -386,24 +412,25 @@ class Service(service.Service): sort_key=None, sort_dir=None): policy.check('find_tlds', context) - return self.storage_api.find_tlds(context, criterion, marker, limit, - sort_key, sort_dir) + return self.storage.find_tlds(context, criterion, marker, limit, + sort_key, sort_dir) def get_tld(self, context, tld_id): policy.check('get_tld', context, {'tld_id': tld_id}) - return self.storage_api.get_tld(context, tld_id) + return self.storage.get_tld(context, tld_id) + @transaction def update_tld(self, context, tld_id, values): policy.check('update_tld', context, {'tld_id': tld_id}) - with self.storage_api.update_tld(context, tld_id, values) as tld: - pass + tld = self.storage.update_tld(context, tld_id, values) self.notifier.info(context, 'dns.tld.update', tld) return tld + @transaction def delete_tld(self, context, tld_id): # Known issue - self.check_for_tld is not reset here. So if the last # TLD happens to be deleted, then we would incorrectly do the TLD @@ -412,19 +439,19 @@ class Service(service.Service): # of hitting this issue vs doing the checks for every delete. policy.check('delete_tld', context, {'tld_id': tld_id}) - with self.storage_api.delete_tld(context, tld_id) as tld: - pass + tld = self.storage.delete_tld(context, tld_id) self.notifier.info(context, 'dns.tld.delete', tld) # TSIG Key Methods + @transaction def create_tsigkey(self, context, tsigkey): policy.check('create_tsigkey', context) - with self.storage_api.create_tsigkey(context, tsigkey) \ - as created_tsigkey: - with wrap_backend_call(): - self.backend.create_tsigkey(context, created_tsigkey) + created_tsigkey = self.storage.create_tsigkey(context, tsigkey) + + with wrap_backend_call(): + self.backend.create_tsigkey(context, created_tsigkey) self.notifier.info(context, 'dns.tsigkey.create', created_tsigkey) @@ -434,39 +461,42 @@ class Service(service.Service): sort_key=None, sort_dir=None): policy.check('find_tsigkeys', context) - return self.storage_api.find_tsigkeys(context, criterion, marker, - limit, sort_key, sort_dir) + return self.storage.find_tsigkeys(context, criterion, marker, + limit, sort_key, sort_dir) def get_tsigkey(self, context, tsigkey_id): policy.check('get_tsigkey', context, {'tsigkey_id': tsigkey_id}) - return self.storage_api.get_tsigkey(context, tsigkey_id) + return self.storage.get_tsigkey(context, tsigkey_id) + @transaction def update_tsigkey(self, context, tsigkey_id, values): policy.check('update_tsigkey', context, {'tsigkey_id': tsigkey_id}) - with self.storage_api.update_tsigkey( - context, tsigkey_id, values) as tsigkey: - with wrap_backend_call(): - self.backend.update_tsigkey(context, tsigkey) + tsigkey = self.storage.update_tsigkey(context, tsigkey_id, values) + + with wrap_backend_call(): + self.backend.update_tsigkey(context, tsigkey) self.notifier.info(context, 'dns.tsigkey.update', tsigkey) return tsigkey + @transaction def delete_tsigkey(self, context, tsigkey_id): policy.check('delete_tsigkey', context, {'tsigkey_id': tsigkey_id}) - with self.storage_api.delete_tsigkey(context, tsigkey_id) as tsigkey: - with wrap_backend_call(): - self.backend.delete_tsigkey(context, tsigkey) + tsigkey = self.storage.delete_tsigkey(context, tsigkey_id) + + with wrap_backend_call(): + self.backend.delete_tsigkey(context, tsigkey) self.notifier.info(context, 'dns.tsigkey.delete', tsigkey) # Tenant Methods def find_tenants(self, context): policy.check('find_tenants', context) - return self.storage_api.find_tenants(context) + return self.storage.find_tenants(context) def get_tenant(self, context, tenant_id): target = { @@ -475,13 +505,14 @@ class Service(service.Service): policy.check('get_tenant', context, target) - return self.storage_api.get_tenant(context, tenant_id) + return self.storage.get_tenant(context, tenant_id) def count_tenants(self, context): policy.check('count_tenants', context) - return self.storage_api.count_tenants(context) + return self.storage.count_tenants(context) # Domain Methods + @transaction def create_domain(self, context, domain): # TODO(kiall): Refactor this method into *MUCH* smaller chunks. @@ -522,7 +553,7 @@ class Service(service.Service): # NOTE(kiall): Fetch the servers before creating the domain, this way # we can prevent domain creation if no servers are # configured. - servers = self.storage_api.find_servers(context) + servers = self.storage.find_servers(context) if len(servers) == 0: LOG.critical(_LC('No servers configured. ' @@ -532,16 +563,17 @@ class Service(service.Service): # Set the serial number domain.serial = utils.increment_serial() - with self.storage_api.create_domain(context, domain) as created_domain: - with wrap_backend_call(): - self.backend.create_domain(context, created_domain) + created_domain = self.storage.create_domain(context, domain) + + with wrap_backend_call(): + self.backend.create_domain(context, created_domain) self.notifier.info(context, 'dns.domain.create', created_domain) return created_domain def get_domain(self, context, domain_id): - domain = self.storage_api.get_domain(context, domain_id) + domain = self.storage.get_domain(context, domain_id) target = { 'domain_id': domain_id, @@ -553,7 +585,7 @@ class Service(service.Service): return domain def get_domain_servers(self, context, domain_id, criterion=None): - domain = self.storage_api.get_domain(context, domain_id) + domain = self.storage.get_domain(context, domain_id) target = { 'domain_id': domain_id, @@ -565,25 +597,26 @@ class Service(service.Service): # TODO(kiall): Once we allow domains to be allocated on 1 of N server # pools, return the filtered list here. - return self.storage_api.find_servers(context, criterion) + return self.storage.find_servers(context, criterion) def find_domains(self, context, criterion=None, marker=None, limit=None, sort_key=None, sort_dir=None): target = {'tenant_id': context.tenant} policy.check('find_domains', context, target) - return self.storage_api.find_domains(context, criterion, marker, limit, - sort_key, sort_dir) + return self.storage.find_domains(context, criterion, marker, limit, + sort_key, sort_dir) def find_domain(self, context, criterion=None): target = {'tenant_id': context.tenant} policy.check('find_domain', context, target) - return self.storage_api.find_domain(context, criterion) + return self.storage.find_domain(context, criterion) + @transaction def update_domain(self, context, domain_id, values, increment_serial=True): # TODO(kiall): Refactor this method into *MUCH* smaller chunks. - domain = self.storage_api.get_domain(context, domain_id) + domain = self.storage.get_domain(context, domain_id) target = { 'domain_id': domain_id, @@ -615,17 +648,18 @@ class Service(service.Service): # Increment the serial number values['serial'] = utils.increment_serial(domain.serial) - with self.storage_api.update_domain( - context, domain_id, values) as domain: - with wrap_backend_call(): - self.backend.update_domain(context, domain) + domain = self.storage.update_domain(context, domain_id, values) + + with wrap_backend_call(): + self.backend.update_domain(context, domain) self.notifier.info(context, 'dns.domain.update', domain) return domain + @transaction def delete_domain(self, context, domain_id): - domain = self.storage_api.get_domain(context, domain_id) + domain = self.storage.get_domain(context, domain_id) target = { 'domain_id': domain_id, @@ -638,13 +672,14 @@ class Service(service.Service): # Prevent deletion of a zone which has child zones criterion = {'parent_domain_id': domain_id} - if self.storage_api.count_domains(context, criterion) > 0: + if self.storage.count_domains(context, criterion) > 0: raise exceptions.DomainHasSubdomain('Please delete any subdomains ' 'before deleting this domain') - with self.storage_api.delete_domain(context, domain_id) as domain: - with wrap_backend_call(): - self.backend.delete_domain(context, domain) + domain = self.storage.delete_domain(context, domain_id) + + with wrap_backend_call(): + self.backend.delete_domain(context, domain) self.notifier.info(context, 'dns.domain.delete', domain) @@ -660,10 +695,11 @@ class Service(service.Service): policy.check('count_domains', context, target) - return self.storage_api.count_domains(context, criterion) + return self.storage.count_domains(context, criterion) + @transaction def touch_domain(self, context, domain_id): - domain = self.storage_api.get_domain(context, domain_id) + domain = self.storage.get_domain(context, domain_id) target = { 'domain_id': domain_id, @@ -680,8 +716,9 @@ class Service(service.Service): return domain # RecordSet Methods + @transaction def create_recordset(self, context, domain_id, recordset): - domain = self.storage_api.get_domain(context, domain_id) + domain = self.storage.get_domain(context, domain_id) target = { 'domain_id': domain_id, @@ -707,11 +744,11 @@ class Service(service.Service): self._is_valid_recordset_placement_subdomain( context, domain, recordset.name) - with self.storage_api.create_recordset( - context, domain_id, recordset) as created_recordset: - with wrap_backend_call(): - self.backend.create_recordset( - context, domain, created_recordset) + created_recordset = self.storage.create_recordset(context, domain_id, + recordset) + + with wrap_backend_call(): + self.backend.create_recordset(context, domain, created_recordset) # Send RecordSet creation notification self.notifier.info(context, 'dns.recordset.create', created_recordset) @@ -719,8 +756,8 @@ class Service(service.Service): return created_recordset def get_recordset(self, context, domain_id, recordset_id): - domain = self.storage_api.get_domain(context, domain_id) - recordset = self.storage_api.get_recordset(context, recordset_id) + domain = self.storage.get_domain(context, domain_id) + recordset = self.storage.get_recordset(context, recordset_id) # Ensure the domain_id matches the record's domain_id if domain.id != recordset.domain_id: @@ -742,19 +779,20 @@ class Service(service.Service): target = {'tenant_id': context.tenant} policy.check('find_recordsets', context, target) - return self.storage_api.find_recordsets(context, criterion, marker, - limit, sort_key, sort_dir) + return self.storage.find_recordsets(context, criterion, marker, + limit, sort_key, sort_dir) def find_recordset(self, context, criterion=None): target = {'tenant_id': context.tenant} policy.check('find_recordset', context, target) - return self.storage_api.find_recordset(context, criterion) + return self.storage.find_recordset(context, criterion) + @transaction def update_recordset(self, context, domain_id, recordset_id, values, increment_serial=True): - domain = self.storage_api.get_domain(context, domain_id) - recordset = self.storage_api.get_recordset(context, recordset_id) + domain = self.storage.get_domain(context, domain_id) + recordset = self.storage.get_recordset(context, recordset_id) # Ensure the domain_id matches the recordset's domain_id if domain.id != recordset.domain_id: @@ -787,23 +825,25 @@ class Service(service.Service): self._is_valid_ttl(context, ttl) # Update the recordset - with self.storage_api.update_recordset( - context, recordset_id, values) as recordset: - with wrap_backend_call(): - self.backend.update_recordset(context, domain, recordset) + recordset = self.storage.update_recordset(context, recordset_id, + values) - if increment_serial: - self._increment_domain_serial(context, domain_id) + with wrap_backend_call(): + self.backend.update_recordset(context, domain, recordset) + + if increment_serial: + self._increment_domain_serial(context, domain_id) # Send RecordSet update notification self.notifier.info(context, 'dns.recordset.update', recordset) return recordset + @transaction def delete_recordset(self, context, domain_id, recordset_id, increment_serial=True): - domain = self.storage_api.get_domain(context, domain_id) - recordset = self.storage_api.get_recordset(context, recordset_id) + domain = self.storage.get_domain(context, domain_id) + recordset = self.storage.get_recordset(context, recordset_id) # Ensure the domain_id matches the recordset's domain_id if domain.id != recordset.domain_id: @@ -818,13 +858,13 @@ class Service(service.Service): policy.check('delete_recordset', context, target) - with self.storage_api.delete_recordset(context, recordset_id) \ - as recordset: - with wrap_backend_call(): - self.backend.delete_recordset(context, domain, recordset) + recordset = self.storage.delete_recordset(context, recordset_id) - if increment_serial: - self._increment_domain_serial(context, domain_id) + with wrap_backend_call(): + self.backend.delete_recordset(context, domain, recordset) + + if increment_serial: + self._increment_domain_serial(context, domain_id) # Send Record deletion notification self.notifier.info(context, 'dns.recordset.delete', recordset) @@ -841,13 +881,14 @@ class Service(service.Service): policy.check('count_recordsets', context, target) - return self.storage_api.count_recordsets(context, criterion) + return self.storage.count_recordsets(context, criterion) # Record Methods + @transaction def create_record(self, context, domain_id, recordset_id, record, increment_serial=True): - domain = self.storage_api.get_domain(context, domain_id) - recordset = self.storage_api.get_recordset(context, recordset_id) + domain = self.storage.get_domain(context, domain_id) + recordset = self.storage.get_recordset(context, recordset_id) target = { 'domain_id': domain_id, @@ -862,14 +903,15 @@ class Service(service.Service): # Ensure the tenant has enough quota to continue self._enforce_record_quota(context, domain, recordset) - with self.storage_api.create_record( - context, domain_id, recordset_id, record) as created_record: - with wrap_backend_call(): - self.backend.create_record( - context, domain, recordset, created_record) + created_record = self.storage.create_record(context, domain_id, + recordset_id, record) - if increment_serial: - self._increment_domain_serial(context, domain_id) + with wrap_backend_call(): + self.backend.create_record( + context, domain, recordset, created_record) + + if increment_serial: + self._increment_domain_serial(context, domain_id) # Send Record creation notification self.notifier.info(context, 'dns.record.create', created_record) @@ -877,9 +919,9 @@ class Service(service.Service): return created_record def get_record(self, context, domain_id, recordset_id, record_id): - domain = self.storage_api.get_domain(context, domain_id) - recordset = self.storage_api.get_recordset(context, recordset_id) - record = self.storage_api.get_record(context, record_id) + domain = self.storage.get_domain(context, domain_id) + recordset = self.storage.get_recordset(context, recordset_id) + record = self.storage.get_record(context, record_id) # Ensure the domain_id matches the record's domain_id if domain.id != record.domain_id: @@ -907,20 +949,21 @@ class Service(service.Service): target = {'tenant_id': context.tenant} policy.check('find_records', context, target) - return self.storage_api.find_records(context, criterion, marker, limit, - sort_key, sort_dir) + return self.storage.find_records(context, criterion, marker, limit, + sort_key, sort_dir) def find_record(self, context, criterion=None): target = {'tenant_id': context.tenant} policy.check('find_record', context, target) - return self.storage_api.find_record(context, criterion) + return self.storage.find_record(context, criterion) + @transaction def update_record(self, context, domain_id, recordset_id, record_id, values, increment_serial=True): - domain = self.storage_api.get_domain(context, domain_id) - recordset = self.storage_api.get_recordset(context, recordset_id) - record = self.storage_api.get_record(context, record_id) + domain = self.storage.get_domain(context, domain_id) + recordset = self.storage.get_recordset(context, recordset_id) + record = self.storage.get_record(context, record_id) # Ensure the domain_id matches the record's domain_id if domain.id != record.domain_id: @@ -942,24 +985,25 @@ class Service(service.Service): policy.check('update_record', context, target) # Update the record - with self.storage_api.update_record( - context, record_id, values) as record: - with wrap_backend_call(): - self.backend.update_record(context, domain, recordset, record) + record = self.storage.update_record(context, record_id, values) - if increment_serial: - self._increment_domain_serial(context, domain_id) + with wrap_backend_call(): + self.backend.update_record(context, domain, recordset, record) + + if increment_serial: + self._increment_domain_serial(context, domain_id) # Send Record update notification self.notifier.info(context, 'dns.record.update', record) return record + @transaction def delete_record(self, context, domain_id, recordset_id, record_id, increment_serial=True): - domain = self.storage_api.get_domain(context, domain_id) - recordset = self.storage_api.get_recordset(context, recordset_id) - record = self.storage_api.get_record(context, record_id) + domain = self.storage.get_domain(context, domain_id) + recordset = self.storage.get_recordset(context, recordset_id) + record = self.storage.get_record(context, record_id) # Ensure the domain_id matches the record's domain_id if domain.id != record.domain_id: @@ -980,12 +1024,13 @@ class Service(service.Service): policy.check('delete_record', context, target) - with self.storage_api.delete_record(context, record_id) as record: - with wrap_backend_call(): - self.backend.delete_record(context, domain, recordset, record) + record = self.storage.delete_record(context, record_id) - if increment_serial: - self._increment_domain_serial(context, domain_id) + with wrap_backend_call(): + self.backend.delete_record(context, domain, recordset, record) + + if increment_serial: + self._increment_domain_serial(context, domain_id) # Send Record deletion notification self.notifier.info(context, 'dns.record.delete', record) @@ -1001,11 +1046,11 @@ class Service(service.Service): } policy.check('count_records', context, target) - return self.storage_api.count_records(context, criterion) + return self.storage.count_records(context, criterion) # Diagnostics Methods def _sync_domain(self, context, domain): - recordsets = self.storage_api.find_recordsets( + recordsets = self.storage.find_recordsets( context, criterion={'domain_id': domain['id']}) # Since we now have records as well as recordsets we need to get the @@ -1018,10 +1063,11 @@ class Service(service.Service): with wrap_backend_call(): return self.backend.sync_domain(context, domain, rdata) + @transaction def sync_domains(self, context): policy.check('diagnostics_sync_domains', context) - domains = self.storage_api.find_domains(context) + domains = self.storage.find_domains(context) results = {} for domain in domains: @@ -1029,8 +1075,9 @@ class Service(service.Service): return results + @transaction def sync_domain(self, context, domain_id): - domain = self.storage_api.get_domain(context, domain_id) + domain = self.storage.get_domain(context, domain_id) target = { 'domain_id': domain_id, @@ -1042,9 +1089,10 @@ class Service(service.Service): return self._sync_domain(context, domain) + @transaction def sync_record(self, context, domain_id, recordset_id, record_id): - domain = self.storage_api.get_domain(context, domain_id) - recordset = self.storage_api.get_recordset(context, recordset_id) + domain = self.storage.get_domain(context, domain_id) + recordset = self.storage.get_recordset(context, recordset_id) target = { 'domain_id': domain_id, @@ -1057,7 +1105,7 @@ class Service(service.Service): policy.check('diagnostics_sync_record', context, target) - record = self.storage_api.get_record(context, record_id) + record = self.storage.get_record(context, record_id) with wrap_backend_call(): return self.backend.sync_record(context, domain, recordset, record) @@ -1071,7 +1119,7 @@ class Service(service.Service): backend_status = {'status': False, 'message': str(e)} try: - storage_status = self.storage_api.ping(context) + storage_status = self.storage.ping(context) except Exception as e: storage_status = {'status': False, 'message': str(e)} @@ -1175,7 +1223,7 @@ class Service(service.Service): value[1]['recordset_id'] in recordsets): recordset = recordsets[value[1]['recordset_id']] else: - recordset = self.storage_api.get_recordset( + recordset = self.storage.get_recordset( elevated_context, value[1]['recordset_id']) if recordset['ttl'] is not None: @@ -1269,7 +1317,7 @@ class Service(service.Service): # NOTE: Find existing zone or create it.. try: - zone = self.storage_api.find_domain( + zone = self.storage.find_domain( elevated_context, {'name': zone_name}) except exceptions.DomainNotFound: msg = _LI('Creating zone for %(fip_id)s:%(region)s - ' @@ -1363,7 +1411,7 @@ class Service(service.Service): } try: - record = self.storage_api.find_record( + record = self.storage.find_record( elevated_context, criterion=criterion) except exceptions.RecordNotFound: msg = 'No such FloatingIP %s:%s' % (region, floatingip_id) @@ -1375,6 +1423,7 @@ class Service(service.Service): record['recordset_id'], record['id']) + @transaction def update_floatingip(self, context, region, floatingip_id, values): """ We strictly see if values['ptrdname'] is str or None and set / unset @@ -1387,12 +1436,11 @@ class Service(service.Service): context, region, floatingip_id, values) # Blacklisted Domains + @transaction def create_blacklist(self, context, blacklist): policy.check('create_blacklist', context) - with self.storage_api.create_blacklist(context, blacklist) as \ - created_blacklist: - pass # NOTE: No other systems need updating + created_blacklist = self.storage.create_blacklist(context, blacklist) self.notifier.info(context, 'dns.blacklist.create', created_blacklist) @@ -1401,7 +1449,7 @@ class Service(service.Service): def get_blacklist(self, context, blacklist_id): policy.check('get_blacklist', context) - blacklist = self.storage_api.get_blacklist(context, blacklist_id) + blacklist = self.storage.get_blacklist(context, blacklist_id) return blacklist @@ -1409,36 +1457,34 @@ class Service(service.Service): limit=None, sort_key=None, sort_dir=None): policy.check('find_blacklists', context) - blacklists = self.storage_api.find_blacklists(context, criterion, - marker, limit, - sort_key, sort_dir) + blacklists = self.storage.find_blacklists(context, criterion, + marker, limit, + sort_key, sort_dir) return blacklists def find_blacklist(self, context, criterion): policy.check('find_blacklist', context) - blacklist = self.storage_api.find_blacklist(context, criterion) + blacklist = self.storage.find_blacklist(context, criterion) return blacklist + @transaction def update_blacklist(self, context, blacklist_id, values): policy.check('update_blacklist', context) - with self.storage_api.update_blacklist(context, - blacklist_id, - values) as blacklist: - pass # NOTE: No other systems need updating + blacklist = self.storage.update_blacklist(context, blacklist_id, + values) self.notifier.info(context, 'dns.blacklist.update', blacklist) return blacklist + @transaction def delete_blacklist(self, context, blacklist_id): policy.check('delete_blacklist', context) - with self.storage_api.delete_blacklist(context, - blacklist_id) as blacklist: - pass # NOTE: No other systems need updating + blacklist = self.storage.delete_blacklist(context, blacklist_id) self.notifier.info(context, 'dns.blacklist.delete', blacklist) diff --git a/designate/mdns/handler.py b/designate/mdns/handler.py index 88bea079d..0782a7699 100644 --- a/designate/mdns/handler.py +++ b/designate/mdns/handler.py @@ -17,7 +17,7 @@ import dns from oslo.config import cfg from designate.openstack.common import log as logging -from designate.storage import api as storage_api +from designate import storage LOG = logging.getLogger(__name__) @@ -28,7 +28,7 @@ class RequestHandler(object): def __init__(self): # Get a storage connection storage_driver = cfg.CONF['service:mdns'].storage_driver - self.storage_api = storage_api.StorageAPI(storage_driver) + self.storage = storage.get_storage(storage_driver) def handle(self, payload): request = dns.message.from_wire(payload) diff --git a/designate/quota/impl_storage.py b/designate/quota/impl_storage.py index 9744b9552..5eb1121a0 100644 --- a/designate/quota/impl_storage.py +++ b/designate/quota/impl_storage.py @@ -16,9 +16,9 @@ from oslo.config import cfg from designate import exceptions +from designate import storage from designate.openstack.common import log as logging from designate.quota.base import Quota -from designate.storage import api as sapi LOG = logging.getLogger(__name__) @@ -27,18 +27,15 @@ LOG = logging.getLogger(__name__) class StorageQuota(Quota): __plugin_name__ = 'storage' - def __init__(self, storage_api=None): + def __init__(self): super(StorageQuota, self).__init__() - if storage_api is None: - # TODO(kiall): Should this be tied to central's config? - storage_driver = cfg.CONF['service:central'].storage_driver - storage_api = sapi.StorageAPI(storage_driver) - - self.storage_api = storage_api + # TODO(kiall): Should this be tied to central's config? + storage_driver = cfg.CONF['service:central'].storage_driver + self.storage = storage.get_storage(storage_driver) def _get_quotas(self, context, tenant_id): - quotas = self.storage_api.find_quotas(context, { + quotas = self.storage.find_quotas(context, { 'tenant_id': tenant_id, }) @@ -48,7 +45,7 @@ class StorageQuota(Quota): context = context.deepcopy() context.all_tenants = True - quota = self.storage_api.find_quota(context, { + quota = self.storage.find_quota(context, { 'tenant_id': tenant_id, 'resource': resource, }) @@ -66,21 +63,19 @@ class StorageQuota(Quota): 'hard_limit': hard_limit, } - with self.storage_api.create_quota(context, values): - pass # NOTE(kiall): No other systems need updating. + self.storage.create_quota(context, values) def update_quota(): values = {'hard_limit': hard_limit} - with self.storage_api.update_quota(context, quota['id'], values): - pass # NOTE(kiall): No other systems need updating. + self.storage.update_quota(context, quota['id'], values) if resource not in self.get_default_quotas(context).keys(): raise exceptions.QuotaResourceUnknown("%s is not a valid quota " "resource", resource) try: - quota = self.storage_api.find_quota(context, { + quota = self.storage.find_quota(context, { 'tenant_id': tenant_id, 'resource': resource, }) @@ -95,10 +90,9 @@ class StorageQuota(Quota): context = context.deepcopy() context.all_tenants = True - quotas = self.storage_api.find_quotas(context, { + quotas = self.storage.find_quotas(context, { 'tenant_id': tenant_id, }) for quota in quotas: - with self.storage_api.delete_quota(context, quota['id']): - pass # NOTE(kiall): No other systems need updating. + self.storage.delete_quota(context, quota['id']) diff --git a/designate/storage/api.py b/designate/storage/api.py deleted file mode 100644 index efd334564..000000000 --- a/designate/storage/api.py +++ /dev/null @@ -1,789 +0,0 @@ -# Copyright 2013 Hewlett-Packard Development Company, L.P. -# -# Author: Kiall Mac Innes -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -import contextlib - -from designate import storage -from designate.openstack.common import excutils - - -class StorageAPI(object): - """ Storage API """ - - def __init__(self, storage_driver): - self.storage = storage.get_storage(storage_driver) - - def _extract_dict_subset(self, d, keys): - return dict([(k, d[k]) for k in keys if k in d]) - - @contextlib.contextmanager - def create_quota(self, context, values): - """ - Create a Quota. - - :param context: RPC Context. - :param values: Values to create the new Quota from. - """ - self.storage.begin() - - try: - quota = self.storage.create_quota(context, values) - yield quota - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def get_quota(self, context, quota_id): - """ - Get a Quota via ID. - - :param context: RPC Context. - :param quota_id: Quota ID to get. - """ - return self.storage.get_quota(context, quota_id) - - def find_quotas(self, context, criterion=None, marker=None, limit=None, - sort_key=None, sort_dir=None): - """ - Find Quotas - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_quotas( - context, criterion, marker, limit, sort_key, sort_dir) - - def find_quota(self, context, criterion): - """ - Find a single Quota. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_quota(context, criterion) - - @contextlib.contextmanager - def update_quota(self, context, quota_id, values): - """ - Update a Quota via ID - - :param context: RPC Context. - :param quota_id: Quota ID to update. - :param values: Values to update the Quota from - """ - self.storage.begin() - - try: - quota = self.storage.update_quota(context, quota_id, values) - yield quota - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def delete_quota(self, context, quota_id): - """ - Delete a Quota via ID. - - :param context: RPC Context. - :param quota_id: Delete a Quota via ID - """ - self.storage.begin() - - try: - quota = self.storage.delete_quota(context, quota_id) - yield quota - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def create_server(self, context, server): - """ - Create a Server. - - :param context: RPC Context. - :param server: Server object with the values to be created. - """ - self.storage.begin() - - try: - created_server = self.storage.create_server(context, server) - yield created_server - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def get_server(self, context, server_id): - """ - Get a Server via ID. - - :param context: RPC Context. - :param server_id: Server ID to get. - """ - return self.storage.get_server(context, server_id) - - def find_servers(self, context, criterion=None, marker=None, limit=None, - sort_key=None, sort_dir=None): - """ - Find Servers - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_servers( - context, criterion, marker, limit, sort_key, sort_dir) - - def find_server(self, context, criterion): - """ - Find a single Server. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_server(context, criterion) - - @contextlib.contextmanager - def update_server(self, context, server_id, values): - """ - Update a Server via ID - - :param context: RPC Context. - :param server_id: Server ID to update. - :param values: Values to update the Server from - """ - self.storage.begin() - - try: - server = self.storage.update_server(context, server_id, values) - yield server - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def delete_server(self, context, server_id): - """ - Delete a Server via ID. - - :param context: RPC Context. - :param server_id: Delete a Server via ID - """ - self.storage.begin() - - try: - server = self.storage.delete_server(context, server_id) - yield server - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def create_tld(self, context, tld): - """ - Create a TLD. - - :param context: RPC Context. - :param tld: Tld object with the values to be created. - """ - self.storage.begin() - - try: - created_tld = self.storage.create_tld(context, tld) - yield created_tld - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def get_tld(self, context, tld_id): - """ - Get a TLD via ID. - - :param context: RPC Context. - :param tld_id: TLD ID to get. - """ - return self.storage.get_tld(context, tld_id) - - def find_tlds(self, context, criterion=None, marker=None, limit=None, - sort_key=None, sort_dir=None): - """ - Find TLDs - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_tlds( - context, criterion, marker, limit, sort_key, sort_dir) - - def find_tld(self, context, criterion): - """ - Find a single TLD. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_tld(context, criterion) - - @contextlib.contextmanager - def update_tld(self, context, tld_id, values): - """ - Update a TLD via ID - - :param context: RPC Context. - :param tld_id: TLD ID to update. - :param values: Values to update the TLD from - """ - self.storage.begin() - - try: - tld = self.storage.update_tld(context, tld_id, values) - yield tld - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def delete_tld(self, context, tld_id): - """ - Delete a TLD via ID. - - :param context: RPC Context. - :param tld_id: Delete a TLD via ID - """ - self.storage.begin() - - try: - tld = self.storage.delete_tld(context, tld_id) - yield tld - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def create_tsigkey(self, context, tsigkey): - """ - Create a TSIG Key. - - :param context: RPC Context. - :param tsigkey: TsigKey object with the values to be created. - """ - self.storage.begin() - - try: - created_tsigkey = self.storage.create_tsigkey(context, tsigkey) - yield created_tsigkey - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def get_tsigkey(self, context, tsigkey_id): - """ - Get a TSIG Key via ID. - - :param context: RPC Context. - :param tsigkey_id: Server ID to get. - """ - return self.storage.get_tsigkey(context, tsigkey_id) - - def find_tsigkeys(self, context, criterion=None, marker=None, limit=None, - sort_key=None, sort_dir=None): - """ - Find Tsigkey - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_tsigkeys( - context, criterion, marker, limit, sort_key, sort_dir) - - def find_tsigkey(self, context, criterion): - """ - Find a single Tsigkey. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_tsigkey(context, criterion) - - @contextlib.contextmanager - def update_tsigkey(self, context, tsigkey_id, values): - """ - Update a TSIG Key via ID - - :param context: RPC Context. - :param tsigkey_id: TSIG Key ID to update. - :param values: Values to update the TSIG Key from - """ - self.storage.begin() - - try: - tsigkey = self.storage.update_tsigkey(context, tsigkey_id, values) - yield tsigkey - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def delete_tsigkey(self, context, tsigkey_id): - """ - Delete a TSIG Key via ID. - - :param context: RPC Context. - :param tsigkey_id: Delete a TSIG Key via ID - """ - self.storage.begin() - - try: - tsigkey = self.storage.delete_tsigkey(context, tsigkey_id) - yield tsigkey - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def find_tenants(self, context): - """ - Find all Tenants. - - :param context: RPC Context. - """ - return self.storage.find_tenants(context) - - def get_tenant(self, context, tenant_id): - """ - Get all Tenants. - - :param context: RPC Context. - :param tenant_id: ID of the Tenant. - """ - return self.storage.get_tenant(context, tenant_id) - - def count_tenants(self, context): - """ - Count tenants - - :param context: RPC Context. - """ - return self.storage.count_tenants(context) - - @contextlib.contextmanager - def create_domain(self, context, domain): - """ - Create a new Domain. - - :param context: RPC Context. - :param domain: Domain object with the values to be created. - """ - self.storage.begin() - - try: - created_domain = self.storage.create_domain(context, domain) - yield created_domain - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def get_domain(self, context, domain_id): - """ - Get a Domain via its ID. - - :param context: RPC Context. - :param domain_id: ID of the Domain. - """ - return self.storage.get_domain(context, domain_id) - - def find_domains(self, context, criterion=None, marker=None, limit=None, - sort_key=None, sort_dir=None): - """ - Find Domains - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_domains( - context, criterion, marker, limit, sort_key, sort_dir) - - def find_domain(self, context, criterion): - """ - Find a single Domain. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_domain(context, criterion) - - @contextlib.contextmanager - def update_domain(self, context, domain_id, values): - """ - Update a Domain via ID. - - :param context: RPC Context. - :param domain_id: Values to update the Domain with - :param values: Values to update the Domain from. - """ - self.storage.begin() - - try: - domain = self.storage.update_domain(context, domain_id, values) - yield domain - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def delete_domain(self, context, domain_id): - """ - Delete a Domain - - :param context: RPC Context. - :param domain_id: Domain ID to delete. - """ - self.storage.begin() - - try: - domain = self.storage.delete_domain(context, domain_id) - yield domain - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def count_domains(self, context, criterion=None): - """ - Count domains - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.count_domains(context, criterion) - - @contextlib.contextmanager - def create_recordset(self, context, domain_id, recordset): - """ - Create a recordset on a given Domain ID - - :param context: RPC Context. - :param domain_id: Domain ID to create the recordset in. - :param recordset: RecordSet object with the values to be created. - """ - self.storage.begin() - - try: - created_recordset = self.storage.create_recordset( - context, domain_id, recordset) - yield created_recordset - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def get_recordset(self, context, recordset_id): - """ - Get a recordset via ID - - :param context: RPC Context. - :param recordset_id: RecordSet ID to get - """ - return self.storage.get_recordset(context, recordset_id) - - def find_recordsets(self, context, criterion=None, marker=None, limit=None, - sort_key=None, sort_dir=None): - """ - Find RecordSets. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_recordsets( - context, criterion, marker, limit, sort_key, sort_dir) - - def find_recordset(self, context, criterion=None): - """ - Find a single RecordSet. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_recordset(context, criterion) - - @contextlib.contextmanager - def update_recordset(self, context, recordset_id, values): - """ - Update a recordset via ID - - :param context: RPC Context - :param recordset_id: RecordSet ID to update - """ - self.storage.begin() - - try: - recordset = self.storage.update_recordset( - context, recordset_id, values) - yield recordset - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def delete_recordset(self, context, recordset_id): - """ - Delete a recordset - - :param context: RPC Context - :param recordset_id: RecordSet ID to delete - """ - self.storage.begin() - - try: - recordset = self.storage.delete_recordset(context, recordset_id) - yield recordset - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def count_recordsets(self, context, criterion=None): - """ - Count recordsets - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.count_recordsets(context, criterion) - - @contextlib.contextmanager - def create_record(self, context, domain_id, recordset_id, record): - """ - Create a record on a given Domain ID - - :param context: RPC Context. - :param domain_id: Domain ID to create the record in. - :param recordset_id: RecordSet ID to create the record in. - :param values: Values to create the new Record from. - :param record: Record object with the values to be created. - """ - self.storage.begin() - - try: - created_record = self.storage.create_record( - context, domain_id, recordset_id, record) - yield created_record - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def get_record(self, context, record_id): - """ - Get a record via ID - - :param context: RPC Context. - :param record_id: Record ID to get - """ - return self.storage.get_record(context, record_id) - - def find_records(self, context, criterion=None, marker=None, limit=None, - sort_key=None, sort_dir=None): - """ - Find Records. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_records( - context, criterion, marker, limit, sort_key, sort_dir) - - def find_record(self, context, criterion=None): - """ - Find a single Record. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_record(context, criterion) - - @contextlib.contextmanager - def update_record(self, context, record_id, values): - """ - Update a record via ID - - :param context: RPC Context - :param record_id: Record ID to update - """ - self.storage.begin() - - try: - record = self.storage.update_record(context, record_id, values) - yield record - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def delete_record(self, context, record_id): - """ - Delete a record - - :param context: RPC Context - :param record_id: Record ID to delete - """ - self.storage.begin() - - try: - record = self.storage.delete_record(context, record_id) - yield record - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def count_records(self, context, criterion=None): - """ - Count records - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.count_records(context, criterion) - - @contextlib.contextmanager - def create_blacklist(self, context, blacklist): - """ - Create a new Blacklisted Domain. - - :param context: RPC Context. - :param blacklist: Blacklist object with the values to be created. - """ - self.storage.begin() - - try: - created_blacklist = self.storage.create_blacklist( - context, blacklist) - yield created_blacklist - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def get_blacklist(self, context, blacklist_id): - """ - Get a Blacklist via its ID. - - :param context: RPC Context. - :param blacklist_id: ID of the Blacklisted Domain. - """ - return self.storage.get_blacklist(context, blacklist_id) - - def find_blacklists(self, context, criterion=None, marker=None, limit=None, - sort_key=None, sort_dir=None): - """ - Find all Blacklisted Domains - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_blacklists( - context, criterion, marker, limit, sort_key, sort_dir) - - def find_blacklist(self, context, criterion): - """ - Find a single Blacklisted Domain. - - :param context: RPC Context. - :param criterion: Criteria to filter by. - """ - return self.storage.find_blacklist(context, criterion) - - @contextlib.contextmanager - def update_blacklist(self, context, blacklist_id, values): - """ - Update a Blacklisted Domain via ID. - - :param context: RPC Context. - :param blacklist_id: Values to update the Blacklist with - :param values: Values to update the Blacklist from. - """ - self.storage.begin() - - try: - blacklist = self.storage.update_blacklist(context, - blacklist_id, - values) - yield blacklist - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - @contextlib.contextmanager - def delete_blacklist(self, context, blacklist_id): - """ - Delete a Blacklisted Domain - - :param context: RPC Context. - :param blacklist_id: Blacklist ID to delete. - """ - self.storage.begin() - - try: - blacklist = self.storage.delete_blacklist(context, blacklist_id) - yield blacklist - except Exception: - with excutils.save_and_reraise_exception(): - self.storage.rollback() - else: - self.storage.commit() - - def ping(self, context): - """ Ping the Storage connection """ - return self.storage.ping(context) diff --git a/designate/tests/test_central/_test_service_ipa.py b/designate/tests/test_central/_test_service_ipa.py index 6b101b452..5e019eaca 100644 --- a/designate/tests/test_central/_test_service_ipa.py +++ b/designate/tests/test_central/_test_service_ipa.py @@ -52,7 +52,7 @@ class CentralServiceTestIPA(designate.tests.test_central. # go directly through storage api to bypass tenant/policy checks save_all_tenants = self.admin_context.all_tenants self.admin_context.all_tenants = True - self.startdomains = self.central_service.storage_api.\ + self.startdomains = self.central_service.storage.\ find_domains(self.admin_context) LOG.debug("%s.setUp: startdomains %d" % (self.__class__, len(self.startdomains))) @@ -62,7 +62,7 @@ class CentralServiceTestIPA(designate.tests.test_central. # delete domains # go directly through storage api to bypass tenant/policy checks self.admin_context.all_tenants = True - domains = self.central_service.storage_api.\ + domains = self.central_service.storage.\ find_domains(self.admin_context) LOG.debug("%s.tearDown: domains %d" % (self.__class__, len(self.startdomains))) diff --git a/designate/tests/test_quota/test_storage.py b/designate/tests/test_quota/test_storage.py index a4c5d76a8..af40a8b13 100644 --- a/designate/tests/test_quota/test_storage.py +++ b/designate/tests/test_quota/test_storage.py @@ -41,7 +41,7 @@ class StorageQuotaTest(tests.TestCase): 'resource': 'domains' } - quota = self.quota.storage_api.find_quota(context, criterion) + quota = self.quota.storage.find_quota(context, criterion) self.assertEqual(quota['tenant_id'], 'tenant_id') self.assertEqual(quota['resource'], 'domains') @@ -64,7 +64,7 @@ class StorageQuotaTest(tests.TestCase): 'resource': 'domains' } - quota = self.quota.storage_api.find_quota(context, criterion) + quota = self.quota.storage.find_quota(context, criterion) self.assertEqual(quota['tenant_id'], 'tenant_id') self.assertEqual(quota['resource'], 'domains') @@ -89,5 +89,5 @@ class StorageQuotaTest(tests.TestCase): 'tenant_id': 'tenant_id' } - quotas = self.quota.storage_api.find_quotas(context, criterion) + quotas = self.quota.storage.find_quotas(context, criterion) self.assertEqual(0, len(quotas)) diff --git a/designate/tests/test_storage/test_api.py b/designate/tests/test_storage/test_api.py deleted file mode 100644 index 47f99ba25..000000000 --- a/designate/tests/test_storage/test_api.py +++ /dev/null @@ -1,776 +0,0 @@ -# Copyright 2013 Hewlett-Packard Development Company, L.P. -# -# Author: Kiall Mac Innes -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -import mock -import testtools - -from designate.openstack.common import log as logging -from designate.tests import TestCase -from designate.storage import api as storage_api - - -LOG = logging.getLogger(__name__) - - -class SentinelException(Exception): - pass - - -class StorageAPITest(TestCase): - def setUp(self): - super(StorageAPITest, self).setUp() - self.storage_api = storage_api.StorageAPI('sqlalchemy') - self.storage_mock = mock.Mock() - self.storage_api.storage = self.storage_mock - - def _set_side_effect(self, method, side_effect): - methodc = getattr(self.storage_mock, method) - methodc.side_effect = side_effect - - def _assert_called_with(self, method, *args, **kwargs): - methodc = getattr(self.storage_mock, method) - methodc.assert_called_with(*args, **kwargs) - - def _assert_has_calls(self, method, *args, **kwargs): - methodc = getattr(self.storage_mock, method) - methodc.assert_has_calls(*args, **kwargs) - - def _assert_call_count(self, method, call_count): - methodc = getattr(self.storage_mock, method) - self.assertEqual(methodc.call_count, call_count) - - # Quota Tests - def test_create_quota(self): - context = mock.sentinel.context - values = mock.sentinel.values - quota = mock.sentinel.quota - - self._set_side_effect('create_quota', [quota]) - - with self.storage_api.create_quota(context, values) as q: - self.assertEqual(quota, q) - - self._assert_called_with('create_quota', context, values) - - def test_create_quota_failure(self): - context = mock.sentinel.context - values = mock.sentinel.values - - self._set_side_effect('create_quota', [{'id': 12345}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.create_quota(context, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('create_quota', context, values) - - def test_get_quota(self): - context = mock.sentinel.context - quota_id = mock.sentinel.quota_id - quota = mock.sentinel.quota - - self._set_side_effect('get_quota', [quota]) - - result = self.storage_api.get_quota(context, quota_id) - self._assert_called_with('get_quota', context, quota_id) - self.assertEqual(quota, result) - - def test_find_quotas(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - marker = mock.sentinel.marker - limit = mock.sentinel.limit - sort_key = mock.sentinel.sort_key - sort_dir = mock.sentinel.sort_dir - quota = mock.sentinel.quota - - self._set_side_effect('find_quotas', [[quota]]) - - result = self.storage_api.find_quotas( - context, criterion, - marker, limit, sort_key, sort_dir) - self._assert_called_with( - 'find_quotas', context, criterion, - marker, limit, sort_key, sort_dir) - self.assertEqual([quota], result) - - def test_find_quota(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - quota = mock.sentinel.quota - - self._set_side_effect('find_quota', [quota]) - - result = self.storage_api.find_quota(context, criterion) - self._assert_called_with('find_quota', context, criterion) - self.assertEqual(quota, result) - - def test_update_quota(self): - context = mock.sentinel.context - values = mock.sentinel.values - - with self.storage_api.update_quota(context, 123, values): - pass - - self._assert_called_with('update_quota', context, 123, values) - - def test_update_quota_failure(self): - context = mock.sentinel.context - values = {'test': 2} - - self._set_side_effect('get_quota', [{'id': 123, 'test': 1}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.update_quota(context, 123, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('update_quota', context, 123, values) - - def test_delete_quota(self): - context = mock.sentinel.context - quota = mock.sentinel.quota - - self._set_side_effect('delete_quota', [quota]) - - with self.storage_api.delete_quota(context, 123) as q: - self.assertEqual(quota, q) - - self._assert_called_with('delete_quota', context, 123) - - def test_delete_quota_failure(self): - context = mock.sentinel.context - quota = mock.sentinel.quota - - self._set_side_effect('delete_quota', [quota]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.delete_quota(context, 123): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - - # Server Tests - def test_create_server(self): - context = mock.sentinel.context - values = mock.sentinel.values - server = mock.sentinel.server - - self._set_side_effect('create_server', [server]) - - with self.storage_api.create_server(context, values) as q: - self.assertEqual(server, q) - - self._assert_called_with('create_server', context, values) - - def test_create_server_failure(self): - context = mock.sentinel.context - values = mock.sentinel.values - - self._set_side_effect('create_server', [{'id': 12345}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.create_server(context, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('create_server', context, values) - - def test_get_server(self): - context = mock.sentinel.context - server_id = mock.sentinel.server_id - server = mock.sentinel.server - - self._set_side_effect('get_server', [server]) - - result = self.storage_api.get_server(context, server_id) - self._assert_called_with('get_server', context, server_id) - self.assertEqual(server, result) - - def test_find_servers(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - marker = mock.sentinel.marker - limit = mock.sentinel.limit - sort_key = mock.sentinel.sort_key - sort_dir = mock.sentinel.sort_dir - - server = mock.sentinel.server - - self._set_side_effect('find_servers', [[server]]) - - result = self.storage_api.find_servers( - context, criterion, - marker, limit, sort_key, sort_dir) - self._assert_called_with( - 'find_servers', context, criterion, - marker, limit, sort_key, sort_dir) - self.assertEqual([server], result) - - def test_find_server(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - server = mock.sentinel.server - - self._set_side_effect('find_server', [server]) - - result = self.storage_api.find_server(context, criterion) - self._assert_called_with('find_server', context, criterion) - self.assertEqual(server, result) - - def test_update_server(self): - context = mock.sentinel.context - values = mock.sentinel.values - - with self.storage_api.update_server(context, 123, values): - pass - - self._assert_called_with('update_server', context, 123, values) - - def test_update_server_failure(self): - context = mock.sentinel.context - values = {'test': 2} - - self._set_side_effect('get_server', [{'id': 123, 'test': 1}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.update_server(context, 123, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('update_server', context, 123, values) - - def test_delete_server(self): - context = mock.sentinel.context - server = mock.sentinel.server - - self._set_side_effect('delete_server', [server]) - - with self.storage_api.delete_server(context, 123) as q: - self.assertEqual(server, q) - - self._assert_called_with('delete_server', context, 123) - - def test_delete_server_failure(self): - context = mock.sentinel.context - server = mock.sentinel.server - - self._set_side_effect('delete_server', [server]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.delete_server(context, 123): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - - # Tsigkey Tests - def test_create_tsigkey(self): - context = mock.sentinel.context - values = mock.sentinel.values - tsigkey = mock.sentinel.tsigkey - - self._set_side_effect('create_tsigkey', [tsigkey]) - - with self.storage_api.create_tsigkey(context, values) as q: - self.assertEqual(tsigkey, q) - - self._assert_called_with('create_tsigkey', context, values) - - def test_create_tsigkey_failure(self): - context = mock.sentinel.context - values = mock.sentinel.values - - self._set_side_effect('create_tsigkey', [{'id': 12345}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.create_tsigkey(context, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('create_tsigkey', context, values) - - def test_get_tsigkey(self): - context = mock.sentinel.context - tsigkey_id = mock.sentinel.tsigkey_id - tsigkey = mock.sentinel.tsigkey - - self._set_side_effect('get_tsigkey', [tsigkey]) - - result = self.storage_api.get_tsigkey(context, tsigkey_id) - self._assert_called_with('get_tsigkey', context, tsigkey_id) - self.assertEqual(tsigkey, result) - - def test_find_tsigkeys(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - marker = mock.sentinel.marker - limit = mock.sentinel.limit - sort_key = mock.sentinel.sort_key - sort_dir = mock.sentinel.sort_dir - tsigkey = mock.sentinel.tsigkey - - self._set_side_effect('find_tsigkeys', [[tsigkey]]) - - result = self.storage_api.find_tsigkeys( - context, criterion, marker, limit, sort_key, sort_dir) - self._assert_called_with( - 'find_tsigkeys', context, criterion, - marker, limit, sort_key, sort_dir) - self.assertEqual([tsigkey], result) - - def test_find_tsigkey(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - tsigkey = mock.sentinel.tsigkey - - self._set_side_effect('find_tsigkey', [tsigkey]) - - result = self.storage_api.find_tsigkey(context, criterion) - self._assert_called_with('find_tsigkey', context, criterion) - self.assertEqual(tsigkey, result) - - def test_update_tsigkey(self): - context = mock.sentinel.context - values = mock.sentinel.values - - with self.storage_api.update_tsigkey(context, 123, values): - pass - - self._assert_called_with('update_tsigkey', context, 123, values) - - def test_update_tsigkey_failure(self): - context = mock.sentinel.context - values = {'test': 2} - - self._set_side_effect('get_tsigkey', [{'id': 123, 'test': 1}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.update_tsigkey(context, 123, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('update_tsigkey', context, 123, values) - - def test_delete_tsigkey(self): - context = mock.sentinel.context - tsigkey = mock.sentinel.tsigkey - - self._set_side_effect('delete_tsigkey', [tsigkey]) - - with self.storage_api.delete_tsigkey(context, 123) as q: - self.assertEqual(tsigkey, q) - - self._assert_called_with('delete_tsigkey', context, 123) - - def test_delete_tsigkey_failure(self): - context = mock.sentinel.context - tsigkey = mock.sentinel.tsigkey - - self._set_side_effect('delete_tsigkey', [tsigkey]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.delete_tsigkey(context, 123): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - - # Tenant Tests - def test_find_tenants(self): - context = mock.sentinel.context - tenant = mock.sentinel.tenant - - self._set_side_effect('find_tenants', [[tenant]]) - - result = self.storage_api.find_tenants(context) - self._assert_called_with('find_tenants', context) - self.assertEqual([tenant], result) - - def test_get_tenant(self): - context = mock.sentinel.context - tenant = mock.sentinel.tenant - - self._set_side_effect('get_tenant', [tenant]) - - result = self.storage_api.get_tenant(context, 123) - self._assert_called_with('get_tenant', context, 123) - self.assertEqual(tenant, result) - - def test_count_tenants(self): - context = mock.sentinel.context - - self._set_side_effect('count_tenants', [1]) - - result = self.storage_api.count_tenants(context) - self._assert_called_with('count_tenants', context) - self.assertEqual(1, result) - - # Domain Tests - def test_create_domain(self): - context = mock.sentinel.context - values = mock.sentinel.values - domain = mock.sentinel.domain - - self._set_side_effect('create_domain', [domain]) - - with self.storage_api.create_domain(context, values) as q: - self.assertEqual(domain, q) - - self._assert_called_with('create_domain', context, values) - - def test_create_domain_failure(self): - context = mock.sentinel.context - values = mock.sentinel.values - - self._set_side_effect('create_domain', [{'id': 12345}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.create_domain(context, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('create_domain', context, values) - - def test_get_domain(self): - context = mock.sentinel.context - domain_id = mock.sentinel.domain_id - domain = mock.sentinel.domain - - self._set_side_effect('get_domain', [domain]) - - result = self.storage_api.get_domain(context, domain_id) - self._assert_called_with('get_domain', context, domain_id) - self.assertEqual(domain, result) - - def test_find_domains(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - marker = mock.sentinel.marker - limit = mock.sentinel.limit - sort_key = mock.sentinel.sort_key - sort_dir = mock.sentinel.sort_dir - domain = mock.sentinel.domain - - self._set_side_effect('find_domains', [[domain]]) - - result = self.storage_api.find_domains( - context, criterion, - marker, limit, sort_key, sort_dir) - self._assert_called_with( - 'find_domains', context, criterion, - marker, limit, sort_key, sort_dir) - self.assertEqual([domain], result) - - def test_find_domain(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - domain = mock.sentinel.domain - - self._set_side_effect('find_domain', [domain]) - - result = self.storage_api.find_domain(context, criterion) - self._assert_called_with('find_domain', context, criterion) - self.assertEqual(domain, result) - - def test_update_domain(self): - context = mock.sentinel.context - values = mock.sentinel.values - - with self.storage_api.update_domain(context, 123, values): - pass - - self._assert_called_with('update_domain', context, 123, values) - - def test_update_domain_failure(self): - context = mock.sentinel.context - values = {'test': 2} - - self._set_side_effect('get_domain', [{'id': 123, 'test': 1}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.update_domain(context, 123, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('update_domain', context, 123, values) - - def test_delete_domain(self): - context = mock.sentinel.context - domain = mock.sentinel.domain - - self._set_side_effect('delete_domain', [domain]) - - with self.storage_api.delete_domain(context, 123) as q: - self.assertEqual(domain, q) - - self._assert_called_with('delete_domain', context, 123) - - def test_delete_domain_failure(self): - context = mock.sentinel.context - domain = mock.sentinel.domain - - self._set_side_effect('delete_domain', [domain]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.delete_domain(context, 123): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - - # RecordSet Tests - def test_create_recordset(self): - context = mock.sentinel.context - values = mock.sentinel.values - recordset = mock.sentinel.recordset - - self._set_side_effect('create_recordset', [recordset]) - - with self.storage_api.create_recordset(context, 123, values) as q: - self.assertEqual(recordset, q) - - self._assert_called_with('create_recordset', context, 123, values) - - def test_create_recordset_failure(self): - context = mock.sentinel.context - values = mock.sentinel.values - - self._set_side_effect('create_recordset', [{'id': 12345}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.create_recordset(context, 123, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('create_recordset', context, 123, values) - - def test_get_recordset(self): - context = mock.sentinel.context - recordset_id = mock.sentinel.recordset_id - recordset = mock.sentinel.recordset - - self._set_side_effect('get_recordset', [recordset]) - - result = self.storage_api.get_recordset(context, recordset_id) - self._assert_called_with('get_recordset', context, recordset_id) - self.assertEqual(recordset, result) - - def test_find_recordsets(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - marker = mock.sentinel.marker - limit = mock.sentinel.limit - sort_key = mock.sentinel.sort_key - sort_dir = mock.sentinel.sort_dir - recordset = mock.sentinel.recordset - - self._set_side_effect('find_recordsets', [[recordset]]) - - result = self.storage_api.find_recordsets( - context, criterion, - marker, limit, sort_key, sort_dir) - self._assert_called_with( - 'find_recordsets', context, criterion, - marker, limit, sort_key, sort_dir) - self.assertEqual([recordset], result) - - def test_find_recordset(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - recordset = mock.sentinel.recordset - - self._set_side_effect('find_recordset', [recordset]) - - result = self.storage_api.find_recordset(context, criterion) - self._assert_called_with('find_recordset', context, criterion) - self.assertEqual(recordset, result) - - def test_update_recordset(self): - context = mock.sentinel.context - values = mock.sentinel.values - - with self.storage_api.update_recordset(context, 123, values): - pass - - self._assert_called_with('update_recordset', context, 123, values) - - def test_update_recordset_failure(self): - context = mock.sentinel.context - values = {'test': 2} - - self._set_side_effect('get_recordset', [{'id': 123, 'test': 1}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.update_recordset(context, 123, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('update_recordset', context, 123, values) - - def test_delete_recordset(self): - context = mock.sentinel.context - recordset = mock.sentinel.recordset - - self._set_side_effect('delete_recordset', [recordset]) - - with self.storage_api.delete_recordset(context, 123) as q: - self.assertEqual(recordset, q) - - self._assert_called_with('delete_recordset', context, 123) - - def test_delete_recordset_failure(self): - context = mock.sentinel.context - recordset = mock.sentinel.recordset - - self._set_side_effect('delete_recordset', [recordset]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.delete_recordset(context, 123): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - - # Record Tests - def test_create_record(self): - context = mock.sentinel.context - values = mock.sentinel.values - record = mock.sentinel.record - - self._set_side_effect('create_record', [record]) - - with self.storage_api.create_record(context, 123, 321, values) as q: - self.assertEqual(record, q) - - self._assert_called_with('create_record', context, 123, 321, values) - - def test_create_record_failure(self): - context = mock.sentinel.context - values = mock.sentinel.values - - self._set_side_effect('create_record', [{'id': 12345}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.create_record(context, 123, 321, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('create_record', context, 123, 321, values) - - def test_get_record(self): - context = mock.sentinel.context - record_id = mock.sentinel.record_id - record = mock.sentinel.record - - self._set_side_effect('get_record', [record]) - - result = self.storage_api.get_record(context, record_id) - self._assert_called_with('get_record', context, record_id) - self.assertEqual(record, result) - - def test_find_records(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - marker = mock.sentinel.marker - limit = mock.sentinel.limit - sort_key = mock.sentinel.sort_key - sort_dir = mock.sentinel.sort_dir - record = mock.sentinel.record - - self._set_side_effect('find_records', [[record]]) - - result = self.storage_api.find_records( - context, criterion, - marker, limit, sort_key, sort_dir) - self._assert_called_with( - 'find_records', context, criterion, - marker, limit, sort_key, sort_dir) - - self.assertEqual([record], result) - - def test_find_record(self): - context = mock.sentinel.context - criterion = mock.sentinel.criterion - record = mock.sentinel.record - - self._set_side_effect('find_record', [record]) - - result = self.storage_api.find_record(context, criterion) - self._assert_called_with('find_record', context, criterion) - - self.assertEqual(record, result) - - def test_update_record(self): - context = mock.sentinel.context - record_id = mock.sentinel.record_id - values = mock.sentinel.values - - with self.storage_api.update_record(context, record_id, values): - pass - - self._assert_called_with('update_record', context, record_id, values) - - def test_update_record_failure(self): - context = mock.sentinel.context - record_id = mock.sentinel.record_id - values = {'test': 2} - - self._set_side_effect('get_record', [{'id': record_id, 'test': 1}]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.update_record(context, record_id, values): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback') - self._assert_called_with('update_record', context, record_id, values) - - def test_delete_record(self): - context = mock.sentinel.context - record = mock.sentinel.record - - self._set_side_effect('delete_record', [record]) - - with self.storage_api.delete_record(context, 123) as q: - self.assertEqual(record, q) - - self._assert_called_with('delete_record', context, 123) - - def test_delete_record_failure(self): - context = mock.sentinel.context - record = mock.sentinel.record - - self._set_side_effect('delete_record', [record]) - - with testtools.ExpectedException(SentinelException): - with self.storage_api.delete_record(context, 123): - raise SentinelException('Something Went Wrong') - - self._assert_called_with('begin') - self._assert_called_with('rollback')