diff --git a/etc/projects.yaml.sample b/etc/projects.yaml.sample index 33ae7516..9640e0fa 100644 --- a/etc/projects.yaml.sample +++ b/etc/projects.yaml.sample @@ -1,8 +1,12 @@ - project: Test-Project description: First project use-storyboard: true - group: Group-1 + groups: + - Group-1 + - Group-2 - project: Test-Project-Two description: Second project use-storyboard: true - group: Group-1 + groups: + - Group-1 + - Group-3 diff --git a/storyboard/db/projects_loader.py b/storyboard/db/projects_loader.py index b2a28ebe..6fb78db7 100644 --- a/storyboard/db/projects_loader.py +++ b/storyboard/db/projects_loader.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import six import warnings import yaml from oslo.config import cfg from sqlalchemy.exc import SADeprecationWarning -from storyboard.db.api import base as db_api from storyboard.common.custom_types import NameType +from storyboard.db.api import base as db_api from storyboard.db.models import Project from storyboard.db.models import ProjectGroup from storyboard.openstack.common import log @@ -33,63 +32,98 @@ LOG = log.getLogger(__name__) def do_load_models(filename): - config_file = open(filename) + session = db_api.get_session(autocommit=False) projects_list = yaml.load(config_file) - validator = NameType() - - project_groups = dict() + project_groups = list() + # Create all the projects. for project in projects_list: if not project.get('use-storyboard'): continue - group_name = project.get("group") or "default" - if group_name not in project_groups: - project_groups[group_name] = list() + project_instance = _get_project(project, session) + project_instance_groups = list() - project_name = project.get("project") - - try: - validator.validate(project_name) - except Exception: - # Skipping invalid project names - LOG.warn("Project %s was not loaded. Validation failed." - % project_name) + if not project_instance: continue - project_description = project.get("description") + groups = project.get("groups") or [] + for group_name in groups: + group_instance = _get_project_group(group_name, session) + project_instance_groups.append(group_instance) - project_groups[group_name].append( - {"name": project_name, - "description": project_description}) + if group_instance not in project_groups: + project_groups.append(group_instance) - session = db_api.get_session() + # Brute force diff + groups_to_remove = set(project_instance.project_groups) - set( + project_instance_groups) + groups_to_add = set(project_instance_groups) - set( + project_instance.project_groups) - with session.begin(): - for project_group_name, projects in six.iteritems(project_groups): - db_project_group = session.query(ProjectGroup)\ - .filter_by(name=project_group_name).first() - if not db_project_group: - db_project_group = ProjectGroup() - db_project_group.name = project_group_name - db_project_group.projects = [] + for group in groups_to_remove: + project_instance.project_groups.remove(group) - for project in projects: - db_project = session.query(Project)\ - .filter_by(name=project["name"]).first() - if not db_project: - db_project = Project() - db_project.name = project["name"] + for group in groups_to_add: + project_instance.project_groups.append(group) - if project['description']: - project['description'] = unicode(project["description"]) + if len(groups_to_remove) + len(groups_to_add) > 0: + session.add(project_instance) - db_project.description = project["description"] - session.add(db_project) + # Now, go through all groups that were not explicitly listed and delete + # them. + project_groups_to_delete = list() + current_groups = session.query(ProjectGroup) + for current_group in current_groups: + if current_group not in project_groups: + project_groups_to_delete.append(current_group) - db_project_group.projects.append(db_project) + for group in project_groups_to_delete: + session.delete(group) - session.add(db_project_group) + session.commit() + + +def _get_project(project, session): + validator = NameType() + name = unicode(project['project']) + if 'description' in project: + description = unicode(project['description']) + else: + description = '' + + try: + validator.validate(name) + except Exception: + # Skipping invalid project names + LOG.warn("Project %s was not loaded. Validation failed." + % [name, ]) + return None + + db_project = session.query(Project) \ + .filter_by(name=name).first() + if not db_project: + db_project = Project() + db_project.name = name + db_project.description = description + db_project.groups = [] + + session.add(db_project) + + return db_project + + +def _get_project_group(project_group_name, session): + db_project_group = session.query(ProjectGroup) \ + .filter_by(name=project_group_name).first() + + if not db_project_group: + db_project_group = ProjectGroup() + db_project_group.name = project_group_name + + session.add(db_project_group) + + return db_project_group diff --git a/storyboard/tests/db/migration/test_cli_project_import.py b/storyboard/tests/db/migration/test_cli_project_import.py new file mode 100644 index 00000000..a44225f2 --- /dev/null +++ b/storyboard/tests/db/migration/test_cli_project_import.py @@ -0,0 +1,34 @@ +# Copyright (c) 2014 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import storyboard.db.api.base as api_base +from storyboard.db.models import ProjectGroup +from storyboard.db import projects_loader +from storyboard.tests import base + + +class TestProjectGroupMigration(base.FunctionalTest): + """Unit tests for the load_projects commandline option, focused on + groups only. + """ + + def setUp(self): + super(TestProjectGroupMigration, self).setUp() + + def testSimpleGroupMigration(self): + projects_loader.do_load_models('./etc/projects.yaml.sample') + + all_groups = api_base.entity_get_all(ProjectGroup) + + self.assertEqual(3, len(all_groups))