diff --git a/tests/notify/test_listener.py b/tests/notify/test_listener.py
index 70643a7ac..00377b076 100644
--- a/tests/notify/test_listener.py
+++ b/tests/notify/test_listener.py
@@ -26,6 +26,29 @@ from tests import utils as test_utils
 load_tests = testscenarios.load_tests_apply_scenarios
 
 
+class RestartableListenerThread(object):
+    def __init__(self, listener):
+        self.listener = listener
+        self.thread = None
+
+    def start(self):
+        if self.thread is None:
+            self.thread = threading.Thread(target=self.listener.start)
+            self.thread.daemon = True
+            self.thread.start()
+
+    def stop(self):
+        if self.thread is not None:
+            self.listener.stop()
+            self.listener.wait()
+            self.thread.join()
+            self.thread = None
+
+    def wait_end(self):
+        self.thread.join(timeout=15)
+        return self.thread.isAlive()
+
+
 class ListenerSetupMixin(object):
 
     class ListenerTracker(object):
@@ -74,16 +97,10 @@ class ListenerSetupMixin(object):
             allow_requeue=True, pool=pool)
         tracker.listeners.append(listener)
 
-        thread = threading.Thread(target=listener.start)
-        thread.daemon = True
-        thread.listener = listener
+        thread = RestartableListenerThread(listener)
         thread.start()
         return thread
 
-    def _stop_listener(self, thread):
-        thread.join(timeout=15)
-        return thread.isAlive()
-
     def _setup_notifier(self, transport, topic='testtopic',
                         publisher_id='testpublisher'):
         return messaging.Notifier(transport, topic=topic,
@@ -151,7 +168,7 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
         notifier = self._setup_notifier(transport)
         notifier.info({}, 'an_event.start', 'test message')
 
-        self.assertFalse(self._stop_listener(listener_thread))
+        self.assertFalse(listener_thread.wait_end())
 
         endpoint.info.assert_called_once_with(
             {}, 'testpublisher', 'an_event.start', 'test message',
@@ -171,7 +188,7 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
         notifier = self._setup_notifier(transport, topic='topic2')
         notifier.info({'ctxt': '2'}, 'an_event.start2', 'test')
 
-        self.assertFalse(self._stop_listener(listener_thread))
+        self.assertFalse(listener_thread.wait_end())
 
         endpoint.info.assert_has_calls([
             mock.call({'ctxt': '1'}, 'testpublisher',
@@ -214,7 +231,7 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
         notifier.info({'ctxt': '2'},
                       'an_event.start', 'test message exchange2')
 
-        self.assertFalse(self._stop_listener(listener_thread))
+        self.assertFalse(listener_thread.wait_end())
 
         endpoint.info.assert_has_calls([
             mock.call({'ctxt': '1'}, 'testpublisher', 'an_event.start',
@@ -237,7 +254,7 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
         notifier = self._setup_notifier(transport)
         notifier.info({}, 'an_event.start', 'test')
 
-        self.assertFalse(self._stop_listener(listener_thread))
+        self.assertFalse(listener_thread.wait_end())
 
         endpoint1.info.assert_called_once_with(
             {}, 'testpublisher', 'an_event.start', 'test', {
@@ -265,7 +282,7 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
         notifier = self._setup_notifier(transport)
         notifier.info({}, 'an_event.start', 'test')
 
-        self.assertFalse(self._stop_listener(listener_thread))
+        self.assertFalse(listener_thread.wait_end())
 
         endpoint.info.assert_has_calls([
             mock.call({}, 'testpublisher', 'an_event.start', 'test',
@@ -291,8 +308,8 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
         notifier.info({'ctxt': '0'}, 'an_event.start', 'test message0')
         notifier.info({'ctxt': '1'}, 'an_event.start', 'test message1')
 
-        self.assertFalse(self._stop_listener(listener2_thread))
-        self.assertFalse(self._stop_listener(listener1_thread))
+        self.assertFalse(listener2_thread.wait_end())
+        self.assertFalse(listener1_thread.wait_end())
 
         def mocked_endpoint_call(i):
             return mock.call({'ctxt': '%d' % i}, 'testpublisher',
@@ -329,20 +346,42 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
 
         notifier = self._setup_notifier(transport, topic="topic")
         mocked_endpoint1_calls = []
-        for i in range(0, 100):
+        for i in range(0, 25):
             notifier.info({'ctxt': '%d' % i}, 'an_event.start',
                           'test message%d' % i)
             mocked_endpoint1_calls.append(mocked_endpoint_call(i))
 
-        self.assertFalse(self._stop_listener(listener3_thread))
-        self.assertFalse(self._stop_listener(listener2_thread))
-        self.assertFalse(self._stop_listener(listener1_thread))
+        listener2_thread.stop()
+
+        for i in range(0, 25):
+            notifier.info({'ctxt': '%d' % i}, 'an_event.start',
+                          'test message%d' % i)
+            mocked_endpoint1_calls.append(mocked_endpoint_call(i))
+
+        listener2_thread.start()
+        listener3_thread.stop()
+
+        for i in range(0, 25):
+            notifier.info({'ctxt': '%d' % i}, 'an_event.start',
+                          'test message%d' % i)
+            mocked_endpoint1_calls.append(mocked_endpoint_call(i))
+
+        listener3_thread.start()
+
+        for i in range(0, 25):
+            notifier.info({'ctxt': '%d' % i}, 'an_event.start',
+                          'test message%d' % i)
+            mocked_endpoint1_calls.append(mocked_endpoint_call(i))
+
+        self.assertFalse(listener3_thread.wait_end())
+        self.assertFalse(listener2_thread.wait_end())
+        self.assertFalse(listener1_thread.wait_end())
 
         self.assertEqual(100, endpoint1.info.call_count)
         endpoint1.info.assert_has_calls(mocked_endpoint1_calls)
 
-        self.assertNotEqual(0, endpoint2.info.call_count)
-        self.assertNotEqual(0, endpoint3.info.call_count)
+        self.assertLessEqual(25, endpoint2.info.call_count)
+        self.assertLessEqual(25, endpoint3.info.call_count)
 
         self.assertEqual(100, endpoint2.info.call_count +
                          endpoint3.info.call_count)