Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit f023eaa

Browse files
Merge pull request #173 from bcaller/recursion
Recursive function calls shouldn't raise RecursionError
2 parents c7b244d + 093f506 commit f023eaa

File tree

7 files changed

+62
-2
lines changed

7 files changed

+62
-2
lines changed

‎examples/vulnerable_code/recursive.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from flask import Flask, request
2+
3+
app = Flask(__name__)
4+
5+
6+
def recur_without_any_propagation(x):
7+
if len(x) < 20:
8+
return recur_without_any_propagation("a" * 24)
9+
return "Done"
10+
11+
12+
def recur_no_propagation_false_positive(x):
13+
if len(x) < 20:
14+
return recur_no_propagation_false_positive(x + "!")
15+
return "Done"
16+
17+
18+
def recur_with_propagation(x):
19+
if len(x) < 20:
20+
return recur_with_propagation(x + "!")
21+
return x
22+
23+
24+
@app.route('/recursive')
25+
def route():
26+
param = request.args.get('param', 'not set')
27+
repeated_completely_untainted = recur_without_any_propagation(param)
28+
app.db.execute(repeated_completely_untainted)
29+
repeated_untainted = recur_no_propagation_false_positive(param)
30+
app.db.execute(repeated_untainted)
31+
repeated_tainted = recur_with_propagation(param)
32+
app.db.execute(repeated_tainted)

‎pyt/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def main(command_line_args=sys.argv[1:]): # noqa: C901
125125
)
126126

127127
initialize_constraint_table(cfg_list)
128+
log.info("Analysing")
128129
analyse(cfg_list)
130+
log.info("Finding vulnerabilities")
129131
vulnerabilities = find_vulnerabilities(
130132
cfg_list,
131133
args.blackbox_mapping_file,

‎pyt/cfg/expr_visitor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import logging
23

34
from .alias_helper import handle_aliases_in_calls
45
from ..core.ast_helper import (
@@ -30,6 +31,8 @@
3031
from .stmt_visitor import StmtVisitor
3132
from .stmt_visitor_helper import CALL_IDENTIFIER
3233

34+
log = logging.getLogger(__name__)
35+
3336

3437
class ExprVisitor(StmtVisitor):
3538
def __init__(
@@ -52,6 +55,7 @@ def __init__(
5255
self.undecided = False
5356
self.function_names = list()
5457
self.function_return_stack = list()
58+
self.function_definition_stack = list() # used to avoid recursion
5559
self.module_definitions_stack = list()
5660
self.prev_nodes_to_avoid = list()
5761
self.last_control_flow_nodes = list()
@@ -543,6 +547,7 @@ def process_function(self, call_node, definition):
543547
first_node
544548
)
545549
self.function_return_stack.pop()
550+
self.function_definition_stack.pop()
546551

547552
return self.nodes[-1]
548553

@@ -560,11 +565,15 @@ def visit_Call(self, node):
560565
last_attribute = _id.rpartition('.')[-1]
561566

562567
if definition:
568+
if definition in self.function_definition_stack:
569+
log.debug("Recursion encountered in function %s", _id)
570+
return self.add_blackbox_or_builtin_call(node, blackbox=True)
563571
if isinstance(definition.node, ast.ClassDef):
564572
self.add_blackbox_or_builtin_call(node, blackbox=False)
565573
elif isinstance(definition.node, ast.FunctionDef):
566574
self.undecided = False
567575
self.function_return_stack.append(_id)
576+
self.function_definition_stack.append(definition)
568577
return self.process_function(node, definition)
569578
else:
570579
raise Exception('Definition was neither FunctionDef or ' +

‎pyt/web_frameworks/framework_adaptor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""A generic framework adaptor that leaves route criteria to the caller."""
22

33
import ast
4+
import logging
45

56
from ..cfg import make_cfg
67
from ..core.ast_helper import Arguments
@@ -10,6 +11,8 @@
1011
TaintedNode
1112
)
1213

14+
log = logging.getLogger(__name__)
15+
1316

1417
class FrameworkAdaptor():
1518
"""An engine that uses the template pattern to find all
@@ -31,6 +34,7 @@ def __init__(
3134

3235
def get_func_cfg_with_tainted_args(self, definition):
3336
"""Build a function cfg and return it, with all arguments tainted."""
37+
log.debug("Getting CFG for %s", definition.name)
3438
func_cfg = make_cfg(
3539
definition.node,
3640
self.project_modules,

‎tests/cfg/cfg_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .cfg_base_test_case import CFGBaseTestCase
44

55
from pyt.core.node_types import (
6+
BBorBInode,
67
EntryOrExitNode,
78
Node
89
)
@@ -1389,6 +1390,13 @@ def test_call_on_call(self):
13891390
path = 'examples/example_inputs/call_on_call.py'
13901391
self.cfg_create_from_file(path)
13911392

1393+
def test_recursive_function(self):
1394+
path = 'examples/example_inputs/recursive.py'
1395+
self.cfg_create_from_file(path)
1396+
recursive_call = self.cfg.nodes[7]
1397+
assert recursive_call.label == '~call_3 = ret_rec(wat)'
1398+
assert isinstance(recursive_call, BBorBInode) # Not RestoreNode
1399+
13921400

13931401
class CFGCallWithAttributeTest(CFGBaseTestCase):
13941402
def setUp(self):

‎tests/main_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ def test_targets_with_recursive(self):
108108
excluded_files = ""
109109

110110
included_files = discover_files(targets, excluded_files, True)
111-
self.assertEqual(len(included_files), 31)
111+
self.assertEqual(len(included_files), 32)
112112

113113
def test_targets_with_recursive_and_excluded(self):
114114
targets = ["examples/vulnerable_code/"]
115115
excluded_files = "inter_command_injection.py"
116116

117117
included_files = discover_files(targets, excluded_files, True)
118-
self.assertEqual(len(included_files), 30)
118+
self.assertEqual(len(included_files), 31)

‎tests/vulnerabilities/vulnerabilities_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,11 @@ def assert_vulnerable(fixture):
465465
assert_vulnerable('result = repr(str("%s" % TAINT.lower().upper()))')
466466
assert_vulnerable('result = repr(str("{}".format(TAINT.lower())))')
467467

468+
def test_recursion(self):
469+
# Really this file only has one vulnerability, but for now it's safer to keep the false positive.
470+
vulnerabilities = self.run_analysis('examples/vulnerable_code/recursive.py')
471+
self.assert_length(vulnerabilities, expected_length=2)
472+
468473

469474
class EngineDjangoTest(VulnerabilitiesBaseTestCase):
470475
def run_analysis(self, path):

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /