diff --git a/diskimage_builder/block_device/level1/lvm.py b/diskimage_builder/block_device/level1/lvm.py
index 2097f5e56..c1b6172c3 100644
--- a/diskimage_builder/block_device/level1/lvm.py
+++ b/diskimage_builder/block_device/level1/lvm.py
@@ -288,41 +288,21 @@ class LVMNode(NodeBase):
 
         exec_sudo(['udevadm', 'settle'])
 
-
-class LVMUmountNode(NodeBase):
-    def __init__(self, name, state, pvs):
-        """Umount Node for LVM
-
-        Information about the PV, VG and LV is typically
-        cached in lvmetad. Even after removing PV device and
-        partitions this data is not automatically updated
-        which leads to a couple of problems.
-        the 'pvscan --cache' scans the available disks
-        and updates the cache.
-        This must be called after the umount of the
-        containing block device is done.
-        """
-        super(LVMUmountNode, self).__init__(name, state)
-        self.pvs = pvs
-
-    def create(self):
-        # This class is used for cleanup only
-        pass
-
-    def umount(self):
+    def cleanup(self):
+        # Information about the PV, VG and LV is typically
+        # cached in lvmetad. Even after removing PV device and
+        # partitions this data is not automatically updated
+        # which leads to a couple of problems.
+        # the 'pvscan --cache' scans the available disks
+        # and updates the cache.
+        # This is in cleanup because it must be called after the
+        # umount of the containing block device is done, (which should
+        # all be done in umount phase).
         try:
             exec_sudo(['pvscan', '--cache'])
         except BlockDeviceSetupException as e:
             logger.info("pvscan call failed [%s]", e.returncode)
 
-    def get_edges(self):
-        # This node depends on all physical device(s), which is
-        # recorded in the "base" argument of the PV nodes.
-        edge_to = set()
-        for pv in self.pvs:
-            edge_to.add(pv.base)
-        return ([], edge_to)
-
 
 class LVMPlugin(PluginBase):
 
@@ -411,12 +391,10 @@ class LVMPlugin(PluginBase):
         # create the "driver" node
         self.lvm_node = LVMNode(config['name'], state,
                                 self.pvs, self.lvs, self.vgs)
-        self.lvm_umount_node = LVMUmountNode(
-            config['name'] + "-UMOUNT", state, self.pvs)
 
     def get_nodes(self):
         # the nodes for insertion into the graph are all of the pvs,
-        # vgs and lvs nodes we have created above, the root node and
+        # vgs and lvs nodes we have created above and the root node and
         # the cleanup node.
         return self.pvs + self.vgs + self.lvs \
-            + [self.lvm_node, self.lvm_umount_node]
+            + [self.lvm_node]
diff --git a/diskimage_builder/block_device/tests/test_lvm.py b/diskimage_builder/block_device/tests/test_lvm.py
index 3ca37287c..a16b0d04f 100644
--- a/diskimage_builder/block_device/tests/test_lvm.py
+++ b/diskimage_builder/block_device/tests/test_lvm.py
@@ -23,7 +23,6 @@ from diskimage_builder.block_device.exception import \
     BlockDeviceSetupException
 from diskimage_builder.block_device.level1.lvm import LVMNode
 from diskimage_builder.block_device.level1.lvm import LVMPlugin
-from diskimage_builder.block_device.level1.lvm import LVMUmountNode
 from diskimage_builder.block_device.level1.lvm import LvsNode
 from diskimage_builder.block_device.level1.lvm import PvsNode
 from diskimage_builder.block_device.level1.lvm import VgsNode
@@ -89,7 +88,7 @@ class TestLVM(tc.TestGraphGeneration):
             # XXX: This has not mocked out the "lower" layers of
             # creating the devices, which we're assuming works OK, nor
             # the upper layers.
-            if isinstance(node, (LVMNode, LVMUmountNode, PvsNode,
+            if isinstance(node, (LVMNode, PvsNode,
                                  VgsNode, LvsNode)):
                 # only the LVMNode actually does anything here...
                 node.create()
@@ -198,7 +197,7 @@ class TestLVM(tc.TestGraphGeneration):
                 # XXX: This has not mocked out the "lower" layers of
                 # creating the devices, which we're assuming works OK, nor
                 # the upper layers.
-                if isinstance(node, (LVMNode, LVMUmountNode, PvsNode,
+                if isinstance(node, (LVMNode, PvsNode,
                                      VgsNode, LvsNode)):
                     # only the PvsNode actually does anything here...
                     node.create()
@@ -306,11 +305,16 @@ class TestLVM(tc.TestGraphGeneration):
                 return r
             mock_temp.side_effect = new_tempfile
 
-            reverse_order = reversed(call_order)
-            for node in reverse_order:
-                if isinstance(node, (LVMNode, LVMUmountNode, PvsNode,
-                                     VgsNode, LvsNode)):
-                    node.umount()
+            def run_it(phase):
+                reverse_order = reversed(call_order)
+                for node in reverse_order:
+                    if isinstance(node, (LVMNode, PvsNode, VgsNode, LvsNode)):
+                        getattr(node, phase)()
+                    else:
+                        logger.debug("Skipping node for test: %s", node)
+
+            run_it('umount')
+            run_it('cleanup')
 
             cmd_sequence = [
                 # delete the lv's
@@ -365,8 +369,7 @@ class TestLVM(tc.TestGraphGeneration):
                 # XXX: This has not mocked out the "lower" layers of
                 # creating the devices, which we're assuming works OK, nor
                 # the upper layers.
-                if isinstance(node, (LVMNode, LVMUmountNode,
-                                     PvsNode, VgsNode, LvsNode)):
+                if isinstance(node, (LVMNode, PvsNode, VgsNode, LvsNode)):
                     # only the LVMNode actually does anything here...
                     node.create()
 
@@ -407,11 +410,16 @@ class TestLVM(tc.TestGraphGeneration):
                 return r
             mock_temp.side_effect = new_tempfile
 
-            reverse_order = reversed(call_order)
-            for node in reverse_order:
-                if isinstance(node, (LVMNode, LVMUmountNode,
-                                     PvsNode, VgsNode, LvsNode)):
-                    node.umount()
+            def run_it(phase):
+                reverse_order = reversed(call_order)
+                for node in reverse_order:
+                    if isinstance(node, (LVMNode, PvsNode, VgsNode, LvsNode)):
+                        getattr(node, phase)()
+                    else:
+                        logger.debug("Skipping node for test: %s", node)
+
+            run_it('umount')
+            run_it('cleanup')
 
             cmd_sequence = [
                 # deactivate lv's