Merge pull request #11 from racker/tasks_for_doing_things
Tasks and Image Download
This commit is contained in:
commit
b3ee1d828c
@ -4,3 +4,4 @@ argparse==1.2.1
|
||||
wsgiref==0.1.2
|
||||
zope.interface==4.0.5
|
||||
structlog==0.3.0
|
||||
treq==0.2.0
|
||||
|
126
teeth_agent/base_task.py
Normal file
126
teeth_agent/base_task.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""
|
||||
Copyright 2013 Rackspace, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from twisted.application.service import MultiService
|
||||
from twisted.application.internet import TimerService
|
||||
from twisted.internet import defer
|
||||
from teeth_agent.logging import get_logger
|
||||
|
||||
__all__ = ['BaseTask', 'MultiTask']
|
||||
|
||||
|
||||
class BaseTask(MultiService, object):
|
||||
"""
|
||||
Task to execute, reporting status periodically to TeethClient instance.
|
||||
"""
|
||||
|
||||
task_name = 'task_undefined'
|
||||
|
||||
def __init__(self, client, task_id, reporting_interval=10):
|
||||
super(BaseTask, self).__init__()
|
||||
self.log = get_logger(task_id=task_id, task_name=self.task_name)
|
||||
self.setName(self.task_name + '.' + task_id)
|
||||
self._client = client
|
||||
self._id = task_id
|
||||
self._percent = 0
|
||||
self._reporting_interval = reporting_interval
|
||||
self._state = 'starting'
|
||||
self._timer = TimerService(self._reporting_interval, self._tick)
|
||||
self._timer.setServiceParent(self)
|
||||
self._error_msg = None
|
||||
self._done = False
|
||||
self._d = defer.Deferred()
|
||||
|
||||
def _run(self):
|
||||
"""Do the actual work here."""
|
||||
|
||||
def run(self):
|
||||
"""Run the Task."""
|
||||
# setServiceParent actually starts the task if it is already running
|
||||
# so we run it in start.
|
||||
if not self.parent:
|
||||
self.setServiceParent(self._client)
|
||||
self._run()
|
||||
return self._d
|
||||
|
||||
def _tick(self):
|
||||
if not self.running:
|
||||
# log.debug("_tick called while not running :()")
|
||||
return
|
||||
|
||||
if self._state in ['error', 'complete']:
|
||||
self.stopService()
|
||||
|
||||
return self._client.update_task_status(self)
|
||||
|
||||
def error(self, message, *args, **kwargs):
|
||||
"""Error out running of the task."""
|
||||
self._error_msg = message
|
||||
self._state = 'error'
|
||||
self.stopService()
|
||||
|
||||
def complete(self, *args, **kwargs):
|
||||
"""Complete running of the task."""
|
||||
self._state = 'complete'
|
||||
self.stopService()
|
||||
|
||||
def startService(self):
|
||||
"""Start the Service."""
|
||||
self._state = 'running'
|
||||
super(BaseTask, self).startService()
|
||||
|
||||
def stopService(self):
|
||||
"""Stop the Service."""
|
||||
super(BaseTask, self).stopService()
|
||||
|
||||
if self._state not in ['error', 'complete']:
|
||||
self.log.err("told to shutdown before task could complete, marking as error.")
|
||||
self._error_msg = 'service being shutdown'
|
||||
self._state = 'error'
|
||||
|
||||
if self._done is False:
|
||||
self._done = True
|
||||
self._d.callback(None)
|
||||
self._client.finish_task(self)
|
||||
|
||||
|
||||
class MultiTask(BaseTask):
|
||||
|
||||
"""Run multiple tasks in parallel."""
|
||||
|
||||
def __init__(self, client, task_id, reporting_interval=10):
|
||||
super(MultiTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
|
||||
self._tasks = []
|
||||
|
||||
def _tick(self):
|
||||
if len(self._tasks):
|
||||
percents = [t._percent for t in self._tasks]
|
||||
self._percent = sum(percents)/float(len(percents))
|
||||
else:
|
||||
self._percent = 0
|
||||
super(MultiTask, self)._tick()
|
||||
|
||||
def _run(self):
|
||||
ds = []
|
||||
for t in self._tasks:
|
||||
ds.append(t.run())
|
||||
dl = defer.DeferredList(ds)
|
||||
dl.addBoth(self.complete, self.error)
|
||||
|
||||
def add_task(self, task):
|
||||
"""Add a task to be ran."""
|
||||
task.setServiceParent(self)
|
||||
self._tasks.append(task)
|
54
teeth_agent/cache_image.py
Normal file
54
teeth_agent/cache_image.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""
|
||||
Copyright 2013 Rackspace, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from teeth_agent.base_task import BaseTask
|
||||
import treq
|
||||
|
||||
|
||||
class ImageDownloaderTask(BaseTask):
|
||||
"""Download image to cache. """
|
||||
task_name = 'image_download'
|
||||
|
||||
def __init__(self, client, task_id, image_info, destination_filename, reporting_interval=10):
|
||||
super(ImageDownloaderTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
|
||||
self._destination_filename = destination_filename
|
||||
self._image_id = image_info.id
|
||||
self._image_hashes = image_info.hashes
|
||||
self._iamge_urls = image_info.urls
|
||||
self._destination_filename = destination_filename
|
||||
|
||||
def _run(self):
|
||||
# TODO: pick by protocol priority.
|
||||
url = self._iamge_urls[0]
|
||||
# TODO: more than just download, sha1 it.
|
||||
return self._download_image_to_file(url)
|
||||
|
||||
def _tick(self):
|
||||
# TODO: get file download percentages.
|
||||
self.percent = 0
|
||||
super(ImageDownloaderTask, self)._tick()
|
||||
|
||||
def _download_image_to_file(self, url):
|
||||
destination = open(self._destination_filename, 'wb')
|
||||
|
||||
def push(data):
|
||||
if self.running:
|
||||
destination.write(data)
|
||||
|
||||
d = treq.get(url)
|
||||
d.addCallback(treq.collect, push)
|
||||
d.addBoth(lambda _: destination.close())
|
||||
return d
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
from twisted.application.service import MultiService
|
||||
from twisted.application.internet import TCPClient
|
||||
@ -87,6 +88,12 @@ class TeethClient(MultiService, object):
|
||||
}
|
||||
}
|
||||
|
||||
@property
|
||||
def conf_image_cache_path(self):
|
||||
"""Path to iamge cache."""
|
||||
# TODO: improve:
|
||||
return tempfile.gettempdir()
|
||||
|
||||
def startService(self):
|
||||
"""Start the Service."""
|
||||
super(TeethClient, self).startService()
|
||||
|
@ -73,6 +73,17 @@ class RPCError(RPCMessage, RuntimeError):
|
||||
self._raw_message = message
|
||||
|
||||
|
||||
class ImageInfo(object):
|
||||
"""
|
||||
Metadata about a machine image.
|
||||
"""
|
||||
def __init__(self, image_id, image_urls, image_hashes):
|
||||
super(ImageInfo, self).__init__()
|
||||
self.id = image_id
|
||||
self.urls = image_urls
|
||||
self.hashes = image_hashes
|
||||
|
||||
|
||||
class CommandValidationError(RuntimeError):
|
||||
"""
|
||||
Exception class which can be used to return an error when the
|
||||
|
@ -14,82 +14,34 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from twisted.application.service import MultiService
|
||||
from twisted.application.internet import TimerService
|
||||
from teeth_agent.logging import get_logger
|
||||
import os
|
||||
|
||||
from teeth_agent.base_task import MultiTask, BaseTask
|
||||
from teeth_agent.cache_image import ImageDownloaderTask
|
||||
|
||||
|
||||
__all__ = ['Task', 'PrepareImageTask']
|
||||
__all__ = ['CacheImagesTask', 'PrepareImageTask']
|
||||
|
||||
|
||||
class Task(MultiService, object):
|
||||
"""
|
||||
Task to execute, reporting status periodically to TeethClient instance.
|
||||
"""
|
||||
class CacheImagesTask(MultiTask):
|
||||
|
||||
task_name = 'task_undefined'
|
||||
"""Cache an array of images on a machine."""
|
||||
|
||||
def __init__(self, client, task_id, reporting_interval=10):
|
||||
super(Task, self).__init__()
|
||||
self.setName(self.task_name)
|
||||
self._client = client
|
||||
self._id = task_id
|
||||
self._percent = 0
|
||||
self._reporting_interval = reporting_interval
|
||||
self._state = 'starting'
|
||||
self._timer = TimerService(self._reporting_interval, self._tick)
|
||||
self._timer.setServiceParent(self)
|
||||
self._error_msg = None
|
||||
self.log = get_logger(task_id=task_id, task_name=self.task_name)
|
||||
task_name = 'cache_images'
|
||||
|
||||
def _run(self):
|
||||
"""Do the actual work here."""
|
||||
|
||||
def run(self):
|
||||
"""Run the Task."""
|
||||
# setServiceParent actually starts the task if it is already running
|
||||
# so we run it in start.
|
||||
self.setServiceParent(self._client)
|
||||
self._run()
|
||||
|
||||
def _tick(self):
|
||||
if not self.running:
|
||||
# log.debug("_tick called while not running :()")
|
||||
return
|
||||
return self._client.update_task_status(self)
|
||||
|
||||
def error(self, message):
|
||||
"""Error out running of the task."""
|
||||
self._error_msg = message
|
||||
self._state = 'error'
|
||||
self.stopService()
|
||||
|
||||
def complete(self):
|
||||
"""Complete running of the task."""
|
||||
self._state = 'complete'
|
||||
self.stopService()
|
||||
|
||||
def startService(self):
|
||||
"""Start the Service."""
|
||||
super(Task, self).startService()
|
||||
self._state = 'running'
|
||||
|
||||
def stopService(self):
|
||||
"""Stop the Service."""
|
||||
super(Task, self).stopService()
|
||||
|
||||
if not self._client.running:
|
||||
return
|
||||
|
||||
if self._state not in ['error', 'complete']:
|
||||
self.log.err("told to shutdown before task could complete, marking as error.")
|
||||
self._error_msg = 'service being shutdown'
|
||||
self._state = 'error'
|
||||
|
||||
self._client.finish_task(self)
|
||||
def __init__(self, client, task_id, images, reporting_interval=10):
|
||||
super(CacheImagesTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
|
||||
self._images = images
|
||||
for image in self._images:
|
||||
image_path = os.path.join(client.conf_image_cache_path, image.id + '.img')
|
||||
t = ImageDownloaderTask(client,
|
||||
task_id, image,
|
||||
image_path,
|
||||
reporting_interval=reporting_interval)
|
||||
self.add_task(t)
|
||||
|
||||
|
||||
class PrepareImageTask(Task):
|
||||
class PrepareImageTask(BaseTask):
|
||||
|
||||
"""Prepare an image to be ran on the machine."""
|
||||
|
||||
|
@ -15,25 +15,37 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import shutil
|
||||
import tempfile
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.trial import unittest
|
||||
from teeth_agent.task import Task
|
||||
from mock import Mock
|
||||
from twisted.web.client import ResponseDone
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from teeth_agent.protocol import ImageInfo
|
||||
from teeth_agent.base_task import BaseTask, MultiTask
|
||||
from teeth_agent.cache_image import ImageDownloaderTask
|
||||
|
||||
|
||||
class FakeClient(object):
|
||||
addService = Mock(return_value=None)
|
||||
running = Mock(return_value=0)
|
||||
update_task_status = Mock(return_value=None)
|
||||
finish_task = Mock(return_value=None)
|
||||
def __init__(self):
|
||||
self.addService = Mock(return_value=None)
|
||||
self.running = Mock(return_value=0)
|
||||
self.update_task_status = Mock(return_value=None)
|
||||
self.finish_task = Mock(return_value=None)
|
||||
|
||||
|
||||
class TestTask(Task):
|
||||
class TestTask(BaseTask):
|
||||
task_name = 'test_task'
|
||||
|
||||
|
||||
class TaskTest(unittest.TestCase):
|
||||
"""Event Emitter tests."""
|
||||
"""Basic tests of the Task API."""
|
||||
|
||||
def setUp(self):
|
||||
self.task_id = str(uuid.uuid4())
|
||||
@ -45,6 +57,15 @@ class TaskTest(unittest.TestCase):
|
||||
del self.task
|
||||
del self.client
|
||||
|
||||
def test_error(self):
|
||||
self.task.run()
|
||||
self.client.addService.assert_called_once_with(self.task)
|
||||
self.task.startService()
|
||||
self.client.update_task_status.assert_called_once_with(self.task)
|
||||
self.task.error('chaos monkey attack')
|
||||
self.assertEqual(self.task._state, 'error')
|
||||
self.client.finish_task.assert_called_once_with(self.task)
|
||||
|
||||
def test_run(self):
|
||||
self.assertEqual(self.task._state, 'starting')
|
||||
self.assertEqual(self.task._id, self.task_id)
|
||||
@ -55,3 +76,118 @@ class TaskTest(unittest.TestCase):
|
||||
self.task.complete()
|
||||
self.assertEqual(self.task._state, 'complete')
|
||||
self.client.finish_task.assert_called_once_with(self.task)
|
||||
|
||||
def test_fast_shutdown(self):
|
||||
self.task.run()
|
||||
self.client.addService.assert_called_once_with(self.task)
|
||||
self.task.startService()
|
||||
self.client.update_task_status.assert_called_once_with(self.task)
|
||||
self.task.stopService()
|
||||
self.assertEqual(self.task._state, 'error')
|
||||
self.client.finish_task.assert_called_once_with(self.task)
|
||||
|
||||
|
||||
class MultiTestTask(MultiTask):
|
||||
task_name = 'test_multitask'
|
||||
|
||||
|
||||
class MultiTaskTest(unittest.TestCase):
|
||||
"""Basic tests of the Multi Task API."""
|
||||
|
||||
def setUp(self):
|
||||
self.task_id = str(uuid.uuid4())
|
||||
self.client = FakeClient()
|
||||
self.task = MultiTestTask(self.client, self.task_id)
|
||||
|
||||
def tearDown(self):
|
||||
del self.task_id
|
||||
del self.task
|
||||
del self.client
|
||||
|
||||
def test_tasks(self):
|
||||
t = TestTask(self.client, self.task_id)
|
||||
self.task.add_task(t)
|
||||
self.assertEqual(self.task._state, 'starting')
|
||||
self.assertEqual(self.task._id, self.task_id)
|
||||
self.task.run()
|
||||
self.client.addService.assert_called_once_with(self.task)
|
||||
self.task.startService()
|
||||
self.client.update_task_status.assert_any_call(self.task)
|
||||
t.complete()
|
||||
self.assertEqual(self.task._state, 'complete')
|
||||
self.client.finish_task.assert_any_call(t)
|
||||
self.client.finish_task.assert_any_call(self.task)
|
||||
|
||||
|
||||
class StubResponse(object):
|
||||
def __init__(self, code, headers, body):
|
||||
self.version = ('HTTP', 1, 1)
|
||||
self.code = code
|
||||
self.status = "ima teapot"
|
||||
self.headers = headers
|
||||
self.body = body
|
||||
self.length = reduce(lambda x, y: x + len(y), body, 0)
|
||||
self.protocol = None
|
||||
|
||||
def deliverBody(self, protocol):
|
||||
self.protocol = protocol
|
||||
|
||||
def run(self):
|
||||
self.protocol.connectionMade()
|
||||
|
||||
for data in self.body:
|
||||
self.protocol.dataReceived(data)
|
||||
|
||||
self.protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
|
||||
|
||||
|
||||
class ImageDownloaderTaskTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
get_patcher = patch('treq.get', autospec=True)
|
||||
self.TreqGet = get_patcher.start()
|
||||
self.addCleanup(get_patcher.stop)
|
||||
|
||||
self.tmpdir = tempfile.mkdtemp('image_download_test')
|
||||
self.task_id = str(uuid.uuid4())
|
||||
self.image_data = str(uuid.uuid4())
|
||||
self.image_md5 = hashlib.md5(self.image_data).hexdigest()
|
||||
self.cache_path = os.path.join(self.tmpdir, 'a1234.img')
|
||||
self.client = FakeClient()
|
||||
self.image_info = ImageInfo('a1234',
|
||||
['http://127.0.0.1/images/a1234.img'], {'md5': self.image_md5})
|
||||
self.task = ImageDownloaderTask(self.client,
|
||||
self.task_id,
|
||||
self.image_info,
|
||||
self.cache_path)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdir)
|
||||
|
||||
def assertFileHash(self, hash_type, path, value):
|
||||
file_hash = hashlib.new(hash_type)
|
||||
with open(path, 'r') as fp:
|
||||
file_hash.update(fp.read())
|
||||
self.assertEqual(value, file_hash.hexdigest())
|
||||
|
||||
def test_download_success(self):
|
||||
resp = StubResponse(200, [], [self.image_data])
|
||||
d = defer.Deferred()
|
||||
self.TreqGet.return_value = d
|
||||
self.task.run()
|
||||
self.client.addService.assert_called_once_with(self.task)
|
||||
|
||||
self.TreqGet.assert_called_once_with('http://127.0.0.1/images/a1234.img')
|
||||
|
||||
self.task.startService()
|
||||
|
||||
d.callback(resp)
|
||||
|
||||
resp.run()
|
||||
|
||||
self.client.update_task_status.assert_called_once_with(self.task)
|
||||
self.assertFileHash('md5', self.cache_path, self.image_md5)
|
||||
|
||||
self.task.stopService()
|
||||
self.assertEqual(self.task._state, 'error')
|
||||
self.client.finish_task.assert_called_once_with(self.task)
|
||||
|
Loading…
x
Reference in New Issue
Block a user