Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 0932cc9

Browse files
authored
Merge pull request #179 from bcaller/ifexp
Better handling of IfExp (ternary)
2 parents 5d7a94b + 2e4f8c9 commit 0932cc9

File tree

7 files changed

+152
-1
lines changed

7 files changed

+152
-1
lines changed

‎examples/example_inputs/ternary.py‎

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
result = (
2+
"abc"
3+
if t.u == v.w else
4+
"def"
5+
if x else
6+
y # This is the only RHS variable which taints result
7+
if func(z if 1 + 1 == 2 else z) else
8+
"ghi"
9+
)

‎pyt/core/transformer.py‎

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,80 @@ def visit_Return(self, node):
6464
return self.visit_chain(node)
6565

6666

67-
class PytTransformer(AsyncTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
67+
class IfExpRewriter(ast.NodeTransformer):
68+
"""Splits IfExp ternary expressions containing complex tests into multiple statements
69+
70+
Will change
71+
72+
a if b(c) else d
73+
74+
into
75+
76+
a if __if_exp_0 else d
77+
78+
with Assign nodes in assignments [__if_exp_0 = b(c)]
79+
"""
80+
81+
def __init__(self, starting_index=0):
82+
self._temporary_variable_index = starting_index
83+
self.assignments = []
84+
super().__init__()
85+
86+
def visit_IfExp(self, node):
87+
if isinstance(node.test, (ast.Name, ast.Attribute)):
88+
return self.generic_visit(node)
89+
else:
90+
temp_var_id = '__if_exp_{}'.format(self._temporary_variable_index)
91+
self._temporary_variable_index += 1
92+
assignment_of_test = ast.Assign(
93+
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
94+
value=self.visit(node.test),
95+
)
96+
ast.copy_location(assignment_of_test, node)
97+
self.assignments.append(assignment_of_test)
98+
transformed_if_exp = ast.IfExp(
99+
test=ast.Name(id=temp_var_id, ctx=ast.Load()),
100+
body=self.visit(node.body),
101+
orelse=self.visit(node.orelse),
102+
)
103+
ast.copy_location(transformed_if_exp, node)
104+
return transformed_if_exp
105+
106+
def visit_FunctionDef(self, node):
107+
return node
108+
109+
110+
class IfExpTransformer:
111+
"""Goes through module and function bodies, adding extra Assign nodes due to IfExp expressions."""
112+
113+
def visit_body(self, nodes):
114+
new_nodes = []
115+
count = 0
116+
for node in nodes:
117+
rewriter = IfExpRewriter(count)
118+
possibly_transformed_node = rewriter.visit(node)
119+
if rewriter.assignments:
120+
new_nodes.extend(rewriter.assignments)
121+
count += len(rewriter.assignments)
122+
new_nodes.append(possibly_transformed_node)
123+
return new_nodes
124+
125+
def visit_FunctionDef(self, node):
126+
transformed = ast.FunctionDef(
127+
name=node.name,
128+
args=node.args,
129+
body=self.visit_body(node.body),
130+
decorator_list=node.decorator_list,
131+
returns=node.returns
132+
)
133+
ast.copy_location(transformed, node)
134+
return self.generic_visit(transformed)
135+
136+
def visit_Module(self, node):
137+
transformed = ast.Module(self.visit_body(node.body))
138+
ast.copy_location(transformed, node)
139+
return self.generic_visit(transformed)
140+
141+
142+
class PytTransformer(AsyncTransformer, IfExpTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
68143
pass

‎pyt/helper_visitors/label_visitor.py‎

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,12 @@ def visit_FormattedValue(self, node):
324324
def visit_Starred(self, node):
325325
self.result += '*'
326326
self.visit(node.value)
327+
328+
def visit_IfExp(self, node):
329+
self.result += '('
330+
self.visit(node.test)
331+
self.result += ') ? ('
332+
self.visit(node.body)
333+
self.result += ') : ('
334+
self.visit(node.orelse)
335+
self.result += ')'

‎pyt/helper_visitors/right_hand_side_visitor.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ def visit_Call(self, node):
2222
for keyword in node.keywords:
2323
self.visit(keyword)
2424

25+
def visit_IfExp(self, node):
26+
# The test doesn't taint the assignment
27+
self.visit(node.body)
28+
self.visit(node.orelse)
29+
2530
@classmethod
2631
def result_for_node(cls, node):
2732
visitor = cls()

‎tests/cfg/cfg_test.py‎

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,39 @@ def test_if_not(self):
580580
(_exit, _if)
581581
])
582582

583+
def test_ternary_ifexp(self):
584+
self.cfg_create_from_file('examples/example_inputs/ternary.py')
585+
586+
# entry = 0
587+
tmp_if_1 = 1
588+
# tmp_if_inner = 2
589+
call = 3
590+
# tmp_if_call = 4
591+
actual_if_exp = 5
592+
exit = 6
593+
594+
self.assert_length(self.cfg.nodes, expected_length=exit + 1)
595+
self.assertInCfg([
596+
(i + 1, i) for i in range(exit)
597+
])
598+
599+
self.assertCountEqual(
600+
self.cfg.nodes[actual_if_exp].right_hand_side_variables,
601+
['y'],
602+
"The variables in the test expressions shouldn't appear as RHS variables"
603+
)
604+
605+
self.assertCountEqual(
606+
self.cfg.nodes[tmp_if_1].right_hand_side_variables,
607+
['t', 'v'],
608+
)
609+
610+
self.assertIn(
611+
'ret_func(',
612+
self.cfg.nodes[call].label,
613+
"Function calls inside the test expressions should still appear in the CFG",
614+
)
615+
583616

584617
class CFGWhileTest(CFGBaseTestCase):
585618

‎tests/core/transformer_test.py‎

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,19 @@ def test_chained_function(self):
5252

5353
transformed = PytTransformer().visit(chained_tree)
5454
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))
55+
56+
def test_if_exp(self):
57+
complex_if_exp_tree = ast.parse("\n".join([
58+
"def a():",
59+
" b = c if d.e(f) else g if h else i if j.k(l) else m",
60+
]))
61+
62+
separated_tree = ast.parse("\n".join([
63+
"def a():",
64+
" __if_exp_0 = d.e(f)",
65+
" __if_exp_1 = j.k(l)",
66+
" b = c if __if_exp_0 else g if h else i if __if_exp_1 else m",
67+
]))
68+
69+
transformed = PytTransformer().visit(complex_if_exp_tree)
70+
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))

‎tests/helper_visitors/label_visitor_test.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,7 @@ def test_joined_str_with_format_spec(self):
8383
def test_starred(self):
8484
label = self.perform_labeling_on_expression('[a, *b] = *c, d')
8585
self.assertEqual(label.result, '[a, *b] = (*c, d)')
86+
87+
def test_if_exp(self):
88+
label = self.perform_labeling_on_expression('a = b if c else d')
89+
self.assertEqual(label.result, 'a = (c) ? (b) : (d)')

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /