Merge "Fix duplicated sg rules check for remote_ip_prefix"

This commit is contained in:
Jenkins 2017-02-08 06:16:25 +00:00 committed by Gerrit Code Review
commit 72ed1f0dd9
2 changed files with 55 additions and 3 deletions

View File

@ -498,8 +498,7 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
'direction': [sgr['direction']]} 'direction': [sgr['direction']]}
include_if_present = ['protocol', 'port_range_max', 'port_range_min', include_if_present = ['protocol', 'port_range_max', 'port_range_min',
'ethertype', 'remote_ip_prefix', 'ethertype', 'remote_group_id']
'remote_group_id']
for key in include_if_present: for key in include_if_present:
value = sgr.get(key) value = sgr.get(key)
if value: if value:
@ -538,6 +537,8 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
rule_dict.pop('description', None) rule_dict.pop('description', None)
keys = rule_dict.keys() keys = rule_dict.keys()
fields = list(keys) + ['id'] fields = list(keys) + ['id']
if 'remote_ip_prefix' not in fields:
fields += ['remote_ip_prefix']
db_rules = self.get_security_group_rules(context, filters, db_rules = self.get_security_group_rules(context, filters,
fields=fields) fields=fields)
# Note(arosen): the call to get_security_group_rules wildcards # Note(arosen): the call to get_security_group_rules wildcards
@ -551,6 +552,7 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
# below to check for these corner cases. # below to check for these corner cases.
rule_dict.pop('id', None) rule_dict.pop('id', None)
sg_protocol = rule_dict.pop('protocol', None) sg_protocol = rule_dict.pop('protocol', None)
remote_ip_prefix = rule_dict.pop('remote_ip_prefix', None)
for db_rule in db_rules: for db_rule in db_rules:
rule_id = db_rule.pop('id', None) rule_id = db_rule.pop('id', None)
# remove protocol and match separately for number and type # remove protocol and match separately for number and type
@ -558,9 +560,21 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
is_protocol_matching = ( is_protocol_matching = (
self._get_ip_proto_name_and_num(db_protocol) == self._get_ip_proto_name_and_num(db_protocol) ==
self._get_ip_proto_name_and_num(sg_protocol)) self._get_ip_proto_name_and_num(sg_protocol))
if (is_protocol_matching and rule_dict == db_rule): db_remote_ip_prefix = db_rule.pop('remote_ip_prefix', None)
duplicate_ip_prefix = self._validate_duplicate_ip_prefix(
remote_ip_prefix, db_remote_ip_prefix)
if (is_protocol_matching and duplicate_ip_prefix and
rule_dict == db_rule):
raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id) raise ext_sg.SecurityGroupRuleExists(rule_id=rule_id)
def _validate_duplicate_ip_prefix(self, ip_prefix, other_ip_prefix):
all_address = ['0.0.0.0/0', '::/0', None]
if ip_prefix == other_ip_prefix:
return True
elif ip_prefix in all_address and other_ip_prefix in all_address:
return True
return False
def _validate_ip_prefix(self, rule): def _validate_ip_prefix(self, rule):
"""Check that a valid cidr was specified as remote_ip_prefix """Check that a valid cidr was specified as remote_ip_prefix

View File

@ -133,6 +133,44 @@ class SecurityGroupDbMixinTestCase(testlib_api.SqlTestCase):
self.mixin._check_for_duplicate_rules_in_db, self.mixin._check_for_duplicate_rules_in_db,
context, rule_dict) context, rule_dict)
def test_check_for_duplicate_diff_rules_remote_ip_prefix_ipv4(self):
db_rules = {'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv4',
'direction': 'ingress', 'security_group_id': 'fake',
'remote_ip_prefix': None}
with mock.patch.object(self.mixin, 'get_security_group_rules',
return_value=[db_rules]):
context = mock.Mock()
rule_dict = {
'security_group_rule': {'id': 'fake2',
'tenant_id': 'fake',
'security_group_id': 'fake',
'ethertype': 'IPv4',
'direction': 'ingress',
'remote_ip_prefix': '0.0.0.0/0'}
}
self.assertRaises(securitygroup.SecurityGroupRuleExists,
self.mixin._check_for_duplicate_rules_in_db,
context, rule_dict)
def test_check_for_duplicate_diff_rules_remote_ip_prefix_ipv6(self):
db_rules = {'id': 'fake', 'tenant_id': 'fake', 'ethertype': 'IPv6',
'direction': 'ingress', 'security_group_id': 'fake',
'remote_ip_prefix': None}
with mock.patch.object(self.mixin, 'get_security_group_rules',
return_value=[db_rules]):
context = mock.Mock()
rule_dict = {
'security_group_rule': {'id': 'fake2',
'tenant_id': 'fake',
'security_group_id': 'fake',
'ethertype': 'IPv6',
'direction': 'ingress',
'remote_ip_prefix': '::/0'}
}
self.assertRaises(securitygroup.SecurityGroupRuleExists,
self.mixin._check_for_duplicate_rules_in_db,
context, rule_dict)
def test_delete_security_group_rule_in_use(self): def test_delete_security_group_rule_in_use(self):
with mock.patch.object(registry, "notify") as mock_notify: with mock.patch.object(registry, "notify") as mock_notify:
mock_notify.side_effect = exceptions.CallbackFailure(Exception()) mock_notify.side_effect = exceptions.CallbackFailure(Exception())