diff --git a/neutron/plugins/ml2/driver_api.py b/neutron/plugins/ml2/driver_api.py
index c1787c84d22..2d87c05f396 100644
--- a/neutron/plugins/ml2/driver_api.py
+++ b/neutron/plugins/ml2/driver_api.py
@@ -36,22 +36,7 @@ BOUND_SEGMENT = 'bound_segment'
 
 
 @six.add_metaclass(abc.ABCMeta)
-class TypeDriver(object):
-    """Define stable abstract interface for ML2 type drivers.
-
-    ML2 type drivers each support a specific network_type for provider
-    and/or tenant network segments. Type drivers must implement this
-    abstract interface, which defines the API by which the plugin uses
-    the driver to manage the persistent type-specific resource
-    allocation state associated with network segments of that type.
-
-    Network segments are represented by segment dictionaries using the
-    NETWORK_TYPE, PHYSICAL_NETWORK, and SEGMENTATION_ID keys defined
-    above, corresponding to the provider attributes.  Future revisions
-    of the TypeDriver API may add additional segment dictionary
-    keys. Attributes not applicable for a particular network_type may
-    either be excluded or stored as None.
-    """
+class _TypeDriverBase(object):
 
     @abc.abstractmethod
     def get_type(self):
@@ -99,6 +84,42 @@ class TypeDriver(object):
         """
         pass
 
+    @abc.abstractmethod
+    def get_mtu(self, physical):
+        """Get driver's network MTU.
+
+        :returns mtu: maximum transmission unit
+
+        Returns the mtu for the network based on the config values and
+        the network type.
+        """
+        pass
+
+
+@six.add_metaclass(abc.ABCMeta)
+class TypeDriver(_TypeDriverBase):
+    """Define abstract interface for ML2 type drivers.
+
+    ML2 type drivers each support a specific network_type for provider
+    and/or tenant network segments. Type drivers must implement this
+    abstract interface, which defines the API by which the plugin uses
+    the driver to manage the persistent type-specific resource
+    allocation state associated with network segments of that type.
+
+    Network segments are represented by segment dictionaries using the
+    NETWORK_TYPE, PHYSICAL_NETWORK, and SEGMENTATION_ID keys defined
+    above, corresponding to the provider attributes.  Future revisions
+    of the TypeDriver API may add additional segment dictionary
+    keys. Attributes not applicable for a particular network_type may
+    either be excluded or stored as None.
+
+    TypeDriver passes session as argument for:
+    - reserve_provider_segment
+    - allocate_tenant_segment
+    - release_segment
+    - get_allocation
+    """
+
     @abc.abstractmethod
     def reserve_provider_segment(self, session, segment):
         """Reserve resource associated with a provider network segment.
@@ -144,14 +165,73 @@ class TypeDriver(object):
         """
         pass
 
+
+@six.add_metaclass(abc.ABCMeta)
+class ML2TypeDriver(_TypeDriverBase):
+    """Define abstract interface for ML2 type drivers.
+
+    ML2 type drivers each support a specific network_type for provider
+    and/or tenant network segments. Type drivers must implement this
+    abstract interface, which defines the API by which the plugin uses
+    the driver to manage the persistent type-specific resource
+    allocation state associated with network segments of that type.
+
+    Network segments are represented by segment dictionaries using the
+    NETWORK_TYPE, PHYSICAL_NETWORK, and SEGMENTATION_ID keys defined
+    above, corresponding to the provider attributes.  Future revisions
+    of the TypeDriver API may add additional segment dictionary
+    keys. Attributes not applicable for a particular network_type may
+    either be excluded or stored as None.
+
+    ML2TypeDriver passes context as argument for:
+    - reserve_provider_segment
+    - allocate_tenant_segment
+    - release_segment
+    - get_allocation
+    """
+
     @abc.abstractmethod
-    def get_mtu(self, physical):
-        """Get driver's network MTU.
+    def reserve_provider_segment(self, context, segment):
+        """Reserve resource associated with a provider network segment.
 
-        :returns mtu: maximum transmission unit
+        :param context: instance of neutron context with DB session
+        :param segment: segment dictionary
+        :returns: segment dictionary
 
-        Returns the mtu for the network based on the config values and
-        the network type.
+        Called inside transaction context on session to reserve the
+        type-specific resource for a provider network segment. The
+        segment dictionary passed in was returned by a previous
+        validate_provider_segment() call.
+        """
+        pass
+
+    @abc.abstractmethod
+    def allocate_tenant_segment(self, context):
+        """Allocate resource for a new tenant network segment.
+
+        :param context: instance of neutron context with DB session
+        :returns: segment dictionary using keys defined above
+
+        Called inside transaction context on session to allocate a new
+        tenant network, typically from a type-specific resource
+        pool. If successful, return a segment dictionary describing
+        the segment. If tenant network segment cannot be allocated
+        (i.e. tenant networks not supported or resource pool is
+        exhausted), return None.
+        """
+        pass
+
+    @abc.abstractmethod
+    def release_segment(self, context, segment):
+        """Release network segment.
+
+        :param context: instance of neutron context with DB session
+        :param segment: segment dictionary using keys defined above
+
+        Called inside transaction context on session to release a
+        tenant or provider network's type-specific resource. Runtime
+        errors are not expected, but raising an exception will result
+        in rollback of the transaction.
         """
         pass
 
diff --git a/neutron/plugins/ml2/driver_context.py b/neutron/plugins/ml2/driver_context.py
index afc73911615..42ba876e462 100644
--- a/neutron/plugins/ml2/driver_context.py
+++ b/neutron/plugins/ml2/driver_context.py
@@ -263,4 +263,4 @@ class PortContext(MechanismDriverContext, api.PortContext):
 
     def release_dynamic_segment(self, segment_id):
         return self._plugin.type_manager.release_dynamic_segment(
-                self._plugin_context.session, segment_id)
+                self._plugin_context, segment_id)
diff --git a/neutron/plugins/ml2/drivers/helpers.py b/neutron/plugins/ml2/drivers/helpers.py
index 331ceadbaa4..c58e5bbec24 100644
--- a/neutron/plugins/ml2/drivers/helpers.py
+++ b/neutron/plugins/ml2/drivers/helpers.py
@@ -31,7 +31,7 @@ LOG = log.getLogger(__name__)
 IDPOOL_SELECT_SIZE = 100
 
 
-class BaseTypeDriver(api.TypeDriver):
+class BaseTypeDriver(api.ML2TypeDriver):
     """BaseTypeDriver for functions common to Segment and flat."""
 
     def __init__(self):
@@ -60,7 +60,7 @@ class SegmentTypeDriver(BaseTypeDriver):
         self.primary_keys = set(dict(model.__table__.columns))
         self.primary_keys.remove("allocated")
 
-    def allocate_fully_specified_segment(self, session, **raw_segment):
+    def allocate_fully_specified_segment(self, context, **raw_segment):
         """Allocate segment fully specified by raw_segment.
 
         If segment exists, then try to allocate it and return db object
@@ -70,9 +70,10 @@ class SegmentTypeDriver(BaseTypeDriver):
 
         network_type = self.get_type()
         try:
-            with session.begin(subtransactions=True):
-                alloc = (session.query(self.model).filter_by(**raw_segment).
-                         first())
+            with context.session.begin(subtransactions=True):
+                alloc = (
+                    context.session.query(self.model).filter_by(**raw_segment).
+                    first())
                 if alloc:
                     if alloc.allocated:
                         # Segment already allocated
@@ -83,7 +84,7 @@ class SegmentTypeDriver(BaseTypeDriver):
                                   "started ",
                                   {"type": network_type,
                                    "segment": raw_segment})
-                        count = (session.query(self.model).
+                        count = (context.session.query(self.model).
                                  filter_by(allocated=False, **raw_segment).
                                  update({"allocated": True}))
                         if count:
@@ -104,7 +105,7 @@ class SegmentTypeDriver(BaseTypeDriver):
                 LOG.debug("%(type)s segment %(segment)s create started",
                           {"type": network_type, "segment": raw_segment})
                 alloc = self.model(allocated=True, **raw_segment)
-                alloc.save(session)
+                alloc.save(context.session)
                 LOG.debug("%(type)s segment %(segment)s create done",
                           {"type": network_type, "segment": raw_segment})
 
@@ -116,15 +117,15 @@ class SegmentTypeDriver(BaseTypeDriver):
 
         return alloc
 
-    def allocate_partially_specified_segment(self, session, **filters):
+    def allocate_partially_specified_segment(self, context, **filters):
         """Allocate model segment from pool partially specified by filters.
 
         Return allocated db object or None.
         """
 
         network_type = self.get_type()
-        with session.begin(subtransactions=True):
-            select = (session.query(self.model).
+        with context.session.begin(subtransactions=True):
+            select = (context.session.query(self.model).
                       filter_by(allocated=False, **filters))
 
             # Selected segment can be allocated before update by someone else,
@@ -140,7 +141,7 @@ class SegmentTypeDriver(BaseTypeDriver):
                       "started with %(segment)s ",
                       {"type": network_type,
                        "segment": raw_segment})
-            count = (session.query(self.model).
+            count = (context.session.query(self.model).
                      filter_by(allocated=False, **raw_segment).
                      update({"allocated": True}))
             if count:
diff --git a/neutron/plugins/ml2/drivers/type_flat.py b/neutron/plugins/ml2/drivers/type_flat.py
index e369c90c267..78716a6c768 100644
--- a/neutron/plugins/ml2/drivers/type_flat.py
+++ b/neutron/plugins/ml2/drivers/type_flat.py
@@ -90,29 +90,29 @@ class FlatTypeDriver(helpers.BaseTypeDriver):
                 msg = _("%s prohibited for flat provider network") % key
                 raise exc.InvalidInput(error_message=msg)
 
-    def reserve_provider_segment(self, session, segment):
+    def reserve_provider_segment(self, context, segment):
         physical_network = segment[api.PHYSICAL_NETWORK]
-        with session.begin(subtransactions=True):
+        with context.session.begin(subtransactions=True):
             try:
                 LOG.debug("Reserving flat network on physical "
                           "network %s", physical_network)
                 alloc = type_flat_model.FlatAllocation(
                     physical_network=physical_network)
-                alloc.save(session)
+                alloc.save(context.session)
             except db_exc.DBDuplicateEntry:
                 raise n_exc.FlatNetworkInUse(
                     physical_network=physical_network)
             segment[api.MTU] = self.get_mtu(alloc.physical_network)
         return segment
 
-    def allocate_tenant_segment(self, session):
+    def allocate_tenant_segment(self, context):
         # Tenant flat networks are not supported.
         return
 
-    def release_segment(self, session, segment):
+    def release_segment(self, context, segment):
         physical_network = segment[api.PHYSICAL_NETWORK]
-        with session.begin(subtransactions=True):
-            count = (session.query(type_flat_model.FlatAllocation).
+        with context.session.begin(subtransactions=True):
+            count = (context.session.query(type_flat_model.FlatAllocation).
                      filter_by(physical_network=physical_network).
                      delete())
         if count:
diff --git a/neutron/plugins/ml2/drivers/type_local.py b/neutron/plugins/ml2/drivers/type_local.py
index 12f3f9e4883..93335c1c342 100644
--- a/neutron/plugins/ml2/drivers/type_local.py
+++ b/neutron/plugins/ml2/drivers/type_local.py
@@ -24,7 +24,7 @@ from neutron.plugins.ml2 import driver_api as api
 LOG = log.getLogger(__name__)
 
 
-class LocalTypeDriver(api.TypeDriver):
+class LocalTypeDriver(api.ML2TypeDriver):
     """Manage state for local networks with ML2.
 
     The LocalTypeDriver implements the 'local' network_type. Local
@@ -52,15 +52,15 @@ class LocalTypeDriver(api.TypeDriver):
                 msg = _("%s prohibited for local provider network") % key
                 raise exc.InvalidInput(error_message=msg)
 
-    def reserve_provider_segment(self, session, segment):
+    def reserve_provider_segment(self, context, segment):
         # No resources to reserve
         return segment
 
-    def allocate_tenant_segment(self, session):
+    def allocate_tenant_segment(self, context):
         # No resources to allocate
         return {api.NETWORK_TYPE: p_const.TYPE_LOCAL}
 
-    def release_segment(self, session, segment):
+    def release_segment(self, context, segment):
         # No resources to release
         pass
 
diff --git a/neutron/plugins/ml2/drivers/type_tunnel.py b/neutron/plugins/ml2/drivers/type_tunnel.py
index 31abab5f16e..5fa97751e75 100644
--- a/neutron/plugins/ml2/drivers/type_tunnel.py
+++ b/neutron/plugins/ml2/drivers/type_tunnel.py
@@ -49,16 +49,12 @@ def chunks(iterable, chunk_size):
 
 
 @six.add_metaclass(abc.ABCMeta)
-class TunnelTypeDriver(helpers.SegmentTypeDriver):
-    """Define stable abstract interface for ML2 type drivers.
+class _TunnelTypeDriverBase(helpers.SegmentTypeDriver):
 
-    tunnel type networks rely on tunnel endpoints. This class defines abstract
-    methods to manage these endpoints.
-    """
     BULK_SIZE = 100
 
     def __init__(self, model):
-        super(TunnelTypeDriver, self).__init__(model)
+        super(_TunnelTypeDriverBase, self).__init__(model)
         self.segmentation_key = next(iter(self.primary_keys))
 
     @abc.abstractmethod
@@ -193,6 +189,32 @@ class TunnelTypeDriver(helpers.SegmentTypeDriver):
                        {'key': key, 'tunnel': segment.get(api.NETWORK_TYPE)})
                 raise exc.InvalidInput(error_message=msg)
 
+    def get_mtu(self, physical_network=None):
+        seg_mtu = super(_TunnelTypeDriverBase, self).get_mtu()
+        mtu = []
+        if seg_mtu > 0:
+            mtu.append(seg_mtu)
+        if cfg.CONF.ml2.path_mtu > 0:
+            mtu.append(cfg.CONF.ml2.path_mtu)
+        version = cfg.CONF.ml2.overlay_ip_version
+        ip_header_length = p_const.IP_HEADER_LENGTH[version]
+        return min(mtu) - ip_header_length if mtu else 0
+
+
+@six.add_metaclass(abc.ABCMeta)
+class TunnelTypeDriver(_TunnelTypeDriverBase):
+    """Define stable abstract interface for ML2 type drivers.
+
+    tunnel type networks rely on tunnel endpoints. This class defines abstract
+    methods to manage these endpoints.
+
+    ML2 type driver that passes session to functions:
+    - reserve_provider_segment
+    - allocate_tenant_segment
+    - release_segment
+    - get_allocation
+    """
+
     def reserve_provider_segment(self, session, segment):
         if self.is_partial_segment(segment):
             alloc = self.allocate_partially_specified_segment(session)
@@ -246,19 +268,76 @@ class TunnelTypeDriver(helpers.SegmentTypeDriver):
                 filter_by(**{self.segmentation_key: tunnel_id}).
                 first())
 
-    def get_mtu(self, physical_network=None):
-        seg_mtu = super(TunnelTypeDriver, self).get_mtu()
-        mtu = []
-        if seg_mtu > 0:
-            mtu.append(seg_mtu)
-        if cfg.CONF.ml2.path_mtu > 0:
-            mtu.append(cfg.CONF.ml2.path_mtu)
-        version = cfg.CONF.ml2.overlay_ip_version
-        ip_header_length = p_const.IP_HEADER_LENGTH[version]
-        return min(mtu) - ip_header_length if mtu else 0
+
+@six.add_metaclass(abc.ABCMeta)
+class ML2TunnelTypeDriver(_TunnelTypeDriverBase):
+    """Define stable abstract interface for ML2 type drivers.
+
+    tunnel type networks rely on tunnel endpoints. This class defines abstract
+    methods to manage these endpoints.
+
+    ML2 type driver that passes context as argument to functions:
+    - reserve_provider_segment
+    - allocate_tenant_segment
+    - release_segment
+    - get_allocation
+    """
+
+    def reserve_provider_segment(self, context, segment):
+        if self.is_partial_segment(segment):
+            alloc = self.allocate_partially_specified_segment(context)
+            if not alloc:
+                raise exc.NoNetworkAvailable()
+        else:
+            segmentation_id = segment.get(api.SEGMENTATION_ID)
+            alloc = self.allocate_fully_specified_segment(
+                context, **{self.segmentation_key: segmentation_id})
+            if not alloc:
+                raise exc.TunnelIdInUse(tunnel_id=segmentation_id)
+        return {api.NETWORK_TYPE: self.get_type(),
+                api.PHYSICAL_NETWORK: None,
+                api.SEGMENTATION_ID: getattr(alloc, self.segmentation_key),
+                api.MTU: self.get_mtu()}
+
+    def allocate_tenant_segment(self, context):
+        alloc = self.allocate_partially_specified_segment(context)
+        if not alloc:
+            return
+        return {api.NETWORK_TYPE: self.get_type(),
+                api.PHYSICAL_NETWORK: None,
+                api.SEGMENTATION_ID: getattr(alloc, self.segmentation_key),
+                api.MTU: self.get_mtu()}
+
+    def release_segment(self, context, segment):
+        tunnel_id = segment[api.SEGMENTATION_ID]
+
+        inside = any(lo <= tunnel_id <= hi for lo, hi in self.tunnel_ranges)
+
+        info = {'type': self.get_type(), 'id': tunnel_id}
+        with context.session.begin(subtransactions=True):
+            query = (context.session.query(self.model).
+                     filter_by(**{self.segmentation_key: tunnel_id}))
+            if inside:
+                count = query.update({"allocated": False})
+                if count:
+                    LOG.debug("Releasing %(type)s tunnel %(id)s to pool",
+                              info)
+            else:
+                count = query.delete()
+                if count:
+                    LOG.debug("Releasing %(type)s tunnel %(id)s outside pool",
+                              info)
+
+        if not count:
+            LOG.warning(_LW("%(type)s tunnel %(id)s not found"), info)
+
+    def get_allocation(self, context, tunnel_id):
+        return (context.session.query(self.model).
+                filter_by(**{self.segmentation_key: tunnel_id}).
+                first())
 
 
-class EndpointTunnelTypeDriver(TunnelTypeDriver):
+class EndpointTunnelTypeDriver(ML2TunnelTypeDriver):
 
     def __init__(self, segment_model, endpoint_model):
         super(EndpointTunnelTypeDriver, self).__init__(segment_model)
diff --git a/neutron/plugins/ml2/drivers/type_vlan.py b/neutron/plugins/ml2/drivers/type_vlan.py
index 0b31e1b90a6..f52c0d1997b 100644
--- a/neutron/plugins/ml2/drivers/type_vlan.py
+++ b/neutron/plugins/ml2/drivers/type_vlan.py
@@ -161,7 +161,7 @@ class VlanTypeDriver(helpers.SegmentTypeDriver):
                 msg = _("%s prohibited for VLAN provider network") % key
                 raise exc.InvalidInput(error_message=msg)
 
-    def reserve_provider_segment(self, session, segment):
+    def reserve_provider_segment(self, context, segment):
         filters = {}
         physical_network = segment.get(api.PHYSICAL_NETWORK)
         if physical_network is not None:
@@ -172,12 +172,12 @@ class VlanTypeDriver(helpers.SegmentTypeDriver):
 
         if self.is_partial_segment(segment):
             alloc = self.allocate_partially_specified_segment(
-                session, **filters)
+                context, **filters)
             if not alloc:
                 raise exc.NoNetworkAvailable()
         else:
             alloc = self.allocate_fully_specified_segment(
-                session, **filters)
+                context, **filters)
             if not alloc:
                 raise exc.VlanIdInUse(**filters)
 
@@ -186,8 +186,8 @@ class VlanTypeDriver(helpers.SegmentTypeDriver):
                 api.SEGMENTATION_ID: alloc.vlan_id,
                 api.MTU: self.get_mtu(alloc.physical_network)}
 
-    def allocate_tenant_segment(self, session):
-        alloc = self.allocate_partially_specified_segment(session)
+    def allocate_tenant_segment(self, context):
+        alloc = self.allocate_partially_specified_segment(context)
         if not alloc:
             return
         return {api.NETWORK_TYPE: p_const.TYPE_VLAN,
@@ -195,15 +195,15 @@ class VlanTypeDriver(helpers.SegmentTypeDriver):
                 api.SEGMENTATION_ID: alloc.vlan_id,
                 api.MTU: self.get_mtu(alloc.physical_network)}
 
-    def release_segment(self, session, segment):
+    def release_segment(self, context, segment):
         physical_network = segment[api.PHYSICAL_NETWORK]
         vlan_id = segment[api.SEGMENTATION_ID]
 
         ranges = self.network_vlan_ranges.get(physical_network, [])
         inside = any(lo <= vlan_id <= hi for lo, hi in ranges)
 
-        with session.begin(subtransactions=True):
-            query = (session.query(vlan_alloc_model.VlanAllocation).
+        with context.session.begin(subtransactions=True):
+            query = (context.session.query(vlan_alloc_model.VlanAllocation).
                      filter_by(physical_network=physical_network,
                                vlan_id=vlan_id))
             if inside:
diff --git a/neutron/plugins/ml2/managers.py b/neutron/plugins/ml2/managers.py
index 60489212f3c..6ae90c57eff 100644
--- a/neutron/plugins/ml2/managers.py
+++ b/neutron/plugins/ml2/managers.py
@@ -197,18 +197,18 @@ class TypeManager(stevedore.named.NamedExtensionManager):
             if segments:
                 for segment_index, segment in enumerate(segments):
                     segment = self.reserve_provider_segment(
-                        session, segment)
+                        context, segment)
                     self._add_network_segment(context, network_id, segment,
                                               segment_index)
             elif (cfg.CONF.ml2.external_network_type and
                   self._get_attribute(network, external_net.EXTERNAL)):
-                segment = self._allocate_ext_net_segment(session)
+                segment = self._allocate_ext_net_segment(context)
                 self._add_network_segment(context, network_id, segment)
             else:
-                segment = self._allocate_tenant_net_segment(session)
+                segment = self._allocate_tenant_net_segment(context)
                 self._add_network_segment(context, network_id, segment)
 
-    def reserve_network_segment(self, session, segment_data):
+    def reserve_network_segment(self, context, segment_data):
         """Call type drivers to reserve a network segment."""
         # Validate the data of segment
         if not validators.is_attr_set(segment_data[api.NETWORK_TYPE]):
@@ -225,8 +225,8 @@ class TypeManager(stevedore.named.NamedExtensionManager):
         self.validate_provider_segment(segment)
 
         # Reserve segment in type driver
-        with session.begin(subtransactions=True):
-            return self.reserve_provider_segment(session, segment)
+        with context.session.begin(subtransactions=True):
+            return self.reserve_provider_segment(context, segment)
 
     def is_partial_segment(self, segment):
         network_type = segment[api.NETWORK_TYPE]
@@ -246,41 +246,53 @@ class TypeManager(stevedore.named.NamedExtensionManager):
             msg = _("network_type value '%s' not supported") % network_type
             raise exc.InvalidInput(error_message=msg)
 
-    def reserve_provider_segment(self, session, segment):
+    def reserve_provider_segment(self, context, segment):
         network_type = segment.get(api.NETWORK_TYPE)
         driver = self.drivers.get(network_type)
-        return driver.obj.reserve_provider_segment(session, segment)
+        if isinstance(driver.obj, api.TypeDriver):
+            return driver.obj.reserve_provider_segment(context.session,
+                                                       segment)
+        else:
+            return driver.obj.reserve_provider_segment(context,
+                                                       segment)
 
-    def _allocate_segment(self, session, network_type):
+    def _allocate_segment(self, context, network_type):
         driver = self.drivers.get(network_type)
-        return driver.obj.allocate_tenant_segment(session)
+        if isinstance(driver.obj, api.TypeDriver):
+            return driver.obj.allocate_tenant_segment(context.session)
+        else:
+            return driver.obj.allocate_tenant_segment(context)
 
-    def _allocate_tenant_net_segment(self, session):
+    def _allocate_tenant_net_segment(self, context):
         for network_type in self.tenant_network_types:
-            segment = self._allocate_segment(session, network_type)
+            segment = self._allocate_segment(context, network_type)
             if segment:
                 return segment
         raise exc.NoNetworkAvailable()
 
-    def _allocate_ext_net_segment(self, session):
+    def _allocate_ext_net_segment(self, context):
         network_type = cfg.CONF.ml2.external_network_type
-        segment = self._allocate_segment(session, network_type)
+        segment = self._allocate_segment(context, network_type)
         if segment:
             return segment
         raise exc.NoNetworkAvailable()
 
-    def release_network_segments(self, session, network_id):
-        segments = segments_db.get_network_segments(session, network_id,
+    def release_network_segments(self, context, network_id):
+        segments = segments_db.get_network_segments(context.session,
+                                                    network_id,
                                                     filter_dynamic=None)
 
         for segment in segments:
-            self.release_network_segment(session, segment)
+            self.release_network_segment(context, segment)
 
-    def release_network_segment(self, session, segment):
+    def release_network_segment(self, context, segment):
         network_type = segment.get(api.NETWORK_TYPE)
         driver = self.drivers.get(network_type)
         if driver:
-            driver.obj.release_segment(session, segment)
+            if isinstance(driver.obj, api.TypeDriver):
+                driver.obj.release_segment(context.session, segment)
+            else:
+                driver.obj.release_segment(context, segment)
         else:
             LOG.error(_LE("Failed to release segment '%s' because "
                           "network type is not supported."), segment)
@@ -295,20 +307,28 @@ class TypeManager(stevedore.named.NamedExtensionManager):
             return dynamic_segment
 
         driver = self.drivers.get(segment.get(api.NETWORK_TYPE))
-        dynamic_segment = driver.obj.reserve_provider_segment(context.session,
-                                                              segment)
-        segments_db.add_network_segment(context, network_id, dynamic_segment,
+        if isinstance(driver.obj, api.TypeDriver):
+            dynamic_segment = driver.obj.reserve_provider_segment(
+                context.session, segment)
+        else:
+            dynamic_segment = driver.obj.reserve_provider_segment(
+                context, segment)
+        segments_db.add_network_segment(context,
+                                        network_id, dynamic_segment,
                                         is_dynamic=True)
         return dynamic_segment
 
-    def release_dynamic_segment(self, session, segment_id):
+    def release_dynamic_segment(self, context, segment_id):
         """Delete a dynamic segment."""
-        segment = segments_db.get_segment_by_id(session, segment_id)
+        segment = segments_db.get_segment_by_id(context.session, segment_id)
         if segment:
             driver = self.drivers.get(segment.get(api.NETWORK_TYPE))
             if driver:
-                driver.obj.release_segment(session, segment)
-                segments_db.delete_network_segment(session, segment_id)
+                if isinstance(driver.obj, api.TypeDriver):
+                    driver.obj.release_segment(context.session, segment)
+                else:
+                    driver.obj.release_segment(context, segment)
+                segments_db.delete_network_segment(context.session, segment_id)
             else:
                 LOG.error(_LE("Failed to release segment '%s' because "
                               "network type is not supported."), segment)
diff --git a/neutron/plugins/ml2/plugin.py b/neutron/plugins/ml2/plugin.py
index 3690ca2648e..fd06bc76ea5 100644
--- a/neutron/plugins/ml2/plugin.py
+++ b/neutron/plugins/ml2/plugin.py
@@ -1868,17 +1868,16 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
             # by unifying segment creation procedure.
             return
 
-        session = context.session
         network_id = segment.get('network_id')
 
         if event == events.PRECOMMIT_CREATE:
             updated_segment = self.type_manager.reserve_network_segment(
-                session, segment)
+                context, segment)
             # The segmentation id might be from ML2 type driver, update it
             # in the original segment.
             segment[api.SEGMENTATION_ID] = updated_segment[api.SEGMENTATION_ID]
         elif event == events.PRECOMMIT_DELETE:
-            self.type_manager.release_network_segment(session, segment)
+            self.type_manager.release_network_segment(context, segment)
 
         try:
             self._notify_mechanism_driver_for_segment_change(
diff --git a/neutron/tests/unit/plugins/ml2/drivers/base_type_tunnel.py b/neutron/tests/unit/plugins/ml2/drivers/base_type_tunnel.py
index 82900258c42..80a33fb176c 100644
--- a/neutron/tests/unit/plugins/ml2/drivers/base_type_tunnel.py
+++ b/neutron/tests/unit/plugins/ml2/drivers/base_type_tunnel.py
@@ -19,7 +19,7 @@ from six import moves
 import testtools
 from testtools import matchers
 
-from neutron.db import api as db
+from neutron import context
 from neutron.plugins.common import constants as p_const
 from neutron.plugins.ml2 import config
 from neutron.plugins.ml2 import driver_api as api
@@ -45,7 +45,7 @@ class TunnelTypeTestMixin(object):
         self.driver = self.DRIVER_CLASS()
         self.driver.tunnel_ranges = TUNNEL_RANGES
         self.driver.sync_allocations()
-        self.session = db.get_session()
+        self.context = context.Context()
 
     def test_tunnel_type(self):
         self.assertEqual(self.TYPE, self.driver.get_type())
@@ -66,47 +66,47 @@ class TunnelTypeTestMixin(object):
 
     def test_sync_tunnel_allocations(self):
         self.assertIsNone(
-            self.driver.get_allocation(self.session, (TUN_MIN - 1)))
+            self.driver.get_allocation(self.context, (TUN_MIN - 1)))
         self.assertFalse(
-            self.driver.get_allocation(self.session, (TUN_MIN)).allocated)
+            self.driver.get_allocation(self.context, (TUN_MIN)).allocated)
         self.assertFalse(
-            self.driver.get_allocation(self.session, (TUN_MIN + 1)).allocated)
+            self.driver.get_allocation(self.context, (TUN_MIN + 1)).allocated)
         self.assertFalse(
-            self.driver.get_allocation(self.session, (TUN_MAX - 1)).allocated)
+            self.driver.get_allocation(self.context, (TUN_MAX - 1)).allocated)
         self.assertFalse(
-            self.driver.get_allocation(self.session, (TUN_MAX)).allocated)
+            self.driver.get_allocation(self.context, (TUN_MAX)).allocated)
         self.assertIsNone(
-            self.driver.get_allocation(self.session, (TUN_MAX + 1)))
+            self.driver.get_allocation(self.context, (TUN_MAX + 1)))
 
         self.driver.tunnel_ranges = UPDATED_TUNNEL_RANGES
         self.driver.sync_allocations()
 
         self.assertIsNone(
-            self.driver.get_allocation(self.session, (TUN_MIN + 5 - 1)))
+            self.driver.get_allocation(self.context, (TUN_MIN + 5 - 1)))
         self.assertFalse(
-            self.driver.get_allocation(self.session, (TUN_MIN + 5)).allocated)
+            self.driver.get_allocation(self.context, (TUN_MIN + 5)).allocated)
         self.assertFalse(
-            self.driver.get_allocation(self.session,
+            self.driver.get_allocation(self.context,
                                        (TUN_MIN + 5 + 1)).allocated)
         self.assertFalse(
-            self.driver.get_allocation(self.session,
+            self.driver.get_allocation(self.context,
                                        (TUN_MAX + 5 - 1)).allocated)
         self.assertFalse(
-            self.driver.get_allocation(self.session, (TUN_MAX + 5)).allocated)
+            self.driver.get_allocation(self.context, (TUN_MAX + 5)).allocated)
         self.assertIsNone(
-            self.driver.get_allocation(self.session, (TUN_MAX + 5 + 1)))
+            self.driver.get_allocation(self.context, (TUN_MAX + 5 + 1)))
 
     def _test_sync_allocations_and_allocated(self, tunnel_id):
         segment = {api.NETWORK_TYPE: self.TYPE,
                    api.PHYSICAL_NETWORK: None,
                    api.SEGMENTATION_ID: tunnel_id}
-        self.driver.reserve_provider_segment(self.session, segment)
+        self.driver.reserve_provider_segment(self.context, segment)
 
         self.driver.tunnel_ranges = UPDATED_TUNNEL_RANGES
         self.driver.sync_allocations()
 
         self.assertTrue(
-            self.driver.get_allocation(self.session, tunnel_id).allocated)
+            self.driver.get_allocation(self.context, tunnel_id).allocated)
 
     def test_sync_allocations_and_allocated_in_initial_range(self):
         self._test_sync_allocations_and_allocated(TUN_MIN + 2)
@@ -141,27 +141,27 @@ class TunnelTypeTestMixin(object):
         segment = {api.NETWORK_TYPE: self.TYPE,
                    api.PHYSICAL_NETWORK: None,
                    api.SEGMENTATION_ID: 101}
-        observed = self.driver.reserve_provider_segment(self.session, segment)
-        alloc = self.driver.get_allocation(self.session,
+        observed = self.driver.reserve_provider_segment(self.context, segment)
+        alloc = self.driver.get_allocation(self.context,
                                            observed[api.SEGMENTATION_ID])
         self.assertTrue(alloc.allocated)
 
         with testtools.ExpectedException(exc.TunnelIdInUse):
-            self.driver.reserve_provider_segment(self.session, segment)
+            self.driver.reserve_provider_segment(self.context, segment)
 
-        self.driver.release_segment(self.session, segment)
-        alloc = self.driver.get_allocation(self.session,
+        self.driver.release_segment(self.context, segment)
+        alloc = self.driver.get_allocation(self.context,
                                            observed[api.SEGMENTATION_ID])
         self.assertFalse(alloc.allocated)
 
         segment[api.SEGMENTATION_ID] = 1000
-        observed = self.driver.reserve_provider_segment(self.session, segment)
-        alloc = self.driver.get_allocation(self.session,
+        observed = self.driver.reserve_provider_segment(self.context, segment)
+        alloc = self.driver.get_allocation(self.context,
                                            observed[api.SEGMENTATION_ID])
         self.assertTrue(alloc.allocated)
 
-        self.driver.release_segment(self.session, segment)
-        alloc = self.driver.get_allocation(self.session,
+        self.driver.release_segment(self.context, segment)
+        alloc = self.driver.get_allocation(self.context,
                                            observed[api.SEGMENTATION_ID])
         self.assertIsNone(alloc)
 
@@ -172,7 +172,7 @@ class TunnelTypeTestMixin(object):
                  api.SEGMENTATION_ID: None}
 
         for x in moves.range(TUN_MIN, TUN_MAX + 1):
-            segment = self.driver.reserve_provider_segment(self.session,
+            segment = self.driver.reserve_provider_segment(self.context,
                                                            specs)
             self.assertEqual(self.TYPE, segment[api.NETWORK_TYPE])
             self.assertThat(segment[api.SEGMENTATION_ID],
@@ -182,14 +182,14 @@ class TunnelTypeTestMixin(object):
             tunnel_ids.add(segment[api.SEGMENTATION_ID])
 
         with testtools.ExpectedException(exc.NoNetworkAvailable):
-            segment = self.driver.reserve_provider_segment(self.session,
+            segment = self.driver.reserve_provider_segment(self.context,
                                                            specs)
 
         segment = {api.NETWORK_TYPE: self.TYPE,
                    api.PHYSICAL_NETWORK: 'None',
                    api.SEGMENTATION_ID: tunnel_ids.pop()}
-        self.driver.release_segment(self.session, segment)
-        segment = self.driver.reserve_provider_segment(self.session, specs)
+        self.driver.release_segment(self.context, segment)
+        segment = self.driver.reserve_provider_segment(self.context, specs)
         self.assertThat(segment[api.SEGMENTATION_ID],
                         matchers.GreaterThan(TUN_MIN - 1))
         self.assertThat(segment[api.SEGMENTATION_ID],
@@ -198,26 +198,26 @@ class TunnelTypeTestMixin(object):
 
         for tunnel_id in tunnel_ids:
             segment[api.SEGMENTATION_ID] = tunnel_id
-            self.driver.release_segment(self.session, segment)
+            self.driver.release_segment(self.context, segment)
 
     def test_allocate_tenant_segment(self):
         tunnel_ids = set()
         for x in moves.range(TUN_MIN, TUN_MAX + 1):
-            segment = self.driver.allocate_tenant_segment(self.session)
+            segment = self.driver.allocate_tenant_segment(self.context)
             self.assertThat(segment[api.SEGMENTATION_ID],
                             matchers.GreaterThan(TUN_MIN - 1))
             self.assertThat(segment[api.SEGMENTATION_ID],
                             matchers.LessThan(TUN_MAX + 1))
             tunnel_ids.add(segment[api.SEGMENTATION_ID])
 
-        segment = self.driver.allocate_tenant_segment(self.session)
+        segment = self.driver.allocate_tenant_segment(self.context)
         self.assertIsNone(segment)
 
         segment = {api.NETWORK_TYPE: self.TYPE,
                    api.PHYSICAL_NETWORK: 'None',
                    api.SEGMENTATION_ID: tunnel_ids.pop()}
-        self.driver.release_segment(self.session, segment)
-        segment = self.driver.allocate_tenant_segment(self.session)
+        self.driver.release_segment(self.context, segment)
+        segment = self.driver.allocate_tenant_segment(self.context)
         self.assertThat(segment[api.SEGMENTATION_ID],
                         matchers.GreaterThan(TUN_MIN - 1))
         self.assertThat(segment[api.SEGMENTATION_ID],
@@ -226,7 +226,7 @@ class TunnelTypeTestMixin(object):
 
         for tunnel_id in tunnel_ids:
             segment[api.SEGMENTATION_ID] = tunnel_id
-            self.driver.release_segment(self.session, segment)
+            self.driver.release_segment(self.context, segment)
 
     def add_endpoint(self, ip=TUNNEL_IP_ONE, host=HOST_ONE):
         return self.driver.add_endpoint(ip, host)
@@ -289,19 +289,19 @@ class TunnelTypeMultiRangeTestMixin(object):
         self.driver = self.DRIVER_CLASS()
         self.driver.tunnel_ranges = self.TUNNEL_MULTI_RANGES
         self.driver.sync_allocations()
-        self.session = db.get_session()
+        self.context = context.Context()
 
     def test_release_segment(self):
-        segments = [self.driver.allocate_tenant_segment(self.session)
+        segments = [self.driver.allocate_tenant_segment(self.context)
                     for i in range(4)]
 
         # Release them in random order. No special meaning.
         for i in (0, 2, 1, 3):
-            self.driver.release_segment(self.session, segments[i])
+            self.driver.release_segment(self.context, segments[i])
 
         for key in (self.TUN_MIN0, self.TUN_MAX0,
                     self.TUN_MIN1, self.TUN_MAX1):
-            alloc = self.driver.get_allocation(self.session, key)
+            alloc = self.driver.get_allocation(self.context, key)
             self.assertFalse(alloc.allocated)
 
 
diff --git a/neutron/tests/unit/plugins/ml2/drivers/test_helpers.py b/neutron/tests/unit/plugins/ml2/drivers/test_helpers.py
index 8f607344008..68485f0f558 100644
--- a/neutron/tests/unit/plugins/ml2/drivers/test_helpers.py
+++ b/neutron/tests/unit/plugins/ml2/drivers/test_helpers.py
@@ -17,7 +17,7 @@ import mock
 from oslo_db import exception as exc
 from sqlalchemy.orm import query
 
-import neutron.db.api as db
+from neutron import context
 from neutron.plugins.ml2.drivers import type_vlan
 from neutron.tests.unit import testlib_api
 
@@ -38,7 +38,7 @@ class HelpersTest(testlib_api.SqlTestCase):
         self.driver = type_vlan.VlanTypeDriver()
         self.driver.network_vlan_ranges = NETWORK_VLAN_RANGES
         self.driver._sync_vlan_allocations()
-        self.session = db.get_session()
+        self.context = context.get_admin_context()
 
     def check_raw_segment(self, expected, observed):
         for key, value in expected.items():
@@ -50,15 +50,15 @@ class HelpersTest(testlib_api.SqlTestCase):
 
     def test_allocate_specific_unallocated_segment_in_pools(self):
         expected = dict(physical_network=TENANT_NET, vlan_id=VLAN_MIN)
-        observed = self.driver.allocate_fully_specified_segment(self.session,
+        observed = self.driver.allocate_fully_specified_segment(self.context,
                                                                 **expected)
         self.check_raw_segment(expected, observed)
 
     def test_allocate_specific_allocated_segment_in_pools(self):
         raw_segment = dict(physical_network=TENANT_NET, vlan_id=VLAN_MIN)
-        self.driver.allocate_fully_specified_segment(self.session,
+        self.driver.allocate_fully_specified_segment(self.context,
                                                      **raw_segment)
-        observed = self.driver.allocate_fully_specified_segment(self.session,
+        observed = self.driver.allocate_fully_specified_segment(self.context,
                                                                 **raw_segment)
         self.assertIsNone(observed)
 
@@ -69,20 +69,20 @@ class HelpersTest(testlib_api.SqlTestCase):
         raw_segment = dict(physical_network=TENANT_NET, vlan_id=VLAN_MIN)
         with mock.patch.object(query.Query, 'update', return_value=0):
             observed = self.driver.allocate_fully_specified_segment(
-                self.session, **raw_segment)
+                self.context, **raw_segment)
             self.assertIsNone(observed)
 
     def test_allocate_specific_unallocated_segment_outside_pools(self):
         expected = dict(physical_network=TENANT_NET, vlan_id=VLAN_OUTSIDE)
-        observed = self.driver.allocate_fully_specified_segment(self.session,
+        observed = self.driver.allocate_fully_specified_segment(self.context,
                                                                 **expected)
         self.check_raw_segment(expected, observed)
 
     def test_allocate_specific_allocated_segment_outside_pools(self):
         raw_segment = dict(physical_network=TENANT_NET, vlan_id=VLAN_OUTSIDE)
-        self.driver.allocate_fully_specified_segment(self.session,
+        self.driver.allocate_fully_specified_segment(self.context,
                                                      **raw_segment)
-        observed = self.driver.allocate_fully_specified_segment(self.session,
+        observed = self.driver.allocate_fully_specified_segment(self.context,
                                                                 **raw_segment)
         self.assertIsNone(observed)
 
@@ -93,32 +93,32 @@ class HelpersTest(testlib_api.SqlTestCase):
         expected = dict(physical_network=TENANT_NET, vlan_id=VLAN_MIN)
         with mock.patch.object(self.driver.model, 'save'):
             observed = self.driver.allocate_fully_specified_segment(
-                self.session, **expected)
+                self.context, **expected)
             self.check_raw_segment(expected, observed)
 
     def test_allocate_partial_segment_without_filters(self):
         expected = dict(physical_network=TENANT_NET)
         observed = self.driver.allocate_partially_specified_segment(
-            self.session)
+            self.context)
         self.check_raw_segment(expected, observed)
 
     def test_allocate_partial_segment_with_filter(self):
         expected = dict(physical_network=TENANT_NET)
         observed = self.driver.allocate_partially_specified_segment(
-            self.session, **expected)
+            self.context, **expected)
         self.check_raw_segment(expected, observed)
 
     def test_allocate_partial_segment_no_resource_available(self):
         for i in range(VLAN_MIN, VLAN_MAX + 1):
-            self.driver.allocate_partially_specified_segment(self.session)
+            self.driver.allocate_partially_specified_segment(self.context)
         observed = self.driver.allocate_partially_specified_segment(
-            self.session)
+            self.context)
         self.assertIsNone(observed)
 
     def test_allocate_partial_segment_outside_pools(self):
         raw_segment = dict(physical_network='other_phys_net')
         observed = self.driver.allocate_partially_specified_segment(
-            self.session, **raw_segment)
+            self.context, **raw_segment)
         self.assertIsNone(observed)
 
     def test_allocate_partial_segment_first_attempt_fails(self):
@@ -127,7 +127,7 @@ class HelpersTest(testlib_api.SqlTestCase):
             self.assertRaises(
                 exc.RetryRequest,
                 self.driver.allocate_partially_specified_segment,
-                self.session, **expected)
+                self.context, **expected)
             observed = self.driver.allocate_partially_specified_segment(
-                self.session, **expected)
+                self.context, **expected)
             self.check_raw_segment(expected, observed)
diff --git a/neutron/tests/unit/plugins/ml2/drivers/test_type_flat.py b/neutron/tests/unit/plugins/ml2/drivers/test_type_flat.py
index 42e40a53398..3fa40bfff78 100644
--- a/neutron/tests/unit/plugins/ml2/drivers/test_type_flat.py
+++ b/neutron/tests/unit/plugins/ml2/drivers/test_type_flat.py
@@ -16,7 +16,7 @@
 from neutron_lib import exceptions as exc
 
 from neutron.common import exceptions as n_exc
-import neutron.db.api as db
+from neutron import context
 from neutron.db.models.plugins.ml2 import flatallocation as type_flat_model
 from neutron.plugins.common import constants as p_const
 from neutron.plugins.ml2 import config
@@ -36,11 +36,11 @@ class FlatTypeTest(testlib_api.SqlTestCase):
         config.cfg.CONF.set_override('flat_networks', FLAT_NETWORKS,
                               group='ml2_type_flat')
         self.driver = type_flat.FlatTypeDriver()
-        self.session = db.get_session()
+        self.context = context.Context()
         self.driver.physnet_mtus = []
 
-    def _get_allocation(self, session, segment):
-        return session.query(type_flat_model.FlatAllocation).filter_by(
+    def _get_allocation(self, context, segment):
+        return context.session.query(type_flat_model.FlatAllocation).filter_by(
             physical_network=segment[api.PHYSICAL_NETWORK]).first()
 
     def test_is_partial_segment(self):
@@ -100,28 +100,28 @@ class FlatTypeTest(testlib_api.SqlTestCase):
     def test_reserve_provider_segment(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_FLAT,
                    api.PHYSICAL_NETWORK: 'flat_net1'}
-        observed = self.driver.reserve_provider_segment(self.session, segment)
-        alloc = self._get_allocation(self.session, observed)
+        observed = self.driver.reserve_provider_segment(self.context, segment)
+        alloc = self._get_allocation(self.context, observed)
         self.assertEqual(segment[api.PHYSICAL_NETWORK], alloc.physical_network)
 
     def test_release_segment(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_FLAT,
                    api.PHYSICAL_NETWORK: 'flat_net1'}
-        self.driver.reserve_provider_segment(self.session, segment)
-        self.driver.release_segment(self.session, segment)
-        alloc = self._get_allocation(self.session, segment)
+        self.driver.reserve_provider_segment(self.context, segment)
+        self.driver.release_segment(self.context, segment)
+        alloc = self._get_allocation(self.context, segment)
         self.assertIsNone(alloc)
 
     def test_reserve_provider_segment_already_reserved(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_FLAT,
                    api.PHYSICAL_NETWORK: 'flat_net1'}
-        self.driver.reserve_provider_segment(self.session, segment)
+        self.driver.reserve_provider_segment(self.context, segment)
         self.assertRaises(n_exc.FlatNetworkInUse,
                           self.driver.reserve_provider_segment,
-                          self.session, segment)
+                          self.context, segment)
 
     def test_allocate_tenant_segment(self):
-        observed = self.driver.allocate_tenant_segment(self.session)
+        observed = self.driver.allocate_tenant_segment(self.context)
         self.assertIsNone(observed)
 
     def test_get_mtu(self):
diff --git a/neutron/tests/unit/plugins/ml2/drivers/test_type_gre.py b/neutron/tests/unit/plugins/ml2/drivers/test_type_gre.py
index d968b13e389..0f873feac5a 100644
--- a/neutron/tests/unit/plugins/ml2/drivers/test_type_gre.py
+++ b/neutron/tests/unit/plugins/ml2/drivers/test_type_gre.py
@@ -13,7 +13,6 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
-from neutron.db.models.plugins.ml2 import gre_allocation_endpoints as gre_model
 from neutron.plugins.common import constants as p_const
 from neutron.plugins.ml2.drivers import type_gre
 from neutron.tests.unit.plugins.ml2.drivers import base_type_tunnel
@@ -27,16 +26,6 @@ HOST_ONE = 'fake_host_one'
 HOST_TWO = 'fake_host_two'
 
 
-def _add_allocation(session, gre_id, allocated=False):
-    allocation = gre_model.GreAllocation(gre_id=gre_id, allocated=allocated)
-    allocation.save(session)
-
-
-def _get_allocation(session, gre_id):
-    return session.query(gre_model.GreAllocation).filter_by(
-        gre_id=gre_id).one()
-
-
 class GreTypeTest(base_type_tunnel.TunnelTypeTestMixin,
                   testlib_api.SqlTestCase):
     DRIVER_MODULE = type_gre
diff --git a/neutron/tests/unit/plugins/ml2/drivers/test_type_local.py b/neutron/tests/unit/plugins/ml2/drivers/test_type_local.py
index 1f10187b6c3..4c7714019d0 100644
--- a/neutron/tests/unit/plugins/ml2/drivers/test_type_local.py
+++ b/neutron/tests/unit/plugins/ml2/drivers/test_type_local.py
@@ -26,7 +26,7 @@ class LocalTypeTest(base.BaseTestCase):
     def setUp(self):
         super(LocalTypeTest, self).setUp()
         self.driver = type_local.LocalTypeDriver()
-        self.session = None
+        self.context = None
 
     def test_is_partial_segment(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_LOCAL}
@@ -52,15 +52,15 @@ class LocalTypeTest(base.BaseTestCase):
 
     def test_reserve_provider_segment(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_LOCAL}
-        observed = self.driver.reserve_provider_segment(self.session, segment)
+        observed = self.driver.reserve_provider_segment(self.context, segment)
         self.assertEqual(segment, observed)
 
     def test_release_provider_segment(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_LOCAL}
-        observed = self.driver.reserve_provider_segment(self.session, segment)
-        self.driver.release_segment(self.session, observed)
+        observed = self.driver.reserve_provider_segment(self.context, segment)
+        self.driver.release_segment(self.context, observed)
 
     def test_allocate_tenant_segment(self):
         expected = {api.NETWORK_TYPE: p_const.TYPE_LOCAL}
-        observed = self.driver.allocate_tenant_segment(self.session)
+        observed = self.driver.allocate_tenant_segment(self.context)
         self.assertEqual(expected, observed)
diff --git a/neutron/tests/unit/plugins/ml2/drivers/test_type_vlan.py b/neutron/tests/unit/plugins/ml2/drivers/test_type_vlan.py
index 205b25e3b46..3ab5d3adeae 100644
--- a/neutron/tests/unit/plugins/ml2/drivers/test_type_vlan.py
+++ b/neutron/tests/unit/plugins/ml2/drivers/test_type_vlan.py
@@ -17,7 +17,7 @@ import mock
 from neutron_lib import exceptions as exc
 from testtools import matchers
 
-import neutron.db.api as db
+from neutron import context
 from neutron.db.models.plugins.ml2 import vlanallocation as vlan_alloc_model
 from neutron.plugins.common import constants as p_const
 from neutron.plugins.common import utils as plugin_utils
@@ -49,7 +49,7 @@ class VlanTypeTest(testlib_api.SqlTestCase):
             NETWORK_VLAN_RANGES)
         self.driver = type_vlan.VlanTypeDriver()
         self.driver._sync_vlan_allocations()
-        self.session = db.get_session()
+        self.context = context.Context()
         self.driver.physnet_mtus = []
 
     def test_parse_network_exception_handling(self):
@@ -59,8 +59,9 @@ class VlanTypeTest(testlib_api.SqlTestCase):
             self.assertRaises(SystemExit,
                               self.driver._parse_network_vlan_ranges)
 
-    def _get_allocation(self, session, segment):
-        return session.query(vlan_alloc_model.VlanAllocation).filter_by(
+    def _get_allocation(self, context, segment):
+        return context.session.query(
+            vlan_alloc_model.VlanAllocation).filter_by(
             physical_network=segment[api.PHYSICAL_NETWORK],
             vlan_id=segment[api.SEGMENTATION_ID]).first()
 
@@ -129,17 +130,17 @@ class VlanTypeTest(testlib_api.SqlTestCase):
 
             segment[api.SEGMENTATION_ID] = vlan_min - 1
             self.assertIsNone(
-                self._get_allocation(self.session, segment))
+                self._get_allocation(self.context, segment))
             segment[api.SEGMENTATION_ID] = vlan_max + 1
             self.assertIsNone(
-                self._get_allocation(self.session, segment))
+                self._get_allocation(self.context, segment))
 
             segment[api.SEGMENTATION_ID] = vlan_min
             self.assertFalse(
-                self._get_allocation(self.session, segment).allocated)
+                self._get_allocation(self.context, segment).allocated)
             segment[api.SEGMENTATION_ID] = vlan_max
             self.assertFalse(
-                self._get_allocation(self.session, segment).allocated)
+                self._get_allocation(self.context, segment).allocated)
 
         check_in_ranges(self.network_vlan_ranges)
         self.driver.network_vlan_ranges = UPDATED_VLAN_RANGES
@@ -150,37 +151,37 @@ class VlanTypeTest(testlib_api.SqlTestCase):
         segment = {api.NETWORK_TYPE: p_const.TYPE_VLAN,
                    api.PHYSICAL_NETWORK: PROVIDER_NET,
                    api.SEGMENTATION_ID: 101}
-        alloc = self._get_allocation(self.session, segment)
+        alloc = self._get_allocation(self.context, segment)
         self.assertIsNone(alloc)
-        observed = self.driver.reserve_provider_segment(self.session, segment)
-        alloc = self._get_allocation(self.session, observed)
+        observed = self.driver.reserve_provider_segment(self.context, segment)
+        alloc = self._get_allocation(self.context, observed)
         self.assertTrue(alloc.allocated)
 
     def test_reserve_provider_segment_already_allocated(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_VLAN,
                    api.PHYSICAL_NETWORK: PROVIDER_NET,
                    api.SEGMENTATION_ID: 101}
-        observed = self.driver.reserve_provider_segment(self.session, segment)
+        observed = self.driver.reserve_provider_segment(self.context, segment)
         self.assertRaises(exc.VlanIdInUse,
                           self.driver.reserve_provider_segment,
-                          self.session,
+                          self.context,
                           observed)
 
     def test_reserve_provider_segment_in_tenant_pools(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_VLAN,
                    api.PHYSICAL_NETWORK: TENANT_NET,
                    api.SEGMENTATION_ID: VLAN_MIN}
-        alloc = self._get_allocation(self.session, segment)
+        alloc = self._get_allocation(self.context, segment)
         self.assertFalse(alloc.allocated)
-        observed = self.driver.reserve_provider_segment(self.session, segment)
-        alloc = self._get_allocation(self.session, observed)
+        observed = self.driver.reserve_provider_segment(self.context, segment)
+        alloc = self._get_allocation(self.context, observed)
         self.assertTrue(alloc.allocated)
 
     def test_reserve_provider_segment_without_segmentation_id(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_VLAN,
                    api.PHYSICAL_NETWORK: TENANT_NET}
-        observed = self.driver.reserve_provider_segment(self.session, segment)
-        alloc = self._get_allocation(self.session, observed)
+        observed = self.driver.reserve_provider_segment(self.context, segment)
+        alloc = self._get_allocation(self.context, observed)
         self.assertTrue(alloc.allocated)
         vlan_id = observed[api.SEGMENTATION_ID]
         self.assertThat(vlan_id, matchers.GreaterThan(VLAN_MIN - 1))
@@ -188,8 +189,8 @@ class VlanTypeTest(testlib_api.SqlTestCase):
 
     def test_reserve_provider_segment_without_physical_network(self):
         segment = {api.NETWORK_TYPE: p_const.TYPE_VLAN}
-        observed = self.driver.reserve_provider_segment(self.session, segment)
-        alloc = self._get_allocation(self.session, observed)
+        observed = self.driver.reserve_provider_segment(self.context, segment)
+        alloc = self._get_allocation(self.context, observed)
         self.assertTrue(alloc.allocated)
         vlan_id = observed[api.SEGMENTATION_ID]
         self.assertThat(vlan_id, matchers.GreaterThan(VLAN_MIN - 1))
@@ -198,11 +199,11 @@ class VlanTypeTest(testlib_api.SqlTestCase):
 
     def test_reserve_provider_segment_all_allocateds(self):
         for __ in range(VLAN_MIN, VLAN_MAX + 1):
-            self.driver.allocate_tenant_segment(self.session)
+            self.driver.allocate_tenant_segment(self.context)
         segment = {api.NETWORK_TYPE: p_const.TYPE_VLAN}
         self.assertRaises(exc.NoNetworkAvailable,
                           self.driver.reserve_provider_segment,
-                          self.session,
+                          self.context,
                           segment)
 
     def test_get_mtu(self):
@@ -228,8 +229,8 @@ class VlanTypeTest(testlib_api.SqlTestCase):
 
     def test_allocate_tenant_segment(self):
         for __ in range(VLAN_MIN, VLAN_MAX + 1):
-            segment = self.driver.allocate_tenant_segment(self.session)
-            alloc = self._get_allocation(self.session, segment)
+            segment = self.driver.allocate_tenant_segment(self.context)
+            alloc = self._get_allocation(self.context, segment)
             self.assertTrue(alloc.allocated)
             vlan_id = segment[api.SEGMENTATION_ID]
             self.assertThat(vlan_id, matchers.GreaterThan(VLAN_MIN - 1))
@@ -238,14 +239,14 @@ class VlanTypeTest(testlib_api.SqlTestCase):
 
     def test_allocate_tenant_segment_no_available(self):
         for __ in range(VLAN_MIN, VLAN_MAX + 1):
-            self.driver.allocate_tenant_segment(self.session)
-        segment = self.driver.allocate_tenant_segment(self.session)
+            self.driver.allocate_tenant_segment(self.context)
+        segment = self.driver.allocate_tenant_segment(self.context)
         self.assertIsNone(segment)
 
     def test_release_segment(self):
-        segment = self.driver.allocate_tenant_segment(self.session)
-        self.driver.release_segment(self.session, segment)
-        alloc = self._get_allocation(self.session, segment)
+        segment = self.driver.allocate_tenant_segment(self.context)
+        self.driver.release_segment(self.context, segment)
+        alloc = self._get_allocation(self.context, segment)
         self.assertFalse(alloc.allocated)
 
     def test_release_segment_unallocated(self):
@@ -253,7 +254,7 @@ class VlanTypeTest(testlib_api.SqlTestCase):
                    api.PHYSICAL_NETWORK: PROVIDER_NET,
                    api.SEGMENTATION_ID: 101}
         with mock.patch.object(type_vlan.LOG, 'warning') as log_warn:
-            self.driver.release_segment(self.session, segment)
+            self.driver.release_segment(self.context, segment)
             log_warn.assert_called_once_with(
                 "No vlan_id %(vlan_id)s found on physical network "
                 "%(physical_network)s",
diff --git a/neutron/tests/unit/plugins/ml2/test_plugin.py b/neutron/tests/unit/plugins/ml2/test_plugin.py
index 2fc802f6380..3b82280567c 100644
--- a/neutron/tests/unit/plugins/ml2/test_plugin.py
+++ b/neutron/tests/unit/plugins/ml2/test_plugin.py
@@ -1842,7 +1842,7 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
         dynamic_segmentation_id = dynamic_segment[driver_api.SEGMENTATION_ID]
         self.assertGreater(dynamic_segmentation_id, 0)
         self.driver.type_manager.release_dynamic_segment(
-            self.context.session, dynamic_segment[driver_api.ID])
+            self.context, dynamic_segment[driver_api.ID])
         self.assertIsNone(segments_db.get_dynamic_segment(
             self.context.session, network_id, 'physnet1'))
 
@@ -2009,7 +2009,7 @@ class TestMultiSegmentNetworks(Ml2PluginV2TestCase):
             with mock.patch('neutron.plugins.ml2.managers.segments_db') as db:
                 db.get_network_segments.return_value = (segment,)
                 self.driver.type_manager.release_network_segments(
-                    self.context.session, network_id)
+                    self.context, network_id)
 
                 log.error.assert_called_once_with(
                     "Failed to release segment '%s' because "