I would like to optimize the below traceback
function. This function is part of an elementary step of the program and is called a lot...
import numpy as np
def traceback(tuple_node, tuple_node_alt):
"""
Compute which value from tuple_node_alt comes from which value from tuple_node.
Return a dictionnary where the key are the values from tuple_node and the values are the idx at which
the value may be located in tuple_node_alt.
"""
# Compute the tolerances based on the node
tolerances = [0.1 if x <= 100 else 0.2 for x in tuple_node]
# Traceback
distance = dict()
alt_identification = dict()
for k, x in enumerate(tuple_node):
distance[k] = [abs(elt-1) for elt in [alt_x/x for alt_x in tuple_node_alt]]
alt_identification[x] = list(np.where([elt <= tolerances[k]+0.00001 for elt in distance[k]])[0])
# Controls the identification and corrects it
len_values = {key: len(val) for key, val in alt_identification.items()}
if all([x <= 1 for x in len_values.values()]):
return alt_identification
else:
for key, value in alt_identification.items():
if len(value) <= 1:
continue
else:
other_values = [val for k, val in alt_identification.items() if k != key]
if value in other_values:
continue
else:
for val in other_values:
set1 = set(value)
intersec = set1.intersection(set(val))
if len(intersec) == 0:
continue
else:
alt_identification[key] = [v for v in value if v not in intersec]
return alt_identification
The input is composed of 2 tuples which do not need to have the same size. e.g.
tuple_node = (40, 50, 60, 80)
tuple_node_alt = (87, 48, 59, 39)
The goal is to figure out which value from tuple_node_alt
may come from which value from tuple_node
. If the value from tuple_node_alt
is within a 10% margin from a value from tuple_node
, it is considered that it comes from this value.
e.g. 39 is within a 10% margin of 40. It comes from 40. This aprt is perform in the "Traceback" section, where a distance dictionnary is computed and where the idx are computed. With the example above, the output is:
Out[67]: {40: [3], 50: [1], 60: [2], 80: [0]}
However, because of a potential overlapping of the tolerance band, 3 scenarios exists:
Scenario 1: each value has been identified to one alternative value. That's the case above.
Scenario 2:
tuple_node = (40, 50, 60, 80)
tuple_node_alt = (42, 55, 54)
55 and 54 are both in the tolerance band of both 50 and 60. Thus, the output is:
Out[66]: {40: [0], 50: [1, 2], 60: [1, 2], 80: []}
Scenario 3:
tuple_node = (40, 50, 60)
tuple_node_alt = (42, 55, 59)
This is when the control part comes in play. With this input, alt_identification
becomes: Out[66]: {40: [0], 50: [1], 60: [1, 2], 80: []}
. However, the 55 can not come from 60 since 50 only has one possibility: 55. Thus, this number being already taken, the correct output which is provided through the control & correct section is:
Out[66]: {40: [0], 50: [1], 60: [2], 80: []}
I would really like to optimize this part and to make it a lot more quicker. At the moment, it takes:
# With an input which does not enter the control & correct part.
node = (40, 50, 60, 80)
node_alt = (39, 48, 59, 87)
%timeit traceback(node, node_alt)
22.6 μs ± 1.04 μs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# With an input which need correction
node = (40, 50, 60, 100)
node_alt = (42, 55, 59, 89)
%timeit traceback(node, node_alt)
28.1 μs ± 1.88 μs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1 Answer 1
A couple of low-hanging fruit inefficiencies:
distance = dict()
. Thedistance[k]
value is computed in a loop, and only every used in the next statement of the loop. It does not need to be stored in a dictionary.all([ ...list comprehension... ])
: You are using list comprehension to build up a list, which you immediately pass toall(...)
. There is no need to actually create the list. Just useall(...list comprehension...)
.set1 = set(value)
. This is inside afor val in other_values:
loop, wherevalue
andset1
are not changed. Move the statement out of thefor
loop, to avoid recreating the same set each iteration.len_values
is only used in the afore mentionedall(...)
, and only the the values oflen_values
dictionary are used. As such, thelen_value
dictionary construction is also unnecessary, and theif
statement can be written:if all(len(val) <= 1 for val in alt_identification.values()):
Since you are returning alt_identification
from the if
statement, and after the if...else
statement, you can invert the test, and remove one return statement:
if any(len(val) > 1 for val in alt_identification.values()):
for key, value in alt_identification.items():
# ... omitted for brevity ...
return alt_identification
Similarly, the two if condition: continue else:
could be re-written if not condition:
.
Other possible improvements:
tolerances[k]
is only used in nextfor k
loop. The list can be removed and the calculations move into the loop.numpy
is only used for alist(np.where([...])[0])
operation, which is fairly obfuscated. A simple list comprehension can be used instead.- The values of
alt_identification
are of typelist
, and converted (repeatedly) into aset()
in the "control & correct" code. They could be stored asset()
to avoid repeated conversions.
Here is my rework of the code, with the changes based on above comments:
def traceback(tuple_node, tuple_node_alt):
def close_alternates(x):
tolerance = (0.1 if x <= 100 else 0.2) + 0.00001
return set( k for k, alt_x in enumerate(tuple_node_alt)
if abs(alt_x/x - 1) <= tolerance )
alt_identification = { x: close_alternates(x) for x in tuple_node }
if any(len(val) > 1 for val in alt_identification.values()):
for key, values in alt_identification.items():
if len(values) > 1:
other_values = [val for k, val in alt_identification.items() if k != key]
if values not in other_values:
for other in other_values:
alt_identification[key] -= other
return alt_identification
I'm getting up to a 2.8x speedup with the above code, on your test data set that require correction.