3
\$\begingroup\$

Below is my code for building a call-flow graph from a Python abstract syntax tree. I'm not satisfied with it because the algorithm is very complicated. Perhaps much more complicated than it needs to be. So I want suggestions on how to simplify the algorithm and also on how to make the code more readable.

class BasicBlock:
 def __init__(self, insns):
 self.insns = insns
 def type(self):
 return type(self.insns[-1]) if self.insns else None
class CFGBuilder:
 def __init__(self):
 self.succ = defaultdict(list)
 self.bbs = []
 def build_tree(self, nodes):
 buf = []
 for node in nodes:
 tp = type(node)
 if tp in (For, While):
 if buf:
 yield BasicBlock(buf), []
 yield BasicBlock([node]), [self.build_tree(node.body)]
 buf = []
 elif tp == If:
 buf.append(node)
 branches = [self.build_tree(node.body),
 self.build_tree(node.orelse)]
 yield BasicBlock(buf), branches
 buf = []
 elif tp in (Break, Continue, Pass, Return):
 buf.append(node)
 yield BasicBlock(buf), []
 return
 elif tp in (Assign, Expr):
 buf.append(node)
 else:
 assert False
 if buf:
 yield BasicBlock(buf), []
 def connect(self, bb_tree, parent_bb, loop_bb):
 tails = []
 breaks = []
 if parent_bb and bb_tree:
 tails = [parent_bb]
 for bb, branches in bb_tree:
 self.bbs.append(bb)
 for tail in tails:
 self.succ[tail].append(bb)
 tp = bb.type()
 if tp == If:
 true_tails, true_breaks = \
 self.connect(branches[0], bb, loop_bb)
 false_tails, false_breaks = \
 self.connect(branches[1], bb, loop_bb)
 breaks.extend(true_breaks + false_breaks)
 if not true_tails and not false_tails:
 return [], breaks
 tails = true_tails + false_tails
 if not branches[1]:
 tails.append(bb)
 elif tp in (For, While):
 tails, loop_breaks = \
 self.connect(branches[0], bb, bb)
 for tail in tails:
 self.succ[tail].append(bb)
 tails = [bb] + loop_breaks
 elif tp == Break:
 if loop_bb:
 return [], breaks + [bb]
 return [bb], []
 elif tp == Continue:
 if loop_bb:
 self.succ[bb].append(loop_bb)
 return [], breaks
 return [bb], []
 elif tp == Return:
 return [], []
 elif tp in (Assign, Expr, Pass, None):
 tails = [bb]
 else:
 assert False
 return tails, breaks
 def build(self, nodes):
 # SSA construction requires an entry block.
 bb_tree = [(BasicBlock([]), [])]
 bb_tree.extend(self.build_tree(nodes))
 # If the last block is a block statement, close the cfg with
 # an empty block.
 if bb_tree[-1][1]:
 dummy = BasicBlock([])
 bb_tree.append((dummy, []))
 self.connect(bb_tree, None, None)
 return self.bbs, self.succ

You can run it like this:

PROG = '''
for ba in range(10):
 if -(a + 10):
 break
 for y in range(10):
 if not x:
 print('cont from x_99')
 continue
 x = x + 1
 if x:
 if y:
 return x
 pass
 break
 print(x)
print('blah')
while 1:
 if x == 0:
 while x < 10:
 x = x + 1
 y = y + 1
 if x == 5:
 break
 if y == 5:
 continue
 else:
 if y == 2:
 b = b + 1
 continue
 a = a + 1
 continue
 print('prutt')
 break
 if b:
 return
 x = 99
 if m:
 while 2:
 if 1:
 break
 else:
 continue
 x = 9
 print('dead code')
 print('this reaches')
print('but ok')
if 2:
 break
else:
 pass
break
print(10)
'''
root = parse(PROG)
builder = CFGBuilder()
bbs, succ = builder.build(root.body)

It produces this cfg:

enter image description here

Here is the gist for plotting.

asked May 7, 2022 at 16:46
\$\endgroup\$
6
  • \$\begingroup\$ This immediately chokes on ast.Import. So like... I'll only verify it with your example source and not realistic source, I guess, because realistic source won't work. \$\endgroup\$ Commented May 7, 2022 at 18:34
  • \$\begingroup\$ Correct. It is only supposed to support the basic control flow in the example. \$\endgroup\$ Commented May 7, 2022 at 18:51
  • \$\begingroup\$ Where's your dot graphing code? It would be useful to include that. \$\endgroup\$ Commented May 7, 2022 at 19:02
  • \$\begingroup\$ @Reinderien I added a gist for plotting. Though I don't need help with the plotting code. \$\endgroup\$ Commented May 12, 2022 at 8:53
  • \$\begingroup\$ Please include it verbatim in the question body. \$\endgroup\$ Commented May 12, 2022 at 18:12

1 Answer 1

1
\$\begingroup\$

Avoid from ast import *; ast is already conveniently short as a prefix to a fully-qualified name, and importing splat pollutes your namespace.

Add type hints. This is especially important to understand what on earth your algorithm is doing. In fact, one of your types is recursive! This will likely confuse static analysers like mypy, but better to have it than nothing at all.

It's somewhat unidiomatic to pull the type of a variable and then compare that reference to a sequence of other types. The traditional idiomatic way is isinstance; the modern way is a match specifying your types. In the sample code I demonstrate both.

Don't assert in production code, and especially don't assert False. Throw a meaningful exception - for now I have filled this in as NotImplementedError.

Avoid line continuation \ - if multiple lines are necessary (which I don't think they are in your case), prefer parens.

Instead of

breaks.extend(true_breaks + false_breaks)

which creates an intermediate list only to throw it away, prefer two extend calls.

Rather than branches[0] or [1], unpack this to a true and false branch. It's more self-documenting, with a bonus that it will catch unexpected sequence lengths.

Don't compare a type to None - compare it to NoneType. None is not of type None.

Suggested

Produces the same output you have; plotting code shoved into a containment function:

import ast
from collections import defaultdict
from types import NoneType
from typing import Optional, Type, Iterable, Iterator
class BasicBlock:
 def __init__(self, insns: list[ast.stmt]) -> None:
 self.insns = insns
 @property
 def last(self) -> Optional[ast.stmt]:
 return self.insns[-1] if self.insns else None
 def matches(self, *types: Type[ast.stmt]) -> bool:
 return isinstance(self.last, types)
TreeNode = tuple[
 BasicBlock,
 list[Iterable['TreeNode']] # this is recursive - yikes
]
class CFGBuilder:
 def __init__(self) -> None:
 # Basic blocks in program order and a mapping of blocks to
 # their successors.
 self.succ: defaultdict[BasicBlock, list[BasicBlock]] = defaultdict(list)
 self.bbs: list[BasicBlock] = []
 def build_tree(self, nodes: list[ast.stmt]) -> Iterator[TreeNode]:
 buf: list[ast.stmt] = []
 
 for node in nodes:
 match node:
 case ast.For() | ast.While():
 if buf:
 yield BasicBlock(buf), []
 yield BasicBlock([node]), [self.build_tree(node.body)]
 buf = []
 
 case ast.If():
 buf.append(node)
 branches = [self.build_tree(node.body),
 self.build_tree(node.orelse)]
 yield BasicBlock(buf), branches
 buf = []
 
 case ast.Break() | ast.Continue() | ast.Pass() | ast.Return():
 buf.append(node)
 yield BasicBlock(buf), []
 return
 
 case ast.Assign() | ast.Expr():
 buf.append(node)
 
 case other:
 raise NotImplementedError()
 if buf:
 yield BasicBlock(buf), []
 def connect(
 self,
 bb_tree: Iterable[TreeNode],
 parent_bb: Optional[BasicBlock],
 loop_bb: Optional[BasicBlock],
 ) -> tuple[
 list[BasicBlock], # tails
 list[BasicBlock], # breaks
 ]:
 breaks = []
 if parent_bb and bb_tree:
 tails = [parent_bb]
 else:
 tails = []
 for bb, branches in bb_tree:
 self.bbs.append(bb)
 for tail in tails:
 self.succ[tail].append(bb)
 if bb.matches(ast.If):
 true_branch, false_branch = branches
 true_tails, true_breaks = self.connect(true_branch, bb, loop_bb)
 false_tails, false_breaks = self.connect(false_branch, bb, loop_bb)
 breaks.extend(true_breaks)
 breaks.extend(false_breaks)
 if not (true_tails or false_tails):
 return [], breaks
 tails = true_tails + false_tails
 if not false_branch:
 tails.append(bb)
 elif bb.matches(ast.For, ast.While):
 branch, = branches
 tails, loop_breaks = self.connect(branch, bb, bb)
 for tail in tails:
 self.succ[tail].append(bb)
 tails = [bb] + loop_breaks
 elif bb.matches(ast.Break):
 if loop_bb:
 return [], breaks + [bb]
 return [bb], []
 elif bb.matches(ast.Continue):
 if loop_bb:
 self.succ[bb].append(loop_bb)
 return [], breaks
 return [bb], []
 elif bb.matches(ast.Return):
 return [], []
 elif bb.matches(ast.Assign, ast.Expr, ast.Pass, NoneType):
 tails = [bb]
 else:
 raise NotImplementedError()
 return tails, breaks
 def build(self, nodes: list[ast.stmt]) -> tuple[
 list[BasicBlock],
 defaultdict[BasicBlock, list[BasicBlock]],
 ]:
 # SSA construction requires an entry block.
 bb_tree = [(BasicBlock([]), [])]
 bb_tree.extend(self.build_tree(nodes))
 # If the last block is a block statement, close the cfg with
 # an empty block.
 if bb_tree[-1][1]:
 dummy = BasicBlock([])
 bb_tree.append((dummy, []))
 self.connect(bb_tree, None, None)
 return self.bbs, self.succ
def plot():
 from pygraphviz import AGraph
 from re import sub
 COLOR_KWD = '#a020f0'
 COLOR_VAR = '#6c71c4'
 COLOR_STR = '#8b2252'
 def colorize(s, col):
 return '<font color="%s">%s</font>' % (col, s)
 def kwd(s):
 return colorize(s, COLOR_KWD)
 def string(s):
 return colorize(f'&#x27;{s}&#x27;', COLOR_STR)
 OPS_HTML = {
 ast.Not: kwd('not'),
 ast.Add: '+',
 ast.Mult: '*',
 ast.Sub: '-',
 ast.USub: '-'
 }
 OPS_PRECEDENCES = {
 ast.Add: 0,
 ast.Sub: 0,
 ast.Mult: 1
 }
 def htmlify(node):
 tp = type(node)
 if tp == ast.If:
 return kwd('if') + ' ' + htmlify(node.test)
 elif tp == ast.Assign:
 targets = ', '.join([htmlify(t) for t in node.targets])
 value = htmlify(node.value)
 return f'{targets} &larr; {value}'
 if tp == ast.BinOp:
 left = node.left
 right = node.right
 left_html = htmlify(left)
 right_html = htmlify(right)
 if type(left) == ast.BinOp and OPS_PRECEDENCES[type(left.op)] == 0:
 left_html = f'({left_html})'
 if type(right) == ast.BinOp:
 right_html = f'({right_html})'
 return ' '.join([
 left_html, htmlify(node.op), right_html
 ])
 elif tp == ast.UnaryOp:
 operand = node.operand
 operand_html = htmlify(operand)
 if type(operand) == ast.BinOp:
 if OPS_PRECEDENCES[type(operand.op)] == 0:
 operand_html = f'({operand_html})'
 op_html = htmlify(node.op)
 if type(node.op) == ast.Not:
 op_html += ' '
 return op_html + operand_html
 elif tp in OPS_HTML:
 return OPS_HTML[tp]
 elif tp == ast.For:
 return ' '.join([
 kwd('for'), htmlify(node.target),
 kwd('in'), htmlify(node.iter)
 ])
 elif tp == ast.While:
 return kwd('while') + ' ' + htmlify(node.test)
 elif tp == ast.Compare:
 assert len(node.ops) == 1
 return ' '.join([
 htmlify(node.left), htmlify(node.ops[0]),
 htmlify(node.comparators[0])
 ])
 elif tp == ast.Expr:
 return htmlify(node.value)
 elif tp == ast.Eq:
 return '='
 elif tp == ast.Lt:
 return '&lt;'
 elif tp == ast.Name:
 id = node.id
 s = f'<i>{id}</i>'
 if '_' in id:
 s = sub(r'(\w+)_(\d+|\?)', r'<i>1円</i><sub>2円</sub>', id)
 return colorize(s, COLOR_VAR)
 elif tp == ast.Constant:
 value = node.value
 if type(value) == str:
 return string(value)
 return str(value)
 elif tp == ast.Num:
 return str(node.n)
 elif tp == ast.Str:
 return string(node.s)
 elif tp == ast.Subscript:
 value_html = htmlify(node.value)
 slice_html = htmlify(node.slice)
 return f'{value_html}&#91;{slice_html}&#93;'
 elif tp == ast.Index:
 return htmlify(node.value)
 elif tp == ast.Call:
 id = node.func.id
 args = ', '.join(htmlify(a) for a in node.args)
 # Phis are so special we pretend that they are keywords
 if id == 'phi':
 return ' '.join([kwd('phi'), args])
 return f'{id}({args})'
 elif tp in (ast.Break, ast.Continue, ast.Pass):
 return kwd(str(tp.__name__).lower())
 elif tp == ast.Return:
 s = kwd('return')
 if node.value:
 return s + ' ' + htmlify(node.value)
 return s
 else:
 assert False
 def plot_bbs(bbs, succ):
 G = AGraph(strict=False, directed=True)
 graph_attrs = {
 'dpi': 300,
 'ranksep': 0.3,
 'fontname': 'Inconsolata',
 'bgcolor': 'transparent'
 }
 G.graph_attr.update(graph_attrs)
 node_attrs = {
 'shape': 'box',
 'width': 0.55,
 'style': 'filled',
 'fillcolor': 'white'
 }
 G.node_attr.update(node_attrs)
 edge_attrs = {
 'fontsize': '10pt'
 }
 G.edge_attr.update(edge_attrs)
 names = {bb: i for i, bb in enumerate(bbs)}
 # Add nodes and edges.
 for bb, name in names.items():
 edges = succ[bb]
 peri = 2 if bb == bbs[0] or not edges else 1
 lines = [htmlify(node) for node in bb.insns]
 label = ''.join(l + '<br align="left"/>' for l in lines)
 label = f'<{label}>'
 G.add_node(name, label=label, peripheries=peri)
 colors = ['black'] * len(edges)
 if len(edges) == 2:
 colors = ['#00aa00', '#aa0000']
 for bb2, color in zip(edges, colors):
 G.add_edge(names[bb], names[bb2], color=color)
 G.draw('test.png', prog='dot')
 return plot_bbs
def main() -> None:
 root = ast.parse('''
for ba in range(10):
 if -(a + 10):
 break
 for y in range(10):
 if not x:
 print('cont from x_99')
 continue
 x = x + 1
 if x:
 if y:
 return x
 pass
 break
 print(x)
print('blah')
while 1:
 if x == 0:
 while x < 10:
 x = x + 1
 y = y + 1
 if x == 5:
 break
 if y == 5:
 continue
 else:
 if y == 2:
 b = b + 1
 continue
 a = a + 1
 continue
 print('prutt')
 break
 if b:
 return
 x = 99
 if m:
 while 2:
 if 1:
 break
 else:
 continue
 x = 9
 print('dead code')
 print('this reaches')
print('but ok')
if 2:
 break
else:
 pass
break
print(10)
''')
 builder = CFGBuilder()
 bbs, succ = builder.build(root.body)
 plot()(bbs, succ)
if __name__ == '__main__':
 main()
answered May 12, 2022 at 19:44
\$\endgroup\$
1
  • \$\begingroup\$ Thanks. Though plotting is not part of my problem so you dont have to include that. \$\endgroup\$ Commented May 12, 2022 at 20:28

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.