@@ -64,5 +64,80 @@ def visit_Return(self, node):
6464 return self .visit_chain (node )
6565
6666
67- class PytTransformer (AsyncTransformer , ChainedFunctionTransformer , ast .NodeTransformer ):
67+ class IfExpRewriter (ast .NodeTransformer ):
68+ """Splits IfExp ternary expressions containing complex tests into multiple statements
69+
70+ Will change
71+
72+ a if b(c) else d
73+
74+ into
75+
76+ a if __if_exp_0 else d
77+
78+ with Assign nodes in assignments [__if_exp_0 = b(c)]
79+ """
80+ 81+ def __init__ (self , starting_index = 0 ):
82+ self ._temporary_variable_index = starting_index
83+ self .assignments = []
84+ super ().__init__ ()
85+ 86+ def visit_IfExp (self , node ):
87+ if isinstance (node .test , (ast .Name , ast .Attribute )):
88+ return self .generic_visit (node )
89+ else :
90+ temp_var_id = '__if_exp_{}' .format (self ._temporary_variable_index )
91+ self ._temporary_variable_index += 1
92+ assignment_of_test = ast .Assign (
93+ targets = [ast .Name (id = temp_var_id , ctx = ast .Store ())],
94+ value = self .visit (node .test ),
95+ )
96+ ast .copy_location (assignment_of_test , node )
97+ self .assignments .append (assignment_of_test )
98+ transformed_if_exp = ast .IfExp (
99+ test = ast .Name (id = temp_var_id , ctx = ast .Load ()),
100+ body = self .visit (node .body ),
101+ orelse = self .visit (node .orelse ),
102+ )
103+ ast .copy_location (transformed_if_exp , node )
104+ return transformed_if_exp
105+ 106+ def visit_FunctionDef (self , node ):
107+ return node
108+ 109+ 110+ class IfExpTransformer :
111+ """Goes through module and function bodies, adding extra Assign nodes due to IfExp expressions."""
112+ 113+ def visit_body (self , nodes ):
114+ new_nodes = []
115+ count = 0
116+ for node in nodes :
117+ rewriter = IfExpRewriter (count )
118+ possibly_transformed_node = rewriter .visit (node )
119+ if rewriter .assignments :
120+ new_nodes .extend (rewriter .assignments )
121+ count += len (rewriter .assignments )
122+ new_nodes .append (possibly_transformed_node )
123+ return new_nodes
124+ 125+ def visit_FunctionDef (self , node ):
126+ transformed = ast .FunctionDef (
127+ name = node .name ,
128+ args = node .args ,
129+ body = self .visit_body (node .body ),
130+ decorator_list = node .decorator_list ,
131+ returns = node .returns
132+ )
133+ ast .copy_location (transformed , node )
134+ return self .generic_visit (transformed )
135+ 136+ def visit_Module (self , node ):
137+ transformed = ast .Module (self .visit_body (node .body ))
138+ ast .copy_location (transformed , node )
139+ return self .generic_visit (transformed )
140+ 141+ 142+ class PytTransformer (AsyncTransformer , IfExpTransformer , ChainedFunctionTransformer , ast .NodeTransformer ):
68143 pass
0 commit comments