2
\$\begingroup\$

I would like to optimize the below traceback function. This function is part of an elementary step of the program and is called a lot...

import numpy as np
def traceback(tuple_node, tuple_node_alt):
 """
 Compute which value from tuple_node_alt comes from which value from tuple_node.
 Return a dictionnary where the key are the values from tuple_node and the values are the idx at which
 the value may be located in tuple_node_alt.
 """
 # Compute the tolerances based on the node
 tolerances = [0.1 if x <= 100 else 0.2 for x in tuple_node]
 # Traceback
 distance = dict()
 alt_identification = dict()
 for k, x in enumerate(tuple_node):
 distance[k] = [abs(elt-1) for elt in [alt_x/x for alt_x in tuple_node_alt]]
 alt_identification[x] = list(np.where([elt <= tolerances[k]+0.00001 for elt in distance[k]])[0])
 # Controls the identification and corrects it
 len_values = {key: len(val) for key, val in alt_identification.items()}
 if all([x <= 1 for x in len_values.values()]):
 return alt_identification
 else:
 for key, value in alt_identification.items():
 if len(value) <= 1:
 continue
 else:
 other_values = [val for k, val in alt_identification.items() if k != key]
 if value in other_values:
 continue
 else:
 for val in other_values:
 set1 = set(value)
 intersec = set1.intersection(set(val))
 if len(intersec) == 0:
 continue
 else:
 alt_identification[key] = [v for v in value if v not in intersec]
 return alt_identification

The input is composed of 2 tuples which do not need to have the same size. e.g.

tuple_node = (40, 50, 60, 80)
tuple_node_alt = (87, 48, 59, 39)

The goal is to figure out which value from tuple_node_alt may come from which value from tuple_node. If the value from tuple_node_alt is within a 10% margin from a value from tuple_node, it is considered that it comes from this value.

e.g. 39 is within a 10% margin of 40. It comes from 40. This aprt is perform in the "Traceback" section, where a distance dictionnary is computed and where the idx are computed. With the example above, the output is:

Out[67]: {40: [3], 50: [1], 60: [2], 80: [0]}

However, because of a potential overlapping of the tolerance band, 3 scenarios exists:

Scenario 1: each value has been identified to one alternative value. That's the case above.

Scenario 2:

tuple_node = (40, 50, 60, 80)
tuple_node_alt = (42, 55, 54)

55 and 54 are both in the tolerance band of both 50 and 60. Thus, the output is:

Out[66]: {40: [0], 50: [1, 2], 60: [1, 2], 80: []}

Scenario 3:

tuple_node = (40, 50, 60)
tuple_node_alt = (42, 55, 59)

This is when the control part comes in play. With this input, alt_identification becomes: Out[66]: {40: [0], 50: [1], 60: [1, 2], 80: []}. However, the 55 can not come from 60 since 50 only has one possibility: 55. Thus, this number being already taken, the correct output which is provided through the control & correct section is:

Out[66]: {40: [0], 50: [1], 60: [2], 80: []}

I would really like to optimize this part and to make it a lot more quicker. At the moment, it takes:

# With an input which does not enter the control & correct part.
node = (40, 50, 60, 80)
node_alt = (39, 48, 59, 87)
%timeit traceback(node, node_alt)
22.6 μs ± 1.04 μs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
# With an input which need correction
node = (40, 50, 60, 100)
node_alt = (42, 55, 59, 89)
%timeit traceback(node, node_alt)
28.1 μs ± 1.88 μs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
200_success
146k22 gold badges190 silver badges478 bronze badges
asked Apr 1, 2019 at 11:23
\$\endgroup\$

1 Answer 1

1
\$\begingroup\$

A couple of low-hanging fruit inefficiencies:

  1. distance = dict(). The distance[k] value is computed in a loop, and only every used in the next statement of the loop. It does not need to be stored in a dictionary.
  2. all([ ...list comprehension... ]): You are using list comprehension to build up a list, which you immediately pass to all(...). There is no need to actually create the list. Just use all(...list comprehension...).
  3. set1 = set(value). This is inside a for val in other_values: loop, where value and set1 are not changed. Move the statement out of the for loop, to avoid recreating the same set each iteration.
  4. len_values is only used in the afore mentioned all(...), and only the the values of len_values dictionary are used. As such, the len_value dictionary construction is also unnecessary, and the if statement can be written:

    if all(len(val) <= 1 for val in alt_identification.values()):
    

Since you are returning alt_identification from the if statement, and after the if...else statement, you can invert the test, and remove one return statement:

if any(len(val) > 1 for val in alt_identification.values()):
 for key, value in alt_identification.items():
 # ... omitted for brevity ...
return alt_identification

Similarly, the two if condition: continue else: could be re-written if not condition:.


Other possible improvements:

  • tolerances[k] is only used in next for k loop. The list can be removed and the calculations move into the loop.
  • numpy is only used for a list(np.where([...])[0]) operation, which is fairly obfuscated. A simple list comprehension can be used instead.
  • The values of alt_identification are of type list, and converted (repeatedly) into a set() in the "control & correct" code. They could be stored as set() to avoid repeated conversions.

Here is my rework of the code, with the changes based on above comments:

def traceback(tuple_node, tuple_node_alt):
 def close_alternates(x):
 tolerance = (0.1 if x <= 100 else 0.2) + 0.00001
 return set( k for k, alt_x in enumerate(tuple_node_alt)
 if abs(alt_x/x - 1) <= tolerance )
 alt_identification = { x: close_alternates(x) for x in tuple_node } 
 if any(len(val) > 1 for val in alt_identification.values()):
 for key, values in alt_identification.items():
 if len(values) > 1:
 other_values = [val for k, val in alt_identification.items() if k != key]
 if values not in other_values:
 for other in other_values:
 alt_identification[key] -= other
 return alt_identification

I'm getting up to a 2.8x speedup with the above code, on your test data set that require correction.

answered Apr 1, 2019 at 17:37
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.