diff --git a/oslo_messaging/_drivers/base.py b/oslo_messaging/_drivers/base.py
index d183614f9..c09ab6f42 100644
--- a/oslo_messaging/_drivers/base.py
+++ b/oslo_messaging/_drivers/base.py
@@ -46,27 +46,15 @@ def batch_poll_helper(func):
         driver_prefetch = in_self.prefetch_size
         if driver_prefetch > 0:
             batch_size = min(batch_size, driver_prefetch)
+        timeout = batch_timeout or timeout
 
-        with timeutils.StopWatch(timeout) as timeout_watch:
-            # poll first message
-            msg = func(in_self, timeout=timeout_watch.leftover(True))
-            if msg is not None:
-                incomings.append(msg)
-            if batch_size == 1 or msg is None:
-                return incomings
-
-            # update batch_timeout according to timeout for whole operation
-            timeout_left = timeout_watch.leftover(True)
-            if timeout_left is not None and (
-                    batch_timeout is None or timeout_left < batch_timeout):
-                batch_timeout = timeout_left
-
-        with timeutils.StopWatch(batch_timeout) as batch_timeout_watch:
-            # poll remained batch messages
-            while len(incomings) < batch_size and msg is not None:
-                msg = func(in_self, timeout=batch_timeout_watch.leftover(True))
-                if msg is not None:
-                    incomings.append(msg)
+        with timeutils.StopWatch(timeout) as watch:
+            while True:
+                message = func(in_self, timeout=watch.leftover(True))
+                if message is not None:
+                    incomings.append(message)
+                if len(incomings) == batch_size or message is None:
+                    break
 
         return incomings
     return wrapper