diff --git a/diskimage_builder/graph/digraph.py b/diskimage_builder/graph/digraph.py index 6e8b7cb2a..5b7758d9c 100644 --- a/diskimage_builder/graph/digraph.py +++ b/diskimage_builder/graph/digraph.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. # +import bisect class Digraph(object): @@ -20,6 +21,32 @@ class Digraph(object): Each node of the digraph must have a unique name. """ + class Edge(object): + """Directed graph edge. + + The digraph has weighted edges. This class holds the weight and + a reference to the node. + """ + + def __init__(self, node, weight): + self.__node = node + self.__weight = weight + + def __eq__(self, other): + return self.__weight == other.get_weight() \ + and self.__node == other.get_node() + + def __lt__(self, other): + return self.__weight < other.get_weight() + + def get_node(self): + """Return the (pointed to) node""" + return self.__node + + def get_weight(self): + """Return the edge's weight""" + return self.__weight + class Node(object): """Directed graph node. @@ -35,33 +62,38 @@ class Digraph(object): computed. """ self.__name = name - self.__incoming = set() - self.__outgoing = set() + self.__incoming = [] + self.__outgoing = [] + + def __repr__(self): + return "<Node [%s]>" % self.__name def get_name(self): """Returns the name of the node.""" return self.__name - def add_incoming(self, node): + def add_incoming(self, node, weight): """Add node to the incoming list.""" + bisect.insort(self.__incoming, Digraph.Edge(node, weight)) - self.__incoming.add(node) - - def add_outgoing(self, node): - """Add node to the incoming list.""" - - self.__outgoing.add(node) + def add_outgoing(self, node, weight): + """Add node to the outgoing list.""" + bisect.insort(self.__outgoing, Digraph.Edge(node, weight)) def get_iter_outgoing(self): """Return an iterator over the outgoing nodes.""" - return iter(self.__outgoing) + return iter([x.get_node() for x in self.__outgoing]) + + def has_incoming(self): + """Returns True if the node has incoming edges""" + return self.__incoming @staticmethod def __as_named_list(inlist): """Return given list as list of names.""" - return map(lambda x: x.get_name(), inlist) + return [x.get_node().get_name() for x in inlist] def get_outgoing_as_named_list(self): """Return the names of all outgoing nodes as a list.""" @@ -105,7 +137,7 @@ class Digraph(object): "exists" % node.get_name()) self._named_nodes[anode.get_name()] = anode - def create_edge(self, anode, bnode): + def create_edge(self, anode, bnode, weight=0): """Creates an edge from a to b - both must be nodes.""" assert issubclass(anode.__class__, Digraph.Node) @@ -114,8 +146,8 @@ class Digraph(object): assert anode == self._named_nodes[anode.get_name()] assert bnode.get_name() in self._named_nodes.keys() assert bnode == self._named_nodes[bnode.get_name()] - anode.add_outgoing(bnode) - bnode.add_incoming(anode) + anode.add_outgoing(bnode, weight) + bnode.add_incoming(anode, weight) def get_iter_nodes_values(self): """Returns the nodes dict to the values. @@ -144,7 +176,7 @@ class Digraph(object): rval[node.get_name()] = node.get_outgoing_as_named_list() return rval - def topological_sort(dg): + def topological_sort(self): """Digraph topological search. This algorithm is based upon a depth first search with @@ -169,7 +201,9 @@ class Digraph(object): tsort.insert(0, node) # The 'main' function of the topological sort - for node in dg.get_iter_nodes_values(): + for node in self.get_iter_nodes_values(): + if node.has_incoming(): + continue visit(node) return tsort @@ -189,6 +223,6 @@ def node_list_to_node_name_list(node_list): """Converts a node list into a list of the corresponding node names.""" node_name_list = [] - for n in node_list: - node_name_list.append(n.get_name()) + for node in node_list: + node_name_list.append(node.get_name()) return node_name_list diff --git a/diskimage_builder/tests/functional/test_graph.py b/diskimage_builder/tests/functional/test_graph.py index 44eb148e2..8066cdbef 100644 --- a/diskimage_builder/tests/functional/test_graph.py +++ b/diskimage_builder/tests/functional/test_graph.py @@ -121,3 +121,23 @@ class TestDigraph(testtools.TestCase): self.assertTrue(False) except RuntimeError: pass + + def test_iter_outgoing_weight_01(self): + """Tests iter_outgoing in a graph with weights""" + + digraph = Digraph() + node0 = Digraph.Node("R") + digraph.add_node(node0) + node1 = Digraph.Node("A") + digraph.add_node(node1) + node2 = Digraph.Node("B") + digraph.add_node(node2) + node3 = Digraph.Node("C") + digraph.add_node(node3) + + digraph.create_edge(node0, node1, 1) + digraph.create_edge(node0, node2, 2) + digraph.create_edge(node0, node3, 3) + + self.assertEqual([node1, node2, node3], + list(node0.get_iter_outgoing())) diff --git a/diskimage_builder/tests/functional/test_graph_toposort.py b/diskimage_builder/tests/functional/test_graph_toposort.py index 25a14475f..4d30cdb93 100644 --- a/diskimage_builder/tests/functional/test_graph_toposort.py +++ b/diskimage_builder/tests/functional/test_graph_toposort.py @@ -12,9 +12,11 @@ # License for the specific language governing permissions and limitations # under the License. +import testtools + +from diskimage_builder.graph.digraph import Digraph from diskimage_builder.graph.digraph import digraph_create_from_dict from diskimage_builder.graph.digraph import node_list_to_node_name_list -import testtools class TestTopologicalSearch(testtools.TestCase): @@ -67,3 +69,48 @@ class TestTopologicalSearch(testtools.TestCase): self.assertTrue(tnames.index('A') < tnames.index('B')) self.assertTrue(tnames.index('B') < tnames.index('C')) self.assertTrue(tnames.index('D') < tnames.index('E')) + + def test_tsort_006(self): + """Complex digraph with weights""" + + digraph = Digraph() + node0 = Digraph.Node("R") + digraph.add_node(node0) + node1 = Digraph.Node("A") + digraph.add_node(node1) + node2 = Digraph.Node("B") + digraph.add_node(node2) + node3 = Digraph.Node("C") + digraph.add_node(node3) + node4 = Digraph.Node("B1") + digraph.add_node(node4) + node5 = Digraph.Node("B2") + digraph.add_node(node5) + node6 = Digraph.Node("B3") + digraph.add_node(node6) + + digraph.create_edge(node0, node1, 1) + digraph.create_edge(node0, node2, 2) + digraph.create_edge(node0, node3, 3) + + digraph.create_edge(node2, node4, 7) + digraph.create_edge(node2, node5, 14) + digraph.create_edge(node2, node6, 21) + + tsort = digraph.topological_sort() + tnames = node_list_to_node_name_list(tsort) + + # Also here: many possible solutions + self.assertTrue(tnames.index('R') < tnames.index('A')) + self.assertTrue(tnames.index('R') < tnames.index('B')) + self.assertTrue(tnames.index('R') < tnames.index('C')) + self.assertTrue(tnames.index('B') < tnames.index('B1')) + self.assertTrue(tnames.index('B') < tnames.index('B2')) + self.assertTrue(tnames.index('B') < tnames.index('B3')) + + # In addition in the weighted graph the following + # must also hold: + self.assertTrue(tnames.index('B') < tnames.index('A')) + self.assertTrue(tnames.index('C') < tnames.index('B')) + self.assertTrue(tnames.index('B2') < tnames.index('B1')) + self.assertTrue(tnames.index('B3') < tnames.index('B2'))