
Fresh start for the StarlingX automation framework. Change-Id: Ie265e0791024f45f71faad6315c2b91b022934d1
138 lines
5.7 KiB
Python
138 lines
5.7 KiB
Python
import concurrent
|
|
import threading
|
|
import traceback
|
|
from concurrent import futures
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from framework.logging.automation_logger import get_logger
|
|
from framework.threading.thread_namer import ThreadNamer
|
|
from framework.threading.thread_object import ThreadObject
|
|
|
|
|
|
def execute_function_in_named_thread(thread_object, function_to_execute, *args, **kwargs):
|
|
"""
|
|
This function will name the thread for logging purposes
|
|
|
|
Args:
|
|
thread_object (ThreadObject): The object representing this thread
|
|
function_to_execute (function): The operation that we want to execute in the thread
|
|
*args: The list of arguments taken by 'function_to_execute'
|
|
**kwargs: The list of kwargs taken by 'function_to_execute'
|
|
Return: function_to_execute's return value
|
|
"""
|
|
try:
|
|
threading.current_thread().name = thread_object.name
|
|
return function_to_execute(*args, **kwargs)
|
|
except Exception as e:
|
|
thread_object.exception = traceback.format_exc()
|
|
raise e
|
|
|
|
|
|
class ThreadManager:
|
|
"""
|
|
This class represents an object that will manage threads for a set of operations.
|
|
- Instantiate the class.
|
|
- Call start_thread by passing in the set of commands that you want to run in parallel.
|
|
- Call join_all_threads to wait for all the threads to complete executing successfully.
|
|
If an exception is encountered during any thread, it will be propagated up
|
|
when we call join_all_threads.
|
|
"""
|
|
|
|
def __init__(self, num_threads: int = 50, timeout: int = 3600, log_thread_status: bool = False):
|
|
"""
|
|
This object is used to spawn threads to perform operations and join them all together,
|
|
Args:
|
|
num_threads: Maximum number of threads available to use at a given time.
|
|
timeout: Number of seconds to wait for the threads to complete before timing out
|
|
log_thread_status: True if we want to log when threads start and complete.
|
|
"""
|
|
self.future_to_thread_object_dict = {}
|
|
self.executor = ThreadPoolExecutor(max_workers=num_threads)
|
|
self.timeout = timeout
|
|
self.log_thread_status = log_thread_status
|
|
|
|
def start_thread(self, thread_name, function_to_execute, *args, **kwargs):
|
|
"""
|
|
This function spawn a new thread to execute the function_to_execute
|
|
|
|
Args:
|
|
thread_name (str): is used to describe this thread for logging purposes
|
|
function_to_execute (function): is the function to Execute
|
|
*args: is the list of arguments to pass to the function_to_execute
|
|
*kwargs is the list of kwargs to pass to the function_to_execute
|
|
Returns:
|
|
An instance of ThreadObject which contains the future promised by the execution of
|
|
this thread.
|
|
"""
|
|
|
|
# Pick the name of the Thread, taking nesting into account.
|
|
thread_namer = ThreadNamer()
|
|
full_thread_name = thread_namer.get_thread_full_name(thread_name)
|
|
thread_object = ThreadObject(full_thread_name, thread_name)
|
|
|
|
# Start running the function_to_execute in a thread
|
|
future = self.executor.submit(execute_function_in_named_thread, thread_object, function_to_execute, *args, **kwargs)
|
|
|
|
# Store the information about the running thread in a ThreadObject
|
|
thread_object.future = future
|
|
self.future_to_thread_object_dict[future] = thread_object
|
|
|
|
def join_all_threads(self):
|
|
"""
|
|
This function will wait for all the threads spawned by this object to complete.
|
|
The function will also raise an exception if any of the threads has encountered an exception.
|
|
|
|
Return:
|
|
results: a map (raw_thread_name, function_to_execute's return values)
|
|
"""
|
|
|
|
exceptions_in_thread = 0
|
|
results = {}
|
|
|
|
try:
|
|
# Wait for all the threads to have completed the operation before moving on.
|
|
for future in futures.as_completed(self.future_to_thread_object_dict.keys(), timeout=self.timeout):
|
|
|
|
if self.log_thread_status:
|
|
get_logger().log_info("Thread {} is done!".format(self.future_to_thread_object_dict[future].name))
|
|
raw_thread_name = self.future_to_thread_object_dict[future].raw_thread_name
|
|
results[raw_thread_name] = future.result()
|
|
|
|
# Log Exceptions to the Console as they happen.
|
|
if future.exception():
|
|
exceptions_in_thread += 1
|
|
self.future_to_thread_object_dict[future].log_exception()
|
|
|
|
except concurrent.futures._base.TimeoutError:
|
|
get_logger().log_error("There is at least one thread that timed out!")
|
|
raise TimeoutError("Thread Timeout")
|
|
|
|
# Generate a Report of Failures that happened in the Threads if any.
|
|
if exceptions_in_thread:
|
|
for thread_object in self.future_to_thread_object_dict.values():
|
|
if thread_object.exception:
|
|
thread_object.log_exception()
|
|
raise AssertionError(f"An error was raised in {exceptions_in_thread} threads.")
|
|
|
|
# Resume back to single-threaded execution.
|
|
if self.log_thread_status:
|
|
get_logger().log_info("Thread joining Complete!")
|
|
|
|
return results
|
|
|
|
def get_thread_object(self, raw_thread_name: str):
|
|
"""
|
|
Gets the Thread Object associated with the raw_thread_name specified.
|
|
|
|
Args:
|
|
raw_thread_name:
|
|
|
|
Returns:
|
|
|
|
"""
|
|
for thread_object in self.future_to_thread_object_dict.values():
|
|
if thread_object.get_raw_thread_name() == raw_thread_name:
|
|
return thread_object
|
|
|
|
raise Exception(f"Failed to find a thread with raw_thread_name='{raw_thread_name}'")
|