I have the following data:
trains
is a dictionary with 1700 elements. The keys are the IDs of trains and the value for each train is an array with every station ID where that train stops.departures
is a dictionary with the same keys astrains
, so also 1700 elements. Each value is the departure time of the train.
Now, I would like to compute intersections between trains. When train A and train B have overlapping stops, I look at the departure time of both trains. When train A departs before train B, then (A, B) is put in the resulting set, otherwise (B, A).
trains = {90: [240, 76, 18, ...], 91: [2, 17, 98, 76, ...], ...}
departures = {90: 1418732160, 91: 1418711580, ...}
intersections = []
for i in trains:
trA = trains[i]
for j in trains:
if i != j:
trB = trains[j]
intersect = [val for val in trA if val in trB]
if intersect:
if departures[i] < departures[j]:
if (i, j) not in intersections:
intersections.append((i, j))
else:
if (j, i) not in intersections:
intersections.append((j, i))
When finished, the intersections list contains 500.000 elements.
This however takes very long to compute! I'm guessing it is because of the (i, j) not in intersections
and (j, i) not in intersections
statements.
Is there any way I could alter my code to speed up this calculation?
3 Answers 3
Iterate over .items()
:
for i, trA in trains.items():
for j, trB in trains.items():
if i != j:
You should probably do an early continue
instead.
To calculate intersect
, use sets:
for i, trA in trains.items():
trA_set = set(trA)
for j, trB in trains.items():
if i == j:
continue
intersect = trA_set.intersection(trB)
This swaps an \$\mathcal{O}(n^2)\$ operation for an \$\mathcal{O}(n)\$ one.
You only check for is-empty, so this can be swapped with:
if trA_set.isdisjoint(trB):
continue
Then you do
if departures[i] < departures[j]:
if (i, j) not in intersections:
intersections.append((i, j))
else:
if (j, i) not in intersections:
intersections.append((j, i))
This can be simplified by making intersections
a set, as long as order isn't important:
intersections = set()
...
if departures[i] < departures[j]:
intersections.add((i, j))
else:
intersections.add((j, i))
or even
route = (i, j) if departures[i] < departures[j] else (j, i)
intersections.add(route)
Note that this still requires doing everything both ways. It would be simpler to require a given ordering at the start:
for i, trA in trains.items():
trA_set = set(trA)
for j, trB in trains.items():
if i == j or departures[i] > departures[j]:
continue
if trA_set.isdisjoint(trB):
continue
intersections.add((i, j))
Note that you should probably stick to PEP 8:
for i, tr_a in trains.items():
tr_a_set = set(tr_a)
for j, tr_b in trains.items():
if i == j or departures[i] > departures[j]:
continue
if tr_a_set.isdisjoint(tr_b):
continue
intersections.add((i, j))
-
\$\begingroup\$ You could update the loop to only go over the necessary indexes (i.e. the inner one should start at
i+1
) similarly to how @AJMansfield structured thepairs
function. \$\endgroup\$ferada– ferada2015年05月10日 15:30:12 +00:00Commented May 10, 2015 at 15:30 -
\$\begingroup\$ @ferada Unfortunately, dictionaries aren't sliceable. I could make an intermediate list, but I doubt it would significantly speed things up seeing as the main cost is probably
tr_a_set.isdisjoint(tr_b)
. Usingitertools.combinations
also prevents thetrA_set = set(trA)
optimization. \$\endgroup\$Veedrac– Veedrac2015年05月10日 15:38:31 +00:00Commented May 10, 2015 at 15:38 -
\$\begingroup\$ But the
keys
return value would be sliceable; in any case I'm pretty sure that not doing \$n^2\$ for the outer two loops is a good idea. \$\endgroup\$ferada– ferada2015年05月10日 15:47:29 +00:00Commented May 10, 2015 at 15:47 -
\$\begingroup\$ @ferada On Python 2, yes, but I don't think I can assume that. Anyway,
combinations(..., 2)
is still \$\mathcal{O}(n^2)\$. \$\endgroup\$Veedrac– Veedrac2015年05月10日 15:51:04 +00:00Commented May 10, 2015 at 15:51
I would use a somewhat different algorithm to accomplish this task, by instead making a dictionary of all trains that go to each stop, sorting those trains by arrival, and then taking all inorder pairs of trains at each stop.
The basic algorithm is like this:
from collections import defaultdict
from itertools import combinations
trains = {...}
departures = {...}
intersections = set()
stations = defaultdict(list)
for t, train in trains.items():
for s in train:
stations[s].append(t)
for station in stations.values():
intersections.update(combinations(sorted(station, key=lambda t: departures[t]), 2))
(This version of the code vastly improved by @Veedrac)
-
\$\begingroup\$ Could you explain your code a bit more in depth? I don't really see what you are doing. \$\endgroup\$JNevens– JNevens2015年05月11日 09:48:54 +00:00Commented May 11, 2015 at 9:48
itertools.combinations(..., 2)
is a handy way to consider all pairs just once. Based on the @ferada's follow-up comments to @Veedrac's answer...
from itertools import combinations
train_sets = {train_id: set(stops) for train_id, stops in trains}
for (i, i_stops), (j, j_stops) in combinations(train_sets.items(), 2):
if not i_stops.is_disjoint(j_stops):
intersections.add((i, j) if departures[i] < departures[j] else (j, i))
trains
anddepaturs
dictionaries? (What is the maximum key?) \$\endgroup\$