diff --git a/octavia/api/common/pagination.py b/octavia/api/common/pagination.py index aef7dd4f07..3ee8424f7b 100644 --- a/octavia/api/common/pagination.py +++ b/octavia/api/common/pagination.py @@ -13,6 +13,7 @@ # under the License. import copy +import itertools from oslo_log import log as logging from pecan import request @@ -184,36 +185,27 @@ class PaginationHelper(object): if not getattr(model, "_tags", None): return query - tag_alias = aliased(base_models.Tags) - if 'tags' in params: tags = params.pop('tags') - if not isinstance(tags, list): - tags = [tags] - first_tag = tags.pop(0) - query = query.join(model._tags) - query = query.filter(base_models.Tags.tag == first_tag) for tag in tags: + # This requires a multi-join to the tags table, + # so me must use aliases for each one. + tag_alias = aliased(base_models.Tags) query = query.join(tag_alias, model._tags) query = query.filter(tag_alias.tag == tag) if 'tags-any' in params: tags = params.pop('tags-any') - if not isinstance(tags, list): - tags = [tags] + tag_alias = aliased(base_models.Tags) query = query.join(tag_alias, model._tags) query = query.filter(tag_alias.tag.in_(tags)) if 'not-tags' in params: tags = params.pop('not-tags') - if not isinstance(tags, list): - tags = [tags] - first_tag = tags.pop(0) - subq = query.session.query(base_models.Tags.resource_id) - subq = subq.join(model._tags) - subq = subq.filter(base_models.Tags.tag == first_tag) + subq = query.session.query(model.id) for tag in tags: + tag_alias = aliased(base_models.Tags) subq = subq.join(tag_alias, model._tags) subq = subq.filter(tag_alias.tag == tag) @@ -221,13 +213,19 @@ class PaginationHelper(object): if 'not-tags-any' in params: tags = params.pop('not-tags-any') - if not isinstance(tags, list): - tags = [tags] query = query.filter( ~model._tags.any(base_models.Tags.tag.in_(tags))) return query + @staticmethod + def _prepare_tags_list(param): + """Split comma seperated tags and return a flat list of tags.""" + if not isinstance(param, list): + param = [param] + return list(itertools.chain.from_iterable( + tag.split(',') for tag in param)) + def apply(self, query, model, enforce_valid_params=True): """Returns a query with sorting / pagination criteria added. @@ -270,8 +268,22 @@ class PaginationHelper(object): filter_params['load_balancer_id'] = filter_params.pop( 'loadbalancer_id') - # Apply tags filtering for the models which support tags. - query = self._apply_tags_filtering(filter_params, model, query) + # Pop the 'tags' related parameters off before handling the + # other filters. Then apply the 'tags' filters after the + # other filters have been applied. + tag_params = {} + if 'tags' in filter_params: + tag_params['tags'] = self._prepare_tags_list( + filter_params.pop('tags')) + if 'tags-any' in filter_params: + tag_params['tags-any'] = self._prepare_tags_list( + filter_params.pop('tags-any')) + if 'not-tags' in filter_params: + tag_params['not-tags'] = self._prepare_tags_list( + filter_params.pop('not-tags')) + if 'not-tags-any' in filter_params: + tag_params['not-tags-any'] = self._prepare_tags_list( + filter_params.pop('not-tags-any')) # Drop invalid arguments self.filters = {k: v for (k, v) in filter_params.items() @@ -287,6 +299,9 @@ class PaginationHelper(object): query = query.filter(model.load_balancer.has( project_id=secondary_query_filter)) + # Apply tags filtering for the models which support tags. + query = self._apply_tags_filtering(tag_params, model, query) + # Add sorting if CONF.api_settings.allow_sorting: # Add default sort keys (if they are OK for the model) diff --git a/octavia/tests/functional/api/v2/test_load_balancer.py b/octavia/tests/functional/api/v2/test_load_balancer.py index f6c272b48e..8aa14e1e7e 100644 --- a/octavia/tests/functional/api/v2/test_load_balancer.py +++ b/octavia/tests/functional/api/v2/test_load_balancer.py @@ -1397,6 +1397,17 @@ class TestLoadBalancer(base.BaseAPITest): [lb.get('id') for lb in lbs] ) + lbs = self.get( + self.LBS_PATH, + params={'tags': ['test_tag2,test_tag3']} + ).json.get(self.root_tag_list) + self.assertIsInstance(lbs, list) + self.assertEqual(1, len(lbs)) + self.assertEqual( + [lb2.get('id')], + [lb.get('id') for lb in lbs] + ) + lbs = self.get( self.LBS_PATH, params={'tags-any': 'test_tag2'} @@ -1445,6 +1456,71 @@ class TestLoadBalancer(base.BaseAPITest): self.assertIsInstance(lbs, list) self.assertEqual(0, len(lbs)) + def test_get_all_tags_mixed_filters(self): + lb1 = self.create_load_balancer( + uuidutils.generate_uuid(), + name='lb1', + project_id=self.project_id, + vip_address='10.0.0.1', + tags=['test_tag1', 'test_tag2'] + ).get(self.root_tag) + self.create_load_balancer( + uuidutils.generate_uuid(), + name='lb2', + project_id=self.project_id, + tags=['test_tag2', 'test_tag3'] + ).get(self.root_tag) + + lbs = self.get( + self.LBS_PATH, + params={'name': 'lb1', 'tags': 'test_tag2', + 'vip_address': '10.0.0.1'} + ).json.get(self.root_tag_list) + self.assertIsInstance(lbs, list) + self.assertEqual(1, len(lbs)) + self.assertEqual(lb1.get('id'), lbs[0].get('id')) + + lbs = self.get( + self.LBS_PATH, + params={'tags': 'test_tag2', 'vip_address': '10.0.0.1'} + ).json.get(self.root_tag_list) + self.assertIsInstance(lbs, list) + self.assertEqual(1, len(lbs)) + self.assertEqual(lb1.get('id'), lbs[0].get('id')) + + lbs = self.get( + self.LBS_PATH, + params={'name': 'lb1', 'tags': 'test_tag2', + 'vip_address': '10.0.0.1', 'tags': 'test_tag1'} + ).json.get(self.root_tag_list) + self.assertIsInstance(lbs, list) + self.assertEqual(1, len(lbs)) + self.assertEqual(lb1.get('id'), lbs[0].get('id')) + + lbs = self.get( + self.LBS_PATH, + params={'name': 'lb1', 'tags': 'test_tag2', + 'vip_address': '10.0.0.1', 'tags': 'test_tag3'} + ).json.get(self.root_tag_list) + self.assertIsInstance(lbs, list) + self.assertEqual(0, len(lbs)) + + lbs = self.get( + self.LBS_PATH, + params={'name': 'lb1', 'tags': 'test_tag3', + 'vip_address': '10.0.0.1'} + ).json.get(self.root_tag_list) + self.assertIsInstance(lbs, list) + self.assertEqual(0, len(lbs)) + + lbs = self.get( + self.LBS_PATH, + params={'name': 'bogus-lb', 'tags': 'test_tag2', + 'vip_address': '10.0.0.1'} + ).json.get(self.root_tag_list) + self.assertIsInstance(lbs, list) + self.assertEqual(0, len(lbs)) + def test_get_all_hides_deleted(self): api_lb = self.create_load_balancer( uuidutils.generate_uuid()).get(self.root_tag)