I'm trying to get an efficient algorithm to calculate the height of a tree in Python for large datasets. The code I have works for small datasets, but takes a long time for really large ones (100,000 items) so I'm trying to figure out ways to optimize it but am getting stuck. Sorry if it seems like a really newbie question, I'm pretty new to Python.
The input is a list length and a list of values, with each list item pointing to its parent, with list item -1 indicating the root of the tree. So with an input of:
5
4 -1 4 1 1
The answer would be 3 — the tree is: ({key:1, children: [{key: 3}, {key:4, children:[{key:0, {key:2}]}] }
Here is the code that I have so far:
import sys, threading
sys.setrecursionlimit(10**7) # max depth of recursion
threading.stack_size(2**25) # new thread will get stack of such size
class TreeHeight:
def read(self):
self.n = int(sys.stdin.readline())
self.parent = list(map(int, sys.stdin.readline().split()))
def getChildren(self, node, nodes):
parent = {'key': node, 'children': []}
children = [i for i, x in enumerate(nodes) if x == parent['key']]
for child in children:
parent['children'].append(self.getChildren(child, nodes))
return parent
def compute_height(self, tree):
if len(tree['children']) == 0:
return 0
else:
max_values = []
for child in tree['children']:
max_values.append(self.compute_height(child))
return 1 + max(max_values)
def main():
tree = TreeHeight()
tree.read()
treeChild = tree.getChildren(-1, tree.parent)
print(tree.compute_height(treeChild))
threading.Thread(target=main).start()
4 Answers 4
You don't need to explicitly build the tree to compute its depth.
class TreeDepth(object):
def __init__(self, parents):
self._parents = parents
self._n = len(parents)
self._max_depth = None
self._depths = [None] * self._n
def max_depth(self):
if self._max_depth is None:
for idx, parent in enumerate(self._parents):
depth = self.get_depth(idx)
if self._max_depth < depth:
self._max_depth = depth
return self._max_depth
def get_depth(self, idx):
depth = self._depths[idx]
if depth is not None:
return depth
parent = self._parents[idx]
if parent == -1:
depth = 1
else:
depth = self.get_depth(parent) + 1
self._depths[idx] = depth
return depth
>>> TreeDepth([4, -1, 4, 1, 1]).max_depth()
3
If my math is good, this goes over each item in self._parents
at most twice, so it will have \$O(n)\$ performance. Also, I have used a recursive approach to compute the depth, but if speed is your main goal, you will probably want to turn that into an iterative approach:
def max_depth2(self):
if self._max_depth is not None:
return self._max_depth
for idx, parent in enumerate(self._parents):
parent_stack = []
while parent != -1 and self._depths[idx] is None:
parent_stack.append(idx)
idx, parent = parent, self._parents[parent]
if parent == -1:
depth = 1
else:
depth = self._depths[idx]
while parent_stack:
self._depths[parent_stack.pop()] = depth
depth += 1
if self._max_depth < depth:
self._max_depth = depth
return self._max_depth
-
\$\begingroup\$ @DanielChepenko
None < 1
returnsTrue
(and does so for basically every object, including zero and minus infinity). So this shouldn't crash in the first iteration. \$\endgroup\$Graipher– Graipher2018年09月07日 12:56:53 +00:00Commented Sep 7, 2018 at 12:56 -
\$\begingroup\$ @DanielChepenko ok, scratch that, this was true in Python 2, but is no longer true in Python 3. \$\endgroup\$Graipher– Graipher2018年09月07日 13:00:01 +00:00Commented Sep 7, 2018 at 13:00
The most efficient way of computing the height of a tree runs in linear time, and it looks like this:
class TreeNode:
def __init__(self):
self.left = None
self.right = None
def get_tree_height(root):
if root is None:
return -1
return max(get_tree_height(root.left), get_tree_height(root.right)) + 1
def main():
a = TreeNode()
b = TreeNode()
c = TreeNode()
d = TreeNode()
e = TreeNode()
a.left = b
b.left = c
c.right = d
b.right = e
print("Tree height:", get_tree_height(a))
if __name__ == "__main__":
main()
If you need to compute the height of a tree which allows more than 2 children, just extend the algorithm to call itself at all the child nodes, and return the maximum of them.
Hope that helps.
threading
So... you use multithreading to spawn 1 thread that execute a sequential function (meaning the function won't spawn any more threads) and do nothing in the meantime. That is, your main thread is just waiting for the thread executing the main()
function to complete. I’m unsure about why you would do something like that, but it's a waste of resources.
compute_height
Checking if a sequence is empty is usualy done by testing the sequence directly since an implicit call to bool()
is performed as needed. bool()
returns False
on an empty sequence and True
otherwise. You can thus write:
if tree['childreen']:
return 0
else:
...
You can also make use of generator expressions to avoir building a temporary list in memory:
max_value = max(self.compute_height(child) for child in tree['children'])
return 1 + max_value
And, last but not least, you can make use of the default
parameter of max()
to avoid checking for empty sequence:
def compute_height(self, tree):
children = tree['children']
return 1 + max((self.compute_height(c) for c in children), default=-1)
getChildren
First of, you should be consistent in your naming. Either use camelCase or snake_case but not both. PEP8 recommends snake_case.
Second, you could make use of list-comprehensions to simplify the writting and avoid calling append
:
parent['children'] = [
self.getChildren(child, nodes)
for child, node in enumerate(nodes)
if node == parent['key']
]
Since it is the only operation you do besides building the parent
dictionary, you could build this list and then build the dictionary:
def get_children(self, root_node, nodes):
children = [
self.get_children(child, nodes)
for child, node in enumerate(nodes)
if node == root_node
]
return {'key': root_node, 'children': children}
read
You can read from stdin
using the builtin input()
function. It will return a line of input at each call:
def read(self):
self.n = int(input())
self.parent = list(map(int, input().split()))
However, you’re never making use of self.n
in your program, so you might as well avoid storing it:
def read(self):
input() # throw away number of nodes
self.parent = list(map(int, input().split()))
TreeHeight
Building a class is generaly done to save a state and have utility functions that operates on said state. As regard to how you use it, I'd say you don't need a class. Just having a function building a tree and a function computing its height is enough. Your functions already have the right parameters to operate on without using the state associated to your objects.
You could thus simplify your code to:
def build_tree(root_node, nodes):
children = [
build_tree(child, nodes)
for child, node in enumerate(nodes)
if node == root_node
]
return {'key': root_node, 'children': children}
def compute_height(tree):
return 1 + max((compute_height(c) for c in tree['children']), default=-1)
def main():
input() # Throw away the number of nodes
tree = build_tree(-1, list(map(int, input().split())))
print(compute_height(tree))
if __name__ == '__main__':
main()
In some point my solution is similar to both @mathias-ettinger solution and @coderodde suggestions.
class Node:
def __init__(self, value, range):
self.value = value
self.range = range
self.childs = []
def addChilds(self, nodes):
self.childs = nodes
class TreeHeight:
def create_tree(self, node):
if node:
child_idx = [i for i,x in enumerate(self.parent) if x == node.range]
child_nodes = [self.nodes[i] for i in child_idx]
node.addChilds(child_nodes)
for child in node.childs:
self.create_tree(child)
def read(self):
self.n = int(sys.stdin.readline())
self.parent = list(map(int, sys.stdin.readline().split()))
self.nodes = [Node(vertex,i) for i, vertex in enumerate(self.parent)]
def get_height(self,node):
return 1 + max((self.get_height(c) for c in node.childs), default=0)
def compute_height(self):
# Replace this code with a faster implementation
root_idx = self.parent.index(-1)
if root_idx == -1:
return 0
root = self.nodes[root_idx]
child_idx = [i for i,x in enumerate(self.parent) if x == root_idx]
child_nodes = [self.nodes[i] for i in child_idx]
root.addChilds(child_nodes)
for child in root.childs:
self.create_tree(child)
return self.get_height(root)
def main():
tree = TreeHeight()
tree.read()
print(tree.compute_height())
threading.Thread(target=main).start()
But I agree with @jaime that you don't need to create a tree to compute it's height.
Explore related questions
See similar questions with these tags.