from unittest import mock import unittest class CharmTestCase(unittest.TestCase): def setUp(self): self._patches = {} self._patches_start = {} def tearDown(self): for k, v in self._patches.items(): v.stop() setattr(self, k, None) self._patches = None self._patches_start = None def _patch(self, method): _m = unittest.mock.patch.object(self.obj, method) mock = _m.start() self.addCleanup(_m.stop) return mock def patch_all(self): for method in self.patches: setattr(self, method, self._patch(method)) def patch_object(self, obj, attr, return_value=None, name=None, new=None, **kwargs): if name is None: name = attr if new is not None: mocked = mock.patch.object(obj, attr, new=new, **kwargs) else: mocked = mock.patch.object(obj, attr, **kwargs) self._patches[name] = mocked started = mocked.start() if new is None: started.return_value = return_value self._patches_start[name] = started setattr(self, name, started) def patch(self, item, return_value=None, name=None, new=None, **kwargs): if name is None: raise RuntimeError("Must pass 'name' to .patch()") if new is not None: mocked = mock.patch(item, new=new, **kwargs) else: mocked = mock.patch(item, **kwargs) self._patches[name] = mocked started = mocked.start() if new is None: started.return_value = return_value self._patches_start[name] = started class TestConfig(object): def __init__(self): self.config = {} self.config_prev = {} def __call__(self, key=None): if key: return self.get(key) else: return self def previous(self, k): return self.config_prev[k] if k in self.config_prev else self.config[k] def set_previous(self, k, v): self.config_prev[k] = v def unset_previous(self, k): if k in self.config_prev: self.config_prev.pop(k) def changed(self, k): if not self.config_prev: return True return self.get(k) != self.previous(k) def get(self, attr=None): if not attr: return self try: return self.config[attr] except KeyError: return None def get_all(self): return self.config def set(self, attr, value): self.config[attr] = value def __getitem__(self, k): return self.get(k)