diff --git a/oslo_messaging/_drivers/protocols/amqp/controller.py b/oslo_messaging/_drivers/protocols/amqp/controller.py
index d89eb0ae7..1dd91f7ec 100644
--- a/oslo_messaging/_drivers/protocols/amqp/controller.py
+++ b/oslo_messaging/_drivers/protocols/amqp/controller.py
@@ -158,11 +158,12 @@ class Server(pyngus.ReceiverEventHandler):
     from a given target.  Messages arriving on the links are placed on the
     'incoming' queue.
     """
-    def __init__(self, addresses, incoming):
+    def __init__(self, addresses, incoming, subscription_id):
         self._incoming = incoming
         self._addresses = addresses
         self._capacity = 500   # credit per link
         self._receivers = None
+        self._id = subscription_id
 
     def attach(self, connection):
         """Create receiver links over the given connection for all the
@@ -267,7 +268,8 @@ class Controller(pyngus.ConnectionEventHandler):
         self._max_task_batch = 50
         # cache of sending links indexed by address:
         self._senders = {}
-        # Servers (set of receiving links), indexed by target:
+        # Servers indexed by target. Each entry is a map indexed by the
+        # specific ProtonListener's identifier:
         self._servers = {}
 
         opt_group = cfg.OptGroup(name='oslo_messaging_amqp',
@@ -329,8 +331,9 @@ class Controller(pyngus.ConnectionEventHandler):
             self.processor = None
         self._tasks = None
         self._senders = None
-        for server in self._servers.values():
-            server.destroy()
+        for servers in self._servers.values():
+            for server in servers.values():
+                server.destroy()
         self._servers.clear()
         self._socket_connection = None
         if self._replies:
@@ -382,7 +385,7 @@ class Controller(pyngus.ConnectionEventHandler):
         LOG.debug("Sending response to %s", address)
         self._send(address, response)
 
-    def subscribe(self, target, in_queue):
+    def subscribe(self, target, in_queue, subscription_id):
         """Subscribe to messages sent to 'target', place received messages on
         'in_queue'.
         """
@@ -391,20 +394,25 @@ class Controller(pyngus.ConnectionEventHandler):
             self._broadcast_address(target),
             self._group_request_address(target)
         ]
-        self._subscribe(target, addresses, in_queue)
+        self._subscribe(target, addresses, in_queue, subscription_id)
 
-    def subscribe_notifications(self, target, in_queue):
+    def subscribe_notifications(self, target, in_queue, subscription_id):
         """Subscribe for notifications on 'target', place received messages on
         'in_queue'.
         """
         addresses = [self._group_request_address(target)]
-        self._subscribe(target, addresses, in_queue)
+        self._subscribe(target, addresses, in_queue, subscription_id)
 
-    def _subscribe(self, target, addresses, in_queue):
+    def _subscribe(self, target, addresses, in_queue, subscription_id):
         LOG.debug("Subscribing to %(target)s (%(addresses)s)",
                   {'target': target, 'addresses': addresses})
-        self._servers[target] = Server(addresses, in_queue)
-        self._servers[target].attach(self._socket_connection.connection)
+        server = Server(addresses, in_queue, subscription_id)
+        servers = self._servers.get(target)
+        if servers is None:
+            servers = {}
+            self._servers[target] = servers
+        servers[subscription_id] = server
+        server.attach(self._socket_connection.connection)
 
     def _resolve(self, target):
         """Return a link address for a given target."""
@@ -583,8 +591,9 @@ class Controller(pyngus.ConnectionEventHandler):
         LOG.debug("Connection active (%(hostname)s:%(port)s), subscribing...",
                   {'hostname': self.hosts.current.hostname,
                    'port': self.hosts.current.port})
-        for s in self._servers.values():
-            s.attach(self._socket_connection.connection)
+        for servers in self._servers.values():
+            for server in servers.values():
+                server.attach(self._socket_connection.connection)
         self._replies = Replies(self._socket_connection.connection,
                                 lambda: self._reply_link_ready())
         self._delay = 0
diff --git a/oslo_messaging/_drivers/protocols/amqp/driver.py b/oslo_messaging/_drivers/protocols/amqp/driver.py
index 04feb2de1..68fbbf4d8 100644
--- a/oslo_messaging/_drivers/protocols/amqp/driver.py
+++ b/oslo_messaging/_drivers/protocols/amqp/driver.py
@@ -25,6 +25,7 @@ import logging
 import os
 import threading
 import time
+import uuid
 
 from oslo_serialization import jsonutils
 from oslo_utils import importutils
@@ -149,6 +150,7 @@ class ProtonListener(base.Listener):
         super(ProtonListener, self).__init__(driver.prefetch_size)
         self.driver = driver
         self.incoming = Queue()
+        self.id = uuid.uuid4().hex
 
     def stop(self):
         self.incoming.stop()
diff --git a/oslo_messaging/_drivers/protocols/amqp/drivertasks.py b/oslo_messaging/_drivers/protocols/amqp/drivertasks.py
index 04943961d..0addc0758 100644
--- a/oslo_messaging/_drivers/protocols/amqp/drivertasks.py
+++ b/oslo_messaging/_drivers/protocols/amqp/drivertasks.py
@@ -83,9 +83,12 @@ class ListenTask(controller.Task):
         """
         if self._notifications:
             controller.subscribe_notifications(self._target,
-                                               self._listener.incoming)
+                                               self._listener.incoming,
+                                               self._listener.id)
         else:
-            controller.subscribe(self._target, self._listener.incoming)
+            controller.subscribe(self._target,
+                                 self._listener.incoming,
+                                 self._listener.id)
 
 
 class ReplyTask(controller.Task):
diff --git a/oslo_messaging/tests/test_amqp_driver.py b/oslo_messaging/tests/test_amqp_driver.py
index 909bc599d..b011a6383 100644
--- a/oslo_messaging/tests/test_amqp_driver.py
+++ b/oslo_messaging/tests/test_amqp_driver.py
@@ -47,6 +47,12 @@ CYRUS_ENABLED = (pyngus and pyngus.VERSION >= (2, 0, 0) and _proton
 LOG = logging.getLogger(__name__)
 
 
+def _wait_until(predicate, timeout):
+    deadline = timeout + time.time()
+    while not predicate() and deadline > time.time():
+        time.sleep(0.1)
+
+
 class _ListenerThread(threading.Thread):
     """Run a blocking listener in a thread."""
     def __init__(self, listener, msg_count):
@@ -55,10 +61,13 @@ class _ListenerThread(threading.Thread):
         self.msg_count = msg_count
         self.messages = moves.queue.Queue()
         self.daemon = True
+        self.started = threading.Event()
         self.start()
+        self.started.wait()
 
     def run(self):
         LOG.debug("Listener started")
+        self.started.set()
         while self.msg_count > 0:
             in_msg = self.listener.poll()[0]
             self.messages.put(in_msg)
@@ -515,12 +524,19 @@ class TestFailover(test_utils.BaseTestCase):
         target = oslo_messaging.Target(topic="my-topic")
         listener = _ListenerThread(driver.listen(target), 2)
 
+        # wait for listener links to come up
+        # 4 == 3 links per listener + 1 for the global reply queue
+        predicate = lambda: self._brokers[0].sender_link_count == 4
+        _wait_until(predicate, 30)
+        self.assertTrue(predicate())
+
         rc = driver.send(target, {"context": "whatever"},
                          {"method": "echo", "id": "echo-1"},
                          wait_for_reply=True,
                          timeout=30)
         self.assertIsNotNone(rc)
         self.assertEqual(rc.get('correlation-id'), 'echo-1')
+
         # 1 request msg, 1 response:
         self.assertEqual(self._brokers[0].topic_count, 1)
         self.assertEqual(self._brokers[0].direct_count, 1)
@@ -528,28 +544,25 @@ class TestFailover(test_utils.BaseTestCase):
         # fail broker 0 and start broker 1:
         self._brokers[0].stop()
         self._brokers[1].start()
-        deadline = time.time() + 30
-        responded = False
-        sequence = 2
-        while deadline > time.time() and not responded:
-            if not listener.isAlive():
-                # listener may have exited after replying to an old correlation
-                # id: restart new listener
-                listener = _ListenerThread(driver.listen(target), 1)
-            try:
-                rc = driver.send(target, {"context": "whatever"},
-                                 {"method": "echo",
-                                  "id": "echo-%d" % sequence},
-                                 wait_for_reply=True,
-                                 timeout=2)
-                self.assertIsNotNone(rc)
-                self.assertEqual(rc.get('correlation-id'),
-                                 'echo-%d' % sequence)
-                responded = True
-            except oslo_messaging.MessagingTimeout:
-                sequence += 1
 
-        self.assertTrue(responded)
+        # wait for listener links to re-establish
+        # 4 = 3 links per listener + 1 for the global reply queue
+        predicate = lambda: self._brokers[1].sender_link_count == 4
+        _wait_until(predicate, 30)
+        self.assertTrue(predicate())
+
+        rc = driver.send(target,
+                         {"context": "whatever"},
+                         {"method": "echo", "id": "echo-2"},
+                         wait_for_reply=True,
+                         timeout=2)
+        self.assertIsNotNone(rc)
+        self.assertEqual(rc.get('correlation-id'), 'echo-2')
+
+        # 1 request msg, 1 response:
+        self.assertEqual(self._brokers[1].topic_count, 1)
+        self.assertEqual(self._brokers[1].direct_count, 1)
+
         listener.join(timeout=30)
         self.assertFalse(listener.isAlive())
 
@@ -558,6 +571,55 @@ class TestFailover(test_utils.BaseTestCase):
         self._brokers[1].stop()
         driver.cleanup()
 
+    def test_listener_failover(self):
+        """Verify that Listeners are re-established after failover.
+        """
+        self._brokers[0].start()
+        driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
+
+        target = oslo_messaging.Target(topic="my-topic")
+        bcast = oslo_messaging.Target(topic="my-topic", fanout=True)
+        listener1 = _ListenerThread(driver.listen(target), 2)
+        listener2 = _ListenerThread(driver.listen(target), 2)
+
+        # wait for 7 sending links to become active on the broker.
+        # 7 = 3 links per Listener + 1 global reply link
+        predicate = lambda: self._brokers[0].sender_link_count == 7
+        _wait_until(predicate, 30)
+        self.assertTrue(predicate())
+
+        driver.send(bcast, {"context": "whatever"},
+                    {"method": "ignore", "id": "echo-1"})
+
+        # 1 message per listener
+        predicate = lambda: self._brokers[0].fanout_sent_count == 2
+        _wait_until(predicate, 30)
+        self.assertTrue(predicate())
+
+        # fail broker 0 and start broker 1:
+        self._brokers[0].stop()
+        self._brokers[1].start()
+
+        # wait again for 7 sending links to re-establish
+        predicate = lambda: self._brokers[1].sender_link_count == 7
+        _wait_until(predicate, 30)
+        self.assertTrue(predicate())
+
+        driver.send(bcast, {"context": "whatever"},
+                    {"method": "ignore", "id": "echo-2"})
+
+        # 1 message per listener
+        predicate = lambda: self._brokers[1].fanout_sent_count == 2
+        _wait_until(predicate, 30)
+        self.assertTrue(predicate())
+
+        listener1.join(timeout=30)
+        listener2.join(timeout=30)
+        self.assertFalse(listener1.isAlive() or listener2.isAlive())
+
+        self._brokers[1].stop()
+        driver.cleanup()
+
 
 class FakeBroker(threading.Thread):
     """A test AMQP message 'broker'."""
@@ -638,12 +700,16 @@ class FakeBroker(threading.Thread):
 
             # Pyngus ConnectionEventHandler callbacks:
 
+            def connection_active(self, connection):
+                self.server.connection_count += 1
+
             def connection_remote_closed(self, connection, reason):
                 """Peer has closed the connection."""
                 self.connection.close()
 
             def connection_closed(self, connection):
                 """Connection close completed."""
+                self.server.connection_count -= 1
                 self.closed = True  # main loop will destroy
 
             def connection_failed(self, connection, error):
@@ -712,6 +778,7 @@ class FakeBroker(threading.Thread):
             # Pyngus SenderEventHandler callbacks:
 
             def sender_active(self, sender_link):
+                self.server.sender_link_count += 1
                 self.server.add_route(self.link.source_address, self)
                 self.routed = True
 
@@ -720,6 +787,7 @@ class FakeBroker(threading.Thread):
                 self.link.close()
 
             def sender_closed(self, sender_link):
+                self.server.sender_link_count -= 1
                 self.destroy()
 
         class ReceiverLink(pyngus.ReceiverEventHandler):
@@ -746,10 +814,14 @@ class FakeBroker(threading.Thread):
 
             # ReceiverEventHandler callbacks:
 
+            def receiver_active(self, receiver_link):
+                self.server.receiver_link_count += 1
+
             def receiver_remote_closed(self, receiver_link, error):
                 self.link.close()
 
             def receiver_closed(self, receiver_link):
+                self.server.receiver_link_count -= 1
                 self.destroy()
 
             def message_received(self, receiver_link, message, handle):
@@ -795,7 +867,12 @@ class FakeBroker(threading.Thread):
         self.direct_count = 0
         self.topic_count = 0
         self.fanout_count = 0
+        self.fanout_sent_count = 0
         self.dropped_count = 0
+        # counts for active links and connections:
+        self.connection_count = 0
+        self.sender_link_count = 0
+        self.receiver_link_count = 0
 
     def start(self):
         """Start the server."""
@@ -907,6 +984,7 @@ class FakeBroker(threading.Thread):
         if dest.startswith(self._broadcast_prefix):
             self.fanout_count += 1
             for link in self._sources[dest]:
+                self.fanout_sent_count += 1
                 LOG.debug("Broadcast to %s", dest)
                 link.send_message(message)
         elif dest.startswith(self._group_prefix):