66#
77
88
9+ from os import fdopen
910from antlr4 import *
1011from solidity_parser .solidity_antlr4 .SolidityLexer import SolidityLexer
1112from solidity_parser .solidity_antlr4 .SolidityParser import SolidityParser
@@ -124,6 +125,27 @@ def visitEnumValue(self, ctx):
124125 type = "EnumValue" ,
125126 name = ctx .identifier ().getText ())
126127
128+ def visitTypeDefinition (self , ctx ):
129+ return Node (ctx = ctx ,
130+ type = "TypeDefinition" ,
131+ typeKeyword = ctx .TypeKeyword ().getText (),
132+ elementaryTypeName = self .visit (ctx .elementaryTypeName ()))
133+ 134+ 135+ def visitCustomErrorDefinition (self , ctx ):
136+ return Node (ctx = ctx ,
137+ type = "CustomErrorDefinition" ,
138+ name = self .visit (ctx .identifier ()),
139+ parameterList = self .visit (ctx .parameterList ()))
140+ 141+ def visitFileLevelConstant (self , ctx ):
142+ return Node (ctx = ctx ,
143+ type = "FileLevelConstant" ,
144+ name = self .visit (ctx .identifier ()),
145+ typeName = self .visit (ctx .typeName ()),
146+ ConstantKeyword = self .visit (ctx .ConstantKeyword ()))
147+ 148+ 127149 def visitUsingForDeclaration (self , ctx : SolidityParser .UsingForDeclarationContext ):
128150 typename = None
129151 if ctx .getChild (3 ) != '*' :
@@ -138,45 +160,29 @@ def visitInheritanceSpecifier(self, ctx: SolidityParser.InheritanceSpecifierCont
138160 return Node (ctx = ctx ,
139161 type = "InheritanceSpecifier" ,
140162 baseName = self .visit (ctx .userDefinedTypeName ()),
141- arguments = self .visit (ctx .expression ()))
163+ arguments = self .visit (ctx .expressionList ()))
142164
143165 def visitContractPart (self , ctx : SolidityParser .ContractPartContext ):
144166 return self .visit (ctx .children [0 ])
145167
146- def visitConstructorDefinition (self , ctx : SolidityParser .ConstructorDefinitionContext ):
147- parameters = self .visit (ctx .parameterList ())
148- block = self .visit (ctx .block ()) if ctx .block () else []
149- modifiers = [self .visit (i ) for i in ctx .modifierList ().modifierInvocation ()]
150168
151- if ctx .modifierList ().ExternalKeyword (0 ):
152- visibility = "external"
153- elif ctx .modifierList ().InternalKeyword (0 ):
154- visibility = "internal"
155- elif ctx .modifierList ().PublicKeyword (0 ):
156- visibility = "public"
157- elif ctx .modifierList ().PrivateKeyword (0 ):
158- visibility = "private"
159- else :
160- visibility = 'default'
161- 162- if ctx .modifierList ().stateMutability (0 ):
163- stateMutability = ctx .modifierList ().stateMutability (0 ).getText ()
169+ def visitFunctionDefinition (self , ctx : SolidityParser .FunctionDefinitionContext ):
170+ isConstructor = isFallback = isReceive = False
171+ 172+ fd = ctx .functionDescriptor ()
173+ if fd .ConstructorKeyword ():
174+ name = fd .ConstructorKeyword ().getText ()
175+ isConstructor = True
176+ elif fd .FallbackKeyword ():
177+ name = fd .FallbackKeyword ().getText ()
178+ isFallback = True
179+ elif fd .ReceiveKeyword ():
180+ name = fd .ReceiveKeyword ().getText ()
181+ isReceive = True
182+ elif fd .identifier ():
183+ name = fd .identifier ().getText ()
164184 else :
165- stateMutability = None
166- 167- return Node (ctx = ctx ,
168- type = "FunctionDefinition" ,
169- name = None ,
170- parameters = parameters ,
171- returnParameters = None ,
172- body = block ,
173- visibility = visibility ,
174- modifiers = modifiers ,
175- isConstructor = True ,
176- stateMutability = stateMutability )
177- 178- def visitFunctionDefinition (self , ctx : SolidityParser .ConstructorDefinitionContext ):
179- name = ctx .identifier ().getText () if ctx .identifier () else ""
185+ raise Exception ("unexpected function descriptor" )
180186
181187 parameters = self .visit (ctx .parameterList ())
182188 returnParameters = self .visit (ctx .returnParameters ()) if ctx .returnParameters () else []
@@ -207,7 +213,9 @@ def visitFunctionDefinition(self, ctx: SolidityParser.ConstructorDefinitionConte
207213 body = block ,
208214 visibility = visibility ,
209215 modifiers = modifiers ,
210- isConstructor = name == self ._currentContract ,
216+ isConstructor = isConstructor ,
217+ isFallback = isFallback ,
218+ isReceive = isReceive ,
211219 stateMutability = stateMutability )
212220
213221 def visitReturnParameters (self , ctx : SolidityParser .ReturnParametersContext ):
@@ -393,6 +401,21 @@ def visitIfStatement(self, ctx):
393401 TrueBody = TrueBody ,
394402 FalseBody = FalseBody )
395403
404+ def visitTryStatement (self , ctx ):
405+ return Node (ctx = ctx ,
406+ type = 'TryStatement' ,
407+ expression = self .visit (ctx .expression ()),
408+ block = self .visit (ctx .block ()),
409+ returnParameters = self .visit (ctx .returnParameters ()),
410+ catchClause = self .visit (ctx .catchClause ()))
411+ 412+ def visitCatchClause (self , ctx ):
413+ return Node (ctx = ctx ,
414+ type = 'CatchClause' ,
415+ identifier = self .visit (ctx .identifier ()),
416+ parameterList = self .visit (ctx .parameterList ()),
417+ block = self .visit (ctx .block ()))
418+ 396419 def visitUserDefinedTypeName (self , ctx ):
397420 return Node (ctx = ctx ,
398421 type = 'UserDefinedTypeName' ,
@@ -428,7 +451,7 @@ def visitNumberLiteral(self, ctx):
428451 def visitMapping (self , ctx ):
429452 return Node (ctx = ctx ,
430453 type = 'Mapping' ,
431- keyType = self .visit (ctx .elementaryTypeName ()),
454+ keyType = self .visit (ctx .mappingKey ()),
432455 valueType = self .visit (ctx .typeName ()))
433456
434457 def visitModifierDefinition (self , ctx ):
@@ -449,6 +472,16 @@ def visitStatement(self, ctx):
449472 def visitSimpleStatement (self , ctx ):
450473 return self .visit (ctx .getChild (0 ))
451474
475+ def visitUncheckedStatement (self , ctx ):
476+ return Node (ctx = ctx ,
477+ type = 'UncheckedStatement' ,
478+ body = self .visit (ctx .block ()))
479+ 480+ def visitRevertStatement (self , ctx ):
481+ return Node (ctx = ctx ,
482+ type = 'RevertStatement' ,
483+ functionCall = self .visit (ctx .functionCall ()))
484+ 452485 def visitExpression (self , ctx ):
453486
454487 children_length = len (ctx .children )
@@ -641,16 +674,15 @@ def visitPrimaryExpression(self, ctx):
641674 type = 'BooleanLiteral' ,
642675 value = ctx .BooleanLiteral ().getText () == 'true' )
643676
644- if ctx .HexLiteral ():
677+ if ctx .hexLiteral ():
645678 return Node (ctx = ctx ,
646- type = 'HexLiteral ' ,
647- value = ctx .HexLiteral ().getText ())
679+ type = 'hexLiteral ' ,
680+ value = ctx .hexLiteral ().getText ())
648681
649- if ctx .StringLiteral ():
682+ if ctx .stringLiteral ():
650683 text = ctx .getText ()
651- 652684 return Node (ctx = ctx ,
653- type = 'StringLiteral ' ,
685+ type = 'stringLiteral ' ,
654686 value = text [1 : len (text ) - 1 ])
655687
656688 if len (ctx .children ) == 3 and ctx .getChild (1 ).getText () == '[' and ctx .getChild (2 ).getText () == ']' :
@@ -737,32 +769,6 @@ def visitVariableDeclarationStatement(self, ctx):
737769 variables = variables ,
738770 initialValue = initialValue )
739771
740- def visitImportDirective (self , ctx ):
741- pathString = ctx .StringLiteral ().getText ()
742- unitAlias = None
743- symbolAliases = None
744- 745- impDecLen = len (ctx .importDeclaration ())
746- if impDecLen > 0 :
747- symbolAliases = []
748- for decl in ctx .importDeclaration ():
749- symbol = decl .identifier (0 ).getText ()
750- alias = None
751- if decl .identifier (1 ):
752- alias = decl .identifier (1 ).getText ()
753- 754- symbolAliases .append ([symbol , alias ])
755- elif impDecLen == 7 :
756- unitAlias = ctx .getChild (3 ).getText ()
757- elif impDecLen == 5 :
758- unitAlias = ctx .getChild (3 ).getText ()
759- 760- return Node (ctx = ctx ,
761- type = 'ImportDirective' ,
762- path = pathString [1 : len (pathString ) - 1 ],
763- unitAlias = unitAlias ,
764- symbolAliases = symbolAliases )
765- 766772 def visitEventDefinition (self , ctx ):
767773 return Node (ctx = ctx ,
768774 type = 'EventDefinition' ,
@@ -792,8 +798,8 @@ def visitEventParameterList(self, ctx):
792798 def visitInlineAssemblyStatement (self , ctx ):
793799 language = None
794800
795- if ctx .StringLiteral ():
796- language = ctx .StringLiteral ().getText ()
801+ if ctx .StringLiteralFragment ():
802+ language = ctx .StringLiteralFragment ().getText ()
797803 language = language [1 : len (language ) - 1 ]
798804
799805 return Node (ctx = ctx ,
@@ -810,13 +816,13 @@ def visitAssemblyBlock(self, ctx):
810816
811817 def visitAssemblyItem (self , ctx ):
812818
813- if ctx .HexLiteral ():
819+ if ctx .hexLiteral ():
814820 return Node (ctx = ctx ,
815821 type = 'HexLiteral' ,
816- value = ctx .HexLiteral ().getText ())
822+ value = ctx .hexLiteral ().getText ())
817823
818- if ctx .StringLiteral ():
819- text = ctx .StringLiteral ().getText ()
824+ if ctx .stringLiteral ():
825+ text = ctx .stringLiteral ().getText ()
820826 return Node (ctx = ctx ,
821827 type = 'StringLiteral' ,
822828 value = text [1 : len (text ) - 1 ])
@@ -834,6 +840,11 @@ def visitAssemblyItem(self, ctx):
834840 def visitAssemblyExpression (self , ctx ):
835841 return self .visit (ctx .getChild (0 ))
836842
843+ def visitAssemblyMember (self , ctx ):
844+ return Node (ctx = ctx ,
845+ type = 'AssemblyMember' ,
846+ name = ctx .identifier ().getText ())
847+ 837848 def visitAssemblyCall (self , ctx ):
838849 functionName = ctx .getChild (0 ).getText ()
839850 args = [self .visit (arg ) for arg in ctx .assemblyExpression ()]
@@ -845,7 +856,7 @@ def visitAssemblyCall(self, ctx):
845856
846857 def visitAssemblyLiteral (self , ctx ):
847858
848- if ctx .StringLiteral ():
859+ if ctx .stringLiteral ():
849860 text = ctx .getText ()
850861 return Node (ctx = ctx ,
851862 type = 'StringLiteral' ,
@@ -861,7 +872,7 @@ def visitAssemblyLiteral(self, ctx):
861872 type = 'HexNumber' ,
862873 value = ctx .getText ())
863874
864- if ctx .HexLiteral ():
875+ if ctx .hexLiteral ():
865876 return Node (ctx = ctx ,
866877 type = 'HexLiteral' ,
867878 value = ctx .getText ())
@@ -981,7 +992,7 @@ def visitImportDirective(self, ctx):
981992
982993 return Node (ctx = ctx ,
983994 type = "ImportDirective" ,
984- path = ctx .StringLiteral ().getText ().strip ('"' ),
995+ path = ctx .importPath ().getText ().strip ('"' ),
985996 symbolAliases = symbol_aliases ,
986997 unitAlias = unit_alias
987998 )
@@ -1106,10 +1117,6 @@ def visitStructDefinition(self, _node):
11061117 self .structs [_node .name ]= _node
11071118 self .names [_node .name ]= _node
11081119
1109- def visitConstructorDefinition (self , _node ):
1110- self .constructor = _node
1111- 1112- 11131120 def visitStateVariableDeclaration (self , _node ):
11141121
11151122 class VarDecVisitor (object ):
@@ -1150,10 +1157,15 @@ def __init__(self, node):
11501157 if (node .type == "FunctionDefinition" ):
11511158 self .visibility = node .visibility
11521159 self .stateMutability = node .stateMutability
1160+ self .isConstructor = node .isConstructor
1161+ self .isFallback = node .isFallback
1162+ self .isReceive = node .isReceive
11531163 self .arguments = {}
11541164 self .returns = {}
11551165 self .declarations = {}
11561166 self .identifiers = []
1167+ 1168+ 11571169
11581170 class FunctionArgumentVisitor (object ):
11591171
@@ -1182,13 +1194,14 @@ def visitIdentifier(self, __node):
11821194 def visitAssemblyCall (self , __node ):
11831195 self .idents .append (__node )
11841196
1185- 11861197 current_function = FunctionObject (_node )
11871198 self .names [_node .name ] = current_function
11881199 if _definition_type == "ModifierDefinition" :
11891200 self .modifiers [_node .name ] = current_function
11901201 else :
11911202 self .functions [_node .name ] = current_function
1203+ if current_function .isConstructor :
1204+ self .constructor = current_function
11921205
11931206 ## get parameters
11941207 funcargvisitor = FunctionArgumentVisitor ()
0 commit comments