Background
I'm implementing an algorithm which localises overlapping sound sources using the time-difference-of-arrivals across 3 sensors [1]. Compatible 'triples' (e.g. with only 2 common sensors) need to be joined up to make larger N sensor graphs. The compatible triples are found using a recursive routine (```combine_all``) until no new solutions are found.
The routine runs well when there are small # of triples, but I've empirically observed for every 10X increase in triple number, there's a ~10^5.5 increase in run-time (using linear regression) - which is problematic (meaning up to 12 hours of run-time on my actual data).
Reference
[1] Kreissig & Yang 2013, Fast and reliable TDOA assignment in multi-source reverberant environments, ICASSP 2013 paper link.
Code and initial profiling
The combine_all
routine accepts the compatibility-conflict matrix describing the compatibility/conflict of all triple pairs. From the currently available triples - their in/compatibility to the current solution are checked and kept or eliminated. The compatible nodes are checked by get_Nvl
, and incompatible nodes are checked by get_not_Nvl
.
Initial profiling with iPython %lprun
told me the get_Nvl
and get_not_Nvl
is where ~80% (40% and 40% each) is spent in combine_all
- and I've tried my best to optimise the current code with no luck.
import numpy as np
from itertools import chain, product
def combine_all(Acc, V, l, X):
'''
Parameters
----------
Acc : (N_triples,N_triples) np.array
The compatibility-conflict graph. Value of 1 means compatible node pair, -1 is incompatible, 0 is undefined.
V_t : set
V_tilda. The currently considered vertices (a sub-set of V, all vertices)
l : set
The solution consisting of the currently compatible vertices.
Returns
-------
solutions_l : list with sublists
A somewhat messy output - must be run through ```format_combineall``` to get
nice output.
'''
# determine N_v(l) and !N_v(l)
# !N_v(l) are the vertices incompatible with the current solution
N_vl = get_Nvl(Acc, V, l)
N_not_vl = get_NOT_Nvl(Acc, V, l)
solutions_l = []
if len(N_vl) == 0:
solutions_l.append(l)
else:
# remove conflicting neighbour
V = V.difference(N_not_vl)
# unvisited compatible neighbours
Nvl_wo_X = N_vl.difference(X)
for n in Nvl_wo_X:
Vx = V.difference(set([n]))
lx = l.union(set([n]))
solution = combine_all(Acc, Vx, lx, X)
if solution:
solutions_l.append(solution)
X = X.union(set([n]))
return solutions_l
def get_Nvl(Acc, V_t, l):
'''Checks for compatible vertices
Essentially, two for loops run across the
available triples (v in V_t) and the current solution set (u in l)
- to access the ```Acc[v,u]``` entries. If all ```Acc[v,u]``` values
are +1, then ```v``` is compatible with the current solution ```l```.
If any of the ```Acc[v,u]``` values is -1 then that ```v``` is not compatible anymore.
Returns
-------
Nvl : set
Solution of vertices that are compatible to at least one other vertex
and not in conflict with any of the other vertices.
'''
Nvl = []
if len(l)>0:
for v in V_t:
for u in l:
if Acc[v,u]==1:
Nvl.append(v)
elif Acc[v,u]==-1:
if v in Nvl:
Nvl.pop(Nvl.index(v))
return set(Nvl)
else:
return V_t
def get_Nvl_fast(Acc, V_t, l):
'''See CombineAll for docs
'''
if len(l)>0:
all_uv = np.array(np.meshgrid(V_t, l)).T.reshape(-1,2)
def get_acc(X):
return Acc[X[0], X[1]]
Acc_values = np.apply_along_axis(get_acc, 1, all_uv)
rows_w_min1 = np.where(Acc_values<0)
v_vals_w_conflicts = np.unique(all_uv[rows_w_min1,0])
Nvl = np.setdiff1d(V_t, v_vals_w_conflicts)
return Nvl
else:
return V_t
def get_NOT_Nvl(Acc:np.array, V:set, l:set):
N_not_vl = []
if len(l)>0:
for v in V:
for u in l:
if Acc[v,u]==-1:
N_not_vl.append(v)
elif Acc[v,u]==1:
if v in N_not_vl:
N_not_vl.pop(N_not_vl.index(v))
else:
N_not_vl = []
return set(N_not_vl)
# ---- not performance related section -- formats output into easily readable
# form
def format_combineall(output):
semiflat = flatten_combine_all(output)
only_sets = []
for each in semiflat:
if isinstance(each, list):
for every in each:
if isinstance(every, set):
only_sets.append(every)
elif isinstance(each, set):
only_sets.append(each)
return only_sets
def flatten_combine_all(entry):
if isinstance(entry, list):
if len(entry)==1:
return flatten_combine_all(entry[0])
else:
return list(map(flatten_combine_all, entry))
elif isinstance(entry, set):
return entry
else:
raise ValueError(f'{entry} can only be set or list')
if __name__ == '__main__':
# compatibility-conflict graph from [1]
A = np.array([[ 0, 1, 0, 0,-1,-1],
[ 1, 0, 1, 1, 0, 1],
[ 0, 1, 0,-1, 1, 0],
[ 0, 1,-1, 0,-1, 0],
[-1, 0, 1,-1, 0, 1],
[-1, 1, 0, 0, 1, 0]])
qq = combine_all(A, set(range(6)), set([]), set([]))
neat_output = format_combineall(qq)
# Expected answer:
# >>> print(neat_output)
# >>> [{0, 1, 2}, {0, 1, 3}, {1, 2, 4, 5}, {1, 3, 5}]
No luck with optimisation experiments
Since the get_Nvl
+ get_not_Nvl
is where most of the time spent - I've focussed on performing optimisations there. To begin with I worked only get_Nvl
(both Nvl
functions have very similar structure.
In the current get_Nvl
implementation - there are two for loops, with i,j
referencing values in the Acc
compatibility-conflict matrix. I've tried multiple things that have lead to no improvement or even increase in runtime, of which here I report the two that I can now recollect.
- Converting serial loop-in-loop (i,j) into a direct i,j
product
- that can be used to check the values ofAcc
in amap
call. - converting the loop-in-loop into a numpy iterable form (np.apply_along_axis) (e.g. see
get_Nvl_fast
) - Converting the
if,else
flow into dictionary calls. (e.g. action_to_take[Acc[v,u]]
instead ofif Acc[v,u]==-1:...
)
Solutions to speeding up the code?
I've run out of ideas on how to speed up the code in Python - and am now considering using a C++ implementation that I can call from Python (though my C++ is non-existent/rusty).
I'd be grateful for ideas on how to speed up my Python code before investing too deep into another language!
1 Answer 1
Style & Maintenance
Prior to optimisation fussiness,
Use PEP484 type hints.
Docstrings are typically in double quotes rather than single quotes.
Do a PEP8 pass with a linter or good IDE; they will point out e.g. that you need spaces around ==
.
You're so close to having a unit test! Write an assert
.
Performance
Some of this will make a difference and some won't. It's good that you're profiling; keep doing that.
You've missed some opportunities for early return
; pursue this, for example when you get_NOT_Nvl(Acc, V, l)
.
Rather than V = V.difference
, just V -=
.
Rather than set([n])
, prefer {n}
; but better yet don't do a set-to-set operation at all. Clone the set and then use .discard
and .add
as needed, since this is a single-element operation.
get_Nvl
needs to operate on Nvl
as a set instead of as a list cast to a set. I anticipate this helping, since Nvl.pop()
and Nvl.index()
are much more inefficient than a set discard
. The same applies to get_NOT_Nvl
.
More broadly, Python is bad at recursion. If you absolutely must have a recursive solution, as you fear: it is time to brush up on C++.
Suggested
import numpy as np
def combine_all(Acc: np.ndarray, V: set[int], l: set[int], X: set[int]) -> list[set[int]]:
"""
Parameters
----------
Acc : (N_triples,N_triples) np.array
The compatibility-conflict graph. Value of 1 means compatible node pair, -1 is incompatible, 0 is undefined.
V_t : set
V_tilda. The currently considered vertices (a sub-set of V, all vertices)
l : set
The solution consisting of the currently compatible vertices.
Returns
-------
solutions_l : list with sublists
A somewhat messy output - must be run through ```format_combineall``` to get
nice output.
"""
# determine N_v(l) and !N_v(l)
# !N_v(l) are the vertices incompatible with the current solution
N_vl = get_Nvl(Acc, V, l)
if len(N_vl) == 0:
return [l]
solutions_l = []
# remove conflicting neighbour
V -= get_NOT_Nvl(Acc, V, l)
# unvisited compatible neighbours
N_vl -= X
X = set(X)
for n in N_vl:
Vx = set(V)
Vx.discard(n)
lx = set(l)
lx.add(n)
solution = combine_all(Acc, Vx, lx, X)
if solution:
solutions_l.append(solution)
X.add(n)
return solutions_l
def get_Nvl(Acc: np.ndarray, V_t: set[int], l: set[int]) -> set[int]:
"""Checks for compatible vertices
Essentially, two for loops run across the
available triples (v in V_t) and the current solution set (u in l)
- to access the ```Acc[v,u]``` entries. If all ```Acc[v,u]``` values
are +1, then ```v``` is compatible with the current solution ```l```.
If any of the ```Acc[v,u]``` values is -1 then that ```v``` is not compatible anymore.
Returns
-------
Nvl : set
Solution of vertices that are compatible to at least one other vertex
and not in conflict with any of the other vertices.
"""
if len(l) < 1:
return V_t
Nvl = set()
for v in V_t:
for u in l:
a = Acc[v, u]
if a == 1:
Nvl.add(v)
elif a == -1:
Nvl.discard(v)
return Nvl
def get_NOT_Nvl(Acc: np.array, V: set[int], l: set[int]) -> set[int]:
if len(l) < 1:
return set()
N_not_vl = set()
for v in V:
for u in l:
if Acc[v, u] == -1:
N_not_vl.add(v)
elif Acc[v, u] == 1:
N_not_vl.discard(v)
return N_not_vl
def format_combineall(output: list) -> list[set[int]]:
"""formats output into easily readable
form"""
semiflat = flatten_combine_all(output)
only_sets = []
for each in semiflat:
if isinstance(each, list):
for every in each:
if isinstance(every, set):
only_sets.append(every)
elif isinstance(each, set):
only_sets.append(each)
return only_sets
def flatten_combine_all(entry: list) -> list:
if isinstance(entry, list):
if len(entry) == 1:
return flatten_combine_all(entry[0])
else:
return list(map(flatten_combine_all, entry))
if isinstance(entry, set):
return entry
raise TypeError(f'{entry} can only be set or list')
def test() -> None:
# compatibility-conflict graph from [1]
A = np.array([[ 0, 1, 0, 0,-1,-1],
[ 1, 0, 1, 1, 0, 1],
[ 0, 1, 0,-1, 1, 0],
[ 0, 1,-1, 0,-1, 0],
[-1, 0, 1,-1, 0, 1],
[-1, 1, 0, 0, 1, 0]])
qq = combine_all(A, set(range(6)), set(), set())
neat_output = format_combineall(qq)
assert neat_output == [{0, 1, 2}, {0, 1, 3}, {1, 2, 4, 5}, {1, 3, 5}]
if __name__ == '__main__':
test()
-
\$\begingroup\$ Thanks a lot for the suggestions! Set based formulation is much more elegant - however, as you predicted, sadly not much faster :|. Need to go down the Cpp path I guess.. \$\endgroup\$Thejasvi– Thejasvi2022年08月09日 06:25:08 +00:00Commented Aug 9, 2022 at 6:25