docker.images/ansible.awx/awx-17.1.0/awx/main/scheduler/dag_simple.py

228 lines
7.1 KiB
Python

from collections import deque
class SimpleDAG(object):
''' A simple implementation of a directed acyclic graph '''
def __init__(self):
self.nodes = []
self.root_nodes = set([])
r'''
Track node_obj->node index
dict where key is a full workflow node object or whatever we are
storing in ['node_object'] and value is an index to be used into
self.nodes
'''
self.node_obj_to_node_index = dict()
r'''
Track per-node from->to edges
i.e.
{
'success': {
1: [2, 3],
4: [2, 3],
},
'failed': {
1: [5],
}
}
'''
self.node_from_edges_by_label = dict()
r'''
Track per-node reverse relationship (child to parent)
i.e.
{
'success': {
2: [1, 4],
3: [1, 4],
},
'failed': {
5: [1],
}
}
'''
self.node_to_edges_by_label = dict()
def __contains__(self, obj):
if self.node['node_object'] in self.node_obj_to_node_index:
return True
return False
def __len__(self):
return len(self.nodes)
def __iter__(self):
return self.nodes.__iter__()
def generate_graphviz_plot(self, file_name="/awx_devel/graph.gv"):
def run_status(obj):
dnr = "RUN"
status = "NA"
if hasattr(obj, 'job') and obj.job and hasattr(obj.job, 'status'):
status = obj.job.status
if hasattr(obj, 'do_not_run') and obj.do_not_run is True:
dnr = "DNR"
return "{}_{}_{}".format(dnr, status, obj.id)
doc = """
digraph g {
rankdir = LR
"""
for n in self.nodes:
obj = n['node_object']
status = "NA"
if hasattr(obj, 'job') and obj.job:
status = obj.job.status
color = 'black'
if status == 'successful':
color = 'green'
elif status == 'failed':
color = 'red'
elif obj.do_not_run is True:
color = 'gray'
doc += "%s [color = %s]\n" % (
run_status(n['node_object']),
color
)
for label, edges in self.node_from_edges_by_label.items():
for from_node, to_nodes in edges.items():
for to_node in to_nodes:
doc += "%s -> %s [ label=\"%s\" ];\n" % (
run_status(self.nodes[from_node]['node_object']),
run_status(self.nodes[to_node]['node_object']),
label,
)
doc += "}\n"
gv_file = open(file_name, 'w')
gv_file.write(doc)
gv_file.close()
def add_node(self, obj, metadata=None):
if self.find_ord(obj) is None:
'''
Assume node is a root node until a child is added
'''
node_index = len(self.nodes)
self.root_nodes.add(node_index)
self.node_obj_to_node_index[obj] = node_index
entry = dict(node_object=obj, metadata=metadata)
self.nodes.append(entry)
def add_edge(self, from_obj, to_obj, label):
from_obj_ord = self.find_ord(from_obj)
to_obj_ord = self.find_ord(to_obj)
'''
To node is no longer a root node
'''
self.root_nodes.discard(to_obj_ord)
if from_obj_ord is None and to_obj_ord is None:
raise LookupError("From object {} and to object {} not found".format(from_obj, to_obj))
elif from_obj_ord is None:
raise LookupError("From object not found {}".format(from_obj))
elif to_obj_ord is None:
raise LookupError("To object not found {}".format(to_obj))
self.node_from_edges_by_label.setdefault(label, dict()) \
.setdefault(from_obj_ord, [])
self.node_to_edges_by_label.setdefault(label, dict()) \
.setdefault(to_obj_ord, [])
self.node_from_edges_by_label[label][from_obj_ord].append(to_obj_ord)
self.node_to_edges_by_label[label][to_obj_ord].append(from_obj_ord)
def find_ord(self, obj):
return self.node_obj_to_node_index.get(obj, None)
def _get_children_by_label(self, node_index, label):
return [self.nodes[index] for index in
self.node_from_edges_by_label.get(label, {})
.get(node_index, [])]
def get_children(self, obj, label=None):
this_ord = self.find_ord(obj)
nodes = []
if label:
return self._get_children_by_label(this_ord, label)
else:
nodes = []
for label_obj in self.node_from_edges_by_label.keys():
nodes.extend(self._get_children_by_label(this_ord, label_obj))
return nodes
def _get_parents_by_label(self, node_index, label):
return [self.nodes[index] for index in
self.node_to_edges_by_label.get(label, {})
.get(node_index, [])]
def get_parents(self, obj, label=None):
this_ord = self.find_ord(obj)
nodes = []
if label:
return self._get_parents_by_label(this_ord, label)
else:
nodes = []
for label_obj in self.node_to_edges_by_label.keys():
nodes.extend(self._get_parents_by_label(this_ord, label_obj))
return nodes
def get_root_nodes(self):
return [self.nodes[index] for index in self.root_nodes]
def has_cycle(self):
node_objs = [node['node_object'] for node in self.get_root_nodes()]
node_objs_visited = set([])
path = set([])
stack = node_objs
res = False
if len(self.nodes) != 0 and len(node_objs) == 0:
return True
while stack:
node_obj = stack.pop()
children = [node['node_object'] for node in self.get_children(node_obj)]
children_to_add = list(filter(lambda node_obj: node_obj not in node_objs_visited, children))
if children_to_add:
if node_obj in path:
res = True
break
path.add(node_obj)
stack.append(node_obj)
stack.extend(children_to_add)
else:
node_objs_visited.add(node_obj)
path.discard(node_obj)
return res
def sort_nodes_topological(self):
nodes_sorted = deque()
obj_ids_processed = set([])
def visit(node):
obj = node['node_object']
if obj.id in obj_ids_processed:
return
for child in self.get_children(obj):
visit(child)
obj_ids_processed.add(obj.id)
nodes_sorted.appendleft(node)
for node in self.nodes:
obj = node['node_object']
if obj.id in obj_ids_processed:
continue
visit(node)
return nodes_sorted