I implemented a binary expression tree, I've used my previous shunting-yard parser to provide it the post-fix expression.
I have a few specific questions outside of just comments about making it more pythonic:
Should I check the type when assigning
left
andright
nodes? I read that type checking is usually discouraged (https://stackoverflow.com/a/6725913/1693004), but theNode
class is inside the binary_expression_tree module.Should the
BinaryExpressionTree
constructor take in the post-fix token expression from the shunting yard parser, or should it take in an infix expression and convert it internally?
Code:
import shunting_yard_parser
class Node:
_left = None
_right = None
_token = None
def __init__(self, token):
self._token = token
@property
def left(self):
return self._left
@left.setter
def left(self, value):
if not isinstance(value, Node):
raise TypeError("Left node must be of type Node")
self._left = value
@property
def right(self):
return self._right
@right.setter
def right(self, value):
if not isinstance(value, Node):
raise TypeError("Right node must be of type Node")
self._right = value
@property
def token(self):
return self._token
class BinaryExpressionTree:
_root = None
def __init__(self, postfix_tokens):
stack = list()
for token in postfix_tokens:
if token.kind == shunting_yard_parser.TokenSpecification.OPERAND:
stack.append(Node(token))
if token.kind == shunting_yard_parser.TokenSpecification.OPERATOR:
operator_node = Node(token)
if len(stack) < 2:
raise TypeError('Incorrectly formatted expression - check operators.')
operator_node.right = stack.pop()
operator_node.left = stack.pop()
stack.append(operator_node)
if len(stack) > 1:
raise TypeError('Incorrectly formatted expression - check operands.')
self._root = stack.pop()
def solve(self):
assert self._root is not None
return self._solve(self._root)
@staticmethod
def _solve(node):
if node.token.kind == shunting_yard_parser.TokenSpecification.OPERAND:
return int(node.token.value)
left_value = BinaryExpressionTree._solve(node.left)
right_value = BinaryExpressionTree._solve(node.right)
operator = node.token.value
if operator == '+':
return left_value + right_value
if operator == '-':
return left_value - right_value
if operator == '*':
return left_value * right_value
if operator == '/':
return left_value / right_value
if __name__ == "__main__":
while True:
infix_expression = input('Enter infix expression:')
postfix_expression = ""
try:
postfix_expression = shunting_yard_parser.parse(infix_expression)
except ValueError as e:
print(e)
if postfix_expression != "":
binaryExpressionTree = BinaryExpressionTree(postfix_expression)
solution = binaryExpressionTree.solve()
print(solution)
Unit tests:
from unittest import TestCase
import binary_expression_tree
import shunting_yard_parser
class TestBinaryExpressionTree(TestCase):
def test_only_operands(self):
# arrange
postfix_expression = [shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERAND, '2'),
shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERAND, '5')]
# act & assert
self.assertRaises(TypeError, lambda: binary_expression_tree.BinaryExpressionTree(postfix_expression))
def test_only_operators(self):
# arrange
postfix_expression = [shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERATOR, '+'),
shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERATOR, '*')]
# act & assert
self.assertRaises(TypeError, lambda: binary_expression_tree.BinaryExpressionTree(postfix_expression))
def test_unbalanced_expression(self):
# arrange
postfix_expression = [shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERAND, '2'),
shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERATOR, '+'),
shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERATOR, '*')]
# act & assert
self.assertRaises(TypeError, lambda: binary_expression_tree.BinaryExpressionTree(postfix_expression))
def test_balanced_expression(self):
# arrange
postfix_expression = [shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERAND, '2'),
shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERAND, '5'),
shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERATOR, '+'),
shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERAND, '10'),
shunting_yard_parser.Token(shunting_yard_parser.TokenSpecification.OPERATOR, '*')]
# act
tree = binary_expression_tree.BinaryExpressionTree(postfix_expression)
# assert
result = tree.solve()
self.assertEqual(70, result)
Shunting-yard algorithm (only for reference, already reviewed):
from collections import namedtuple, deque
from enum import Enum
import re
class TokenSpecification(Enum):
PARENTHESES = 1
OPERAND = 2
OPERATOR = 3
IGNORE = 4
JUNK = 5
class Token(namedtuple("Token", "kind value")):
_operatorPrecedence = {"*": 3,
"/": 3,
"+": 2,
"-": 2}
@property
def precedence(self):
if self.kind == TokenSpecification.OPERATOR:
return self._operatorPrecedence[self.value]
else:
raise TypeError("")
def _tokenize(expression):
token_specification = [
(TokenSpecification.PARENTHESES.name, r'[()]'),
(TokenSpecification.OPERAND.name, r'\d+'),
(TokenSpecification.OPERATOR.name, r'[+\-*/]'),
(TokenSpecification.IGNORE.name, r'\s+'),
(TokenSpecification.JUNK.name, r'\S+?\b')
]
tokenizer_regex = '|'.join('(?P<{kind}>{pattern})'.format(kind=kind, pattern=pattern)
for kind, pattern in token_specification)
tokenizer = re.compile(tokenizer_regex)
for match in tokenizer.finditer(expression):
kind = match.lastgroup
value = match.group(kind)
if kind == TokenSpecification.JUNK:
raise ValueError('Unrecognized token: {0}'.format(value))
elif kind != TokenSpecification.IGNORE:
yield Token(TokenSpecification[kind], value)
def parse(infix_expression):
operators = deque()
for token in _tokenize(infix_expression):
if token.kind == TokenSpecification.OPERAND:
yield token
elif token.kind == TokenSpecification.OPERATOR:
while (operators
and operators[0].value != '('
and operators[0].precedence >= token.precedence):
yield operators.popleft()
operators.appendleft(token)
elif token.value == '(':
operators.appendleft(token)
elif token.value == ')':
# Pop all the operators in front of the "(".
while operators and operators[0].value != '(':
yield operators.popleft()
# The previous operation would have removed all the operators
# because there is no matching opening parenthesises.
if not operators:
raise ValueError('Unmatched )')
# Remove matching parenthesis.
operators.popleft()
for operator in operators:
# If there are still opening parenthesises in the stack this means
# we haven't found a matching closing one.
if operator.value == '(':
raise ValueError('Unmatched (')
yield operator
if __name__ == "__main__":
while True:
action = input('Enter infix expression:')
try:
postfix_expression = parse(action)
print([op.value for op in postfix_expression])
except ValueError as e:
print(e.args)