Merge "Joined 'tags' column while getting node"

This commit is contained in:
Jenkins 2016-05-24 00:23:09 +00:00 committed by Gerrit Code Review
commit 28a3f4a7b2
3 changed files with 38 additions and 8 deletions

View File

@ -29,6 +29,7 @@ from oslo_utils import strutils
from oslo_utils import timeutils from oslo_utils import timeutils
from oslo_utils import uuidutils from oslo_utils import uuidutils
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm import joinedload
from sqlalchemy import sql from sqlalchemy import sql
from ironic.common import exception from ironic.common import exception
@ -63,6 +64,10 @@ def _session_for_write():
return enginefacade.writer.using(_CONTEXT) return enginefacade.writer.using(_CONTEXT)
def _get_node_query_with_tags():
return model_query(models.Node).options(joinedload('tags'))
def model_query(model, *args, **kwargs): def model_query(model, *args, **kwargs):
"""Query helper for simpler session usage. """Query helper for simpler session usage.
@ -241,14 +246,14 @@ class Connection(api.Connection):
def get_node_list(self, filters=None, limit=None, marker=None, def get_node_list(self, filters=None, limit=None, marker=None,
sort_key=None, sort_dir=None): sort_key=None, sort_dir=None):
query = model_query(models.Node) query = _get_node_query_with_tags()
query = self._add_nodes_filters(query, filters) query = self._add_nodes_filters(query, filters)
return _paginate_query(models.Node, limit, marker, return _paginate_query(models.Node, limit, marker,
sort_key, sort_dir, query) sort_key, sort_dir, query)
def reserve_node(self, tag, node_id): def reserve_node(self, tag, node_id):
with _session_for_write(): with _session_for_write():
query = model_query(models.Node) query = _get_node_query_with_tags()
query = add_identity_filter(query, node_id) query = add_identity_filter(query, node_id)
# be optimistic and assume we usually create a reservation # be optimistic and assume we usually create a reservation
count = query.filter_by(reservation=None).update( count = query.filter_by(reservation=None).update(
@ -313,24 +318,29 @@ class Connection(api.Connection):
instance_uuid=values['instance_uuid'], instance_uuid=values['instance_uuid'],
node=values['uuid']) node=values['uuid'])
raise exception.NodeAlreadyExists(uuid=values['uuid']) raise exception.NodeAlreadyExists(uuid=values['uuid'])
# Set tags to [] for new created node
node['tags'] = []
return node return node
def get_node_by_id(self, node_id): def get_node_by_id(self, node_id):
query = model_query(models.Node).filter_by(id=node_id) query = _get_node_query_with_tags()
query = query.filter_by(id=node_id)
try: try:
return query.one() return query.one()
except NoResultFound: except NoResultFound:
raise exception.NodeNotFound(node=node_id) raise exception.NodeNotFound(node=node_id)
def get_node_by_uuid(self, node_uuid): def get_node_by_uuid(self, node_uuid):
query = model_query(models.Node).filter_by(uuid=node_uuid) query = _get_node_query_with_tags()
query = query.filter_by(uuid=node_uuid)
try: try:
return query.one() return query.one()
except NoResultFound: except NoResultFound:
raise exception.NodeNotFound(node=node_uuid) raise exception.NodeNotFound(node=node_uuid)
def get_node_by_name(self, node_name): def get_node_by_name(self, node_name):
query = model_query(models.Node).filter_by(name=node_name) query = _get_node_query_with_tags()
query = query.filter_by(name=node_name)
try: try:
return query.one() return query.one()
except NoResultFound: except NoResultFound:
@ -340,8 +350,8 @@ class Connection(api.Connection):
if not uuidutils.is_uuid_like(instance): if not uuidutils.is_uuid_like(instance):
raise exception.InvalidUUID(uuid=instance) raise exception.InvalidUUID(uuid=instance)
query = (model_query(models.Node) query = _get_node_query_with_tags()
.filter_by(instance_uuid=instance)) query = query.filter_by(instance_uuid=instance)
try: try:
result = query.one() result = query.one()

View File

@ -27,6 +27,7 @@ from sqlalchemy import Boolean, Column, DateTime, Index
from sqlalchemy import ForeignKey, Integer from sqlalchemy import ForeignKey, Integer
from sqlalchemy import schema, String, Text from sqlalchemy import schema, String, Text
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import orm
from ironic.common.i18n import _ from ironic.common.i18n import _
from ironic.common import paths from ironic.common import paths
@ -197,3 +198,10 @@ class NodeTag(Base):
node_id = Column(Integer, ForeignKey('nodes.id'), node_id = Column(Integer, ForeignKey('nodes.id'),
primary_key=True, nullable=False) primary_key=True, nullable=False)
tag = Column(String(255), primary_key=True, nullable=False) tag = Column(String(255), primary_key=True, nullable=False)
node = orm.relationship(
"Node",
backref='tags',
primaryjoin='and_(NodeTag.node_id == Node.id)',
foreign_keys=node_id
)

View File

@ -61,22 +61,28 @@ class DbNodeTestCase(base.DbTestCase):
def test_get_node_by_id(self): def test_get_node_by_id(self):
node = utils.create_test_node() node = utils.create_test_node()
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
res = self.dbapi.get_node_by_id(node.id) res = self.dbapi.get_node_by_id(node.id)
self.assertEqual(node.id, res.id) self.assertEqual(node.id, res.id)
self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.uuid, res.uuid)
self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags])
def test_get_node_by_uuid(self): def test_get_node_by_uuid(self):
node = utils.create_test_node() node = utils.create_test_node()
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
res = self.dbapi.get_node_by_uuid(node.uuid) res = self.dbapi.get_node_by_uuid(node.uuid)
self.assertEqual(node.id, res.id) self.assertEqual(node.id, res.id)
self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.uuid, res.uuid)
self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags])
def test_get_node_by_name(self): def test_get_node_by_name(self):
node = utils.create_test_node() node = utils.create_test_node()
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
res = self.dbapi.get_node_by_name(node.name) res = self.dbapi.get_node_by_name(node.name)
self.assertEqual(node.id, res.id) self.assertEqual(node.id, res.id)
self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.uuid, res.uuid)
self.assertEqual(node.name, res.name) self.assertEqual(node.name, res.name)
self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags])
def test_get_node_that_does_not_exist(self): def test_get_node_that_does_not_exist(self):
self.assertRaises(exception.NodeNotFound, self.assertRaises(exception.NodeNotFound,
@ -217,6 +223,8 @@ class DbNodeTestCase(base.DbTestCase):
res = self.dbapi.get_node_list() res = self.dbapi.get_node_list()
res_uuids = [r.uuid for r in res] res_uuids = [r.uuid for r in res]
six.assertCountEqual(self, uuids, res_uuids) six.assertCountEqual(self, uuids, res_uuids)
for r in res:
self.assertEqual([], r.tags)
def test_get_node_list_with_filters(self): def test_get_node_list_with_filters(self):
ch1 = utils.create_test_chassis(uuid=uuidutils.generate_uuid()) ch1 = utils.create_test_chassis(uuid=uuidutils.generate_uuid())
@ -272,9 +280,11 @@ class DbNodeTestCase(base.DbTestCase):
def test_get_node_by_instance(self): def test_get_node_by_instance(self):
node = utils.create_test_node( node = utils.create_test_node(
instance_uuid='12345678-9999-0000-aaaa-123456789012') instance_uuid='12345678-9999-0000-aaaa-123456789012')
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
res = self.dbapi.get_node_by_instance(node.instance_uuid) res = self.dbapi.get_node_by_instance(node.instance_uuid)
self.assertEqual(node.uuid, res.uuid) self.assertEqual(node.uuid, res.uuid)
self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags])
def test_get_node_by_instance_wrong_uuid(self): def test_get_node_by_instance_wrong_uuid(self):
utils.create_test_node( utils.create_test_node(
@ -446,12 +456,14 @@ class DbNodeTestCase(base.DbTestCase):
def test_reserve_node(self): def test_reserve_node(self):
node = utils.create_test_node() node = utils.create_test_node()
self.dbapi.set_node_tags(node.id, ['tag1', 'tag2'])
uuid = node.uuid uuid = node.uuid
r1 = 'fake-reservation' r1 = 'fake-reservation'
# reserve the node # reserve the node
self.dbapi.reserve_node(r1, uuid) res = self.dbapi.reserve_node(r1, uuid)
self.assertItemsEqual(['tag1', 'tag2'], [tag.tag for tag in res.tags])
# check reservation # check reservation
res = self.dbapi.get_node_by_uuid(uuid) res = self.dbapi.get_node_by_uuid(uuid)