diff --git a/oslo_vmware/tests/test_vim_util.py b/oslo_vmware/tests/test_vim_util.py index 2594b763..a542509b 100644 --- a/oslo_vmware/tests/test_vim_util.py +++ b/oslo_vmware/tests/test_vim_util.py @@ -316,6 +316,37 @@ class VimUtilTest(base.TestCase): vim.ContinueRetrievePropertiesEx.assert_called_once_with( vim.service_content.propertyCollector, token=token) + @mock.patch('oslo_vmware.vim_util.continue_retrieval') + @mock.patch('oslo_vmware.vim_util.cancel_retrieval') + def test_with_retrieval(self, cancel_retrieval, continue_retrieval): + vim = mock.Mock() + retrieve_result0 = mock.Mock() + retrieve_result0.objects = [mock.Mock(), mock.Mock()] + retrieve_result1 = mock.Mock() + retrieve_result1.objects = [mock.Mock(), mock.Mock()] + continue_retrieval.side_effect = [retrieve_result1, None] + expected = retrieve_result0.objects + retrieve_result1.objects + + with vim_util.WithRetrieval(vim, retrieve_result0) as iterator: + self.assertEqual(expected, list(iterator)) + + calls = [ + mock.call(vim, retrieve_result0), + mock.call(vim, retrieve_result1)] + continue_retrieval.assert_has_calls(calls) + self.assertFalse(cancel_retrieval.called) + + @mock.patch('oslo_vmware.vim_util.continue_retrieval') + @mock.patch('oslo_vmware.vim_util.cancel_retrieval') + def test_with_retrieval_early_exit(self, cancel_retrieval, + continue_retrieval): + vim = mock.Mock() + retrieve_result = mock.Mock() + with vim_util.WithRetrieval(vim, retrieve_result): + pass + + cancel_retrieval.assert_called_once_with(vim, retrieve_result) + @mock.patch('oslo_vmware.vim_util.get_object_properties') def test_get_object_property(self, get_object_properties): prop = mock.Mock() diff --git a/oslo_vmware/vim_util.py b/oslo_vmware/vim_util.py index 8980a511..e8b6e11f 100644 --- a/oslo_vmware/vim_util.py +++ b/oslo_vmware/vim_util.py @@ -389,6 +389,39 @@ def continue_retrieval(vim, retrieve_result): return vim.ContinueRetrievePropertiesEx(collector, token=token) +class WithRetrieval(object): + """Context to retrieve results. + + This context provides an iterator to retrieve results and cancel (when + needed) retrieve operation on __exit__. + + Example: + + with WithRetrieval(vim, retrieve_result) as objects: + for obj in objects: + # Use obj + """ + + def __init__(self, vim, retrieve_result): + super(WithRetrieval, self).__init__() + self.vim = vim + self.retrieve_result = retrieve_result + + def __enter__(self): + return iter(self) + + def __exit__(self, exc_type, exc_value, traceback): + if self.retrieve_result: + cancel_retrieval(self.vim, self.retrieve_result) + + def __iter__(self): + while self.retrieve_result: + for obj in self.retrieve_result.objects: + yield obj + self.retrieve_result = continue_retrieval( + self.vim, self.retrieve_result) + + def get_object_property(vim, moref, property_name): """Get property of the given managed object. @@ -491,15 +524,14 @@ def get_inventory_path(vim, entity_ref, max_objects=100): entity_name = None propSet = None path = "" - while retrieve_result: - for obj in retrieve_result.objects: + with WithRetrieval(vim, retrieve_result) as objects: + for obj in objects: if hasattr(obj, 'propSet'): propSet = obj.propSet if len(propSet) >= 1 and not entity_name: entity_name = propSet[0].val elif len(propSet) >= 1: path = '%s/%s' % (propSet[0].val, path) - retrieve_result = continue_retrieval(vim, retrieve_result) # NOTE(arnaud): slice to exclude the root folder from the result. if propSet is not None and len(propSet) > 0: path = path[len(propSet[0].val):]