Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8c95773

Browse files
bcallerBen Caller
authored and
Ben Caller
committed
Chained function calls separated into multiple assignments
Take the example from examples/vulnerable_code/sql/sqli.py: `result = session.query(User).filter("username={}".format(TAINT))` The `filter` function is marked as a sink. However, previously this did not get marked as a vulnerability. The call label used to be `session.query`, ignoring the filter function. Now, when the file is read, it is transformed into 2 lines: ``` __chain_tmp_1 = session.query(User) result = __chain_tmp_1.filter("username={}".format(TAINT)) ``` And the vulnerability is found. We don't find everything here: just ordinary assignments and return statements. We can't just transform all Call nodes here since Call nodes can appear in many different scenarios e.g. comprehensions, bare function calls.
1 parent 11567c4 commit 8c95773

File tree

6 files changed

+110
-7
lines changed

6 files changed

+110
-7
lines changed

‎pyt/core/ast_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import subprocess
77
from functools import lru_cache
88

9-
from .transformer import AsyncTransformer
9+
from .transformer import PytTransformer
1010

1111

1212
BLACK_LISTED_CALL_NAMES = ['self']
@@ -35,7 +35,7 @@ def generate_ast(path):
3535
with open(path, 'r') as f:
3636
try:
3737
tree = ast.parse(f.read())
38-
return AsyncTransformer().visit(tree)
38+
return PytTransformer().visit(tree)
3939
except SyntaxError: # pragma: no cover
4040
global recursive
4141
if not recursive:

‎pyt/core/transformer.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ast
22

33

4-
class AsyncTransformer(ast.NodeTransformer):
4+
class AsyncTransformer():
55
"""Converts all async nodes into their synchronous counterparts."""
66

77
def visit_Await(self, node):
@@ -16,3 +16,55 @@ def visit_AsyncFor(self, node):
1616

1717
def visit_AsyncWith(self, node):
1818
return self.visit(ast.With(**node.__dict__))
19+
20+
21+
class ChainedFunctionTransformer():
22+
def visit_chain(self, node, depth=1):
23+
if (
24+
isinstance(node.value, ast.Call) and
25+
isinstance(node.value.func, ast.Attribute) and
26+
isinstance(node.value.func.value, ast.Call)
27+
):
28+
call_node = node.value
29+
# If we want to handle nested functions in future, depth needs fixing
30+
temp_var_id = '__chain_tmp_{}'.format(depth)
31+
unvisited_inner_call = ast.Assign(
32+
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
33+
value=call_node.func.value,
34+
)
35+
ast.copy_location(unvisited_inner_call, node)
36+
inner_calls = self.visit_chain(unvisited_inner_call, depth + 1)
37+
inner_calls = inner_calls if isinstance(inner_calls, list) else [inner_calls]
38+
for inner_call_node in inner_calls:
39+
ast.copy_location(inner_call_node, node)
40+
outer_call = self.generic_visit(type(node)(
41+
value=ast.Call(
42+
func=ast.Attribute(
43+
value=ast.Name(id=temp_var_id, ctx=ast.Load()),
44+
attr=call_node.func.attr,
45+
ctx=ast.Load(),
46+
),
47+
args=call_node.args,
48+
keywords=call_node.keywords,
49+
),
50+
**{field: value for field, value in ast.iter_fields(node) if field != 'value'}
51+
))
52+
ast.copy_location(outer_call, node)
53+
ast.copy_location(outer_call.value, node)
54+
ast.copy_location(outer_call.value.func, node)
55+
return [
56+
*inner_calls,
57+
outer_call,
58+
]
59+
else:
60+
return self.generic_visit(node)
61+
62+
def visit_Assign(self, node):
63+
return self.visit_chain(node)
64+
65+
def visit_Return(self, node):
66+
return self.visit_chain(node)
67+
68+
69+
class PytTransformer(AsyncTransformer, ChainedFunctionTransformer, ast.NodeTransformer):
70+
pass

‎tests/base_test_case.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pyt.cfg import make_cfg
55
from pyt.core.ast_helper import generate_ast
66
from pyt.core.module_definitions import project_definitions
7+
from pyt.core.transformer import PytTransformer
78

89

910
class BaseTestCase(unittest.TestCase):
@@ -36,7 +37,7 @@ def cfg_create_from_ast(
3637
):
3738
project_definitions.clear()
3839
self.cfg = make_cfg(
39-
ast_tree,
40+
PytTransformer().visit(ast_tree),
4041
project_modules,
4142
local_modules,
4243
filename='?'

‎tests/cfg/cfg_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,3 +1497,35 @@ def test_name_for(self):
14971497

14981498
self.assert_length(self.cfg.nodes, expected_length=4)
14991499
self.assertEqual(self.cfg.nodes[1].label, 'for x in l:')
1500+
1501+
1502+
class CFGFunctionChain(CFGBaseTestCase):
1503+
def test_simple(self):
1504+
self.cfg_create_from_ast(
1505+
ast.parse('a = b.c(z)')
1506+
)
1507+
middle_nodes = self.cfg.nodes[1:-1]
1508+
self.assert_length(middle_nodes, expected_length=2)
1509+
self.assertEqual(middle_nodes[0].label, '~call_1 = ret_b.c(z)')
1510+
self.assertEqual(middle_nodes[0].func_name, 'b.c')
1511+
self.assertCountEqual(middle_nodes[0].right_hand_side_variables, ['z', 'b'])
1512+
1513+
def test_chain(self):
1514+
self.cfg_create_from_ast(
1515+
ast.parse('a = b.xxx.c(z).d(y)')
1516+
)
1517+
middle_nodes = self.cfg.nodes[1:-1]
1518+
self.assert_length(middle_nodes, expected_length=4)
1519+
1520+
self.assertEqual(middle_nodes[0].left_hand_side, '~call_1')
1521+
self.assertCountEqual(middle_nodes[0].right_hand_side_variables, ['b', 'z'])
1522+
self.assertEqual(middle_nodes[0].label, '~call_1 = ret_b.xxx.c(z)')
1523+
1524+
self.assertEqual(middle_nodes[1].left_hand_side, '__chain_tmp_1')
1525+
self.assertCountEqual(middle_nodes[1].right_hand_side_variables, ['~call_1'])
1526+
1527+
self.assertEqual(middle_nodes[2].left_hand_side, '~call_2')
1528+
self.assertCountEqual(middle_nodes[2].right_hand_side_variables, ['__chain_tmp_1', 'y'])
1529+
1530+
self.assertEqual(middle_nodes[3].left_hand_side, 'a')
1531+
self.assertCountEqual(middle_nodes[3].right_hand_side_variables, ['~call_2'])

‎tests/core/transformer_test.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import ast
22
import unittest
33

4-
from pyt.core.transformer import AsyncTransformer
4+
from pyt.core.transformer import PytTransformer
55

66

77
class TransformerTest(unittest.TestCase):
88
"""Tests for the AsyncTransformer."""
99

1010
def test_async_removed_by_transformer(self):
11+
self.maxDiff = 99999
1112
async_tree = ast.parse("\n".join([
1213
"async def a():",
1314
" async for b in c():",
@@ -30,7 +31,24 @@ def test_async_removed_by_transformer(self):
3031
]))
3132
self.assertIsInstance(sync_tree.body[0], ast.FunctionDef)
3233

33-
transformed = AsyncTransformer().visit(async_tree)
34+
transformed = PytTransformer().visit(async_tree)
3435
self.assertIsInstance(transformed.body[0], ast.FunctionDef)
3536

3637
self.assertEqual(ast.dump(transformed), ast.dump(sync_tree))
38+
39+
def test_chained_function(self):
40+
chained_tree = ast.parse("\n".join([
41+
"def a():",
42+
" b = c.d(e).f(g).h(i).j(k)",
43+
]))
44+
45+
separated_tree = ast.parse("\n".join([
46+
"def a():",
47+
" __chain_tmp_3 = c.d(e)",
48+
" __chain_tmp_2 = __chain_tmp_3.f(g)",
49+
" __chain_tmp_1 = __chain_tmp_2.h(i)",
50+
" b = __chain_tmp_1.j(k)",
51+
]))
52+
53+
transformed = PytTransformer().visit(chained_tree)
54+
self.assertEqual(ast.dump(transformed), ast.dump(separated_tree))

‎tests/vulnerabilities/vulnerabilities_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def test_path_traversal_sanitised_2_result(self):
282282

283283
def test_sql_result(self):
284284
vulnerabilities = self.run_analysis('examples/vulnerable_code/sql/sqli.py')
285-
self.assert_length(vulnerabilities, expected_length=2)
285+
self.assert_length(vulnerabilities, expected_length=3)
286286
vulnerability_description = str(vulnerabilities[0])
287287
EXPECTED_VULNERABILITY_DESCRIPTION = """
288288
File: examples/vulnerable_code/sql/sqli.py

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /