I have a CSV holding a "flat" table of a tree's edges (NOT binary, but a node cannot have two parents), ~1M edges:
node_id parent_id
1 0
2 1
3 1
4 2
...
The nodes are sorted in a way that a parent_id
must always come before any of its children, so a parent_id
will always be lower than node_id
.
I wish, for each node_id, to get the set of all ancestor nodes (including itself, propagated until root which is node 0 here), and a set of all descendant nodes (including itself, propagated until leaves), and speed is crucial.
Currently what I do at high level:
- Read the CSV in pandas, call it
nodes_df
- Iterate once through
nodes_df
to getnode_ancestors
, a{node_id: set(ancestors)}
dict adding for each node's ancestors itself and its parent's ancestors (which I know I have seen all by then) - Iterate through
nodes_df
again in reverse order to getnode_descendants
, a{node_id: set(ancestors)}
dict adding for each node's descendants itself and its child's descendants (which I know I have seen all by then)
import pandas as pd
from collections import defaultdict
# phase 1
nodes_df = pd.read_csv('input.csv')
# phase 2
node_ancestors = defaultdict(set)
node_ancestors[0] = set([0])
for id, ndata in nodes_df1.iterrows():
node_ancestors[ndata['node_id']].add(ndata['node_id'])
node_ancestors[ndata['node_id']].update(node_ancestors[ndata['parent_id']])
# phase 3
node_descendants = defaultdict(set)
node_descendants[0] = set([0])
for id, ndata in nodes_df1[::-1].iterrows():
node_descendants[ndata['node_id']].add(ndata['node_id'])
node_descendants[ndata['parent_id']].\
update(node_descendants[ndata['node_id']])
So, this takes dozens of seconds on my laptop, which is ages for my application. How do I improve?
Plausible directions:
- Can I use pandas better? Can I get
node_ancestors
and/ornode_descendants
by some clever join which is out of my league? - Can I use a python graph library like
Networkx
origraph
(which in my experience is faster on large graphs)? E.g. in both libraries I have aget_all_shortest_paths
methods, which returns something like a{node_id: dist}
dictionary, from which I could select the keys, but... I need this for every node, so again a long long loop - Parallelizing - no idea how to do this
1 Answer 1
id
you shadow the builtin id
with this name as variable
itertuples
a way to improve performance is using itertuples
to iterate over the DataFrame
: for _, node, parent in df.itertuples():
iterations
You can do this in 1 iteration over the input with a nested loop over the ancestors:
node_ancestors = defaultdict(set)
node_ancestors[0] = set([0])
node_descendants = defaultdict(set)
node_descendants[0] = set([0])
for _, node, parent in df.itertuples():
node_ancestors[node].add(node)
node_ancestors[node].update(node_ancestors[parent])
for ancestor in node_ancestors[node]:
node_descendants[ancestor].add(node)
Depending on how nested the tree is, this will be faster or slower than iterating over the whole input twice. You'll need to test it on your dataset.
global vs local
another speedup might be achieved by doing this in a function instead of the global namespace (explanation)
def parse_tree(df):
node_ancestors = defaultdict(set)
node_ancestors[0] = set([0])
node_descendants = defaultdict(set)
node_descendants[0] = set([0])
for _, node, parent in df.itertuples():
node_ancestors[node].add(node)
node_ancestors[node].update(node_ancestors[parent])
for ancestor in node_ancestors[node]:
node_descendants[ancestor].add(node)
return node_ancestors, node_descendants
-
\$\begingroup\$ Not
parents[node] = parent
isn't necessary. \$\endgroup\$Giora Simchoni– Giora Simchoni2019年01月23日 06:47:14 +00:00Commented Jan 23, 2019 at 6:47
Explore related questions
See similar questions with these tags.