diff --git a/oslo_messaging/_drivers/impl_rabbit.py b/oslo_messaging/_drivers/impl_rabbit.py
index 0f57a9279..61905d466 100644
--- a/oslo_messaging/_drivers/impl_rabbit.py
+++ b/oslo_messaging/_drivers/impl_rabbit.py
@@ -461,6 +461,7 @@ class Connection(object):
         # NOTE(sileht): if purpose is PURPOSE_LISTEN
         # the consume code does the heartbeat stuff
         # we don't need a thread
+        self._heartbeat_thread = None
         if purpose == rpc_amqp.PURPOSE_SEND:
             self._heartbeat_start()
 
diff --git a/oslo_messaging/tests/drivers/test_impl_rabbit.py b/oslo_messaging/tests/drivers/test_impl_rabbit.py
index c8f58bdbe..7008f23f5 100644
--- a/oslo_messaging/tests/drivers/test_impl_rabbit.py
+++ b/oslo_messaging/tests/drivers/test_impl_rabbit.py
@@ -261,28 +261,26 @@ class TestRabbitConsume(test_utils.BaseTestCase):
                                                  'kombu+memory:////')
         self.addCleanup(transport.cleanup)
         channel = mock.Mock()
-        conn = transport._driver._get_connection(amqp.PURPOSE_LISTEN
-                                                 ).connection
-        conn.connection.recoverable_channel_errors = (IOError,)
-        with mock.patch.object(conn.connection, 'channel',
-                               side_effect=[IOError, IOError, channel]):
-            conn.reset()
-            self.assertEqual(channel, conn.channel)
+        with transport._driver._get_connection(amqp.PURPOSE_LISTEN) as conn:
+            conn.connection.connection.recoverable_channel_errors = (IOError,)
+            with mock.patch.object(conn.connection.connection, 'channel',
+                                   side_effect=[IOError, IOError, channel]):
+                conn.connection.reset()
+                self.assertEqual(channel, conn.connection.channel)
 
     def test_connection_ack_have_disconnected_kombu_connection(self):
         transport = oslo_messaging.get_transport(self.conf,
                                                  'kombu+memory:////')
         self.addCleanup(transport.cleanup)
-        conn = transport._driver._get_connection(amqp.PURPOSE_LISTEN
-                                                 ).connection
-        channel = conn.channel
-        with mock.patch('kombu.connection.Connection.connected',
-                        new_callable=mock.PropertyMock,
-                        return_value=False):
-            self.assertRaises(driver_common.Timeout,
-                              conn.consume, timeout=0.01)
-            # Ensure a new channel have been setuped
-            self.assertNotEqual(channel, conn.channel)
+        with transport._driver._get_connection(amqp.PURPOSE_LISTEN) as conn:
+            channel = conn.connection.channel
+            with mock.patch('kombu.connection.Connection.connected',
+                            new_callable=mock.PropertyMock,
+                            return_value=False):
+                self.assertRaises(driver_common.Timeout,
+                                  conn.connection.consume, timeout=0.01)
+                # Ensure a new channel have been setuped
+                self.assertNotEqual(channel, conn.connection.channel)
 
 
 class TestRabbitTransportURL(test_utils.BaseTestCase):