I am trying to calculate a customized concordance index for survival analysis. Below is my code. It runs well for small input dataframe but extremely slow on a dataframe with one million rows (>30min).
import pandas as pd
def c_index1(y_pred, events, times):
df = pd.DataFrame(data={'proba':y_pred, 'event':events, 'time':times})
n_total_correct = 0
n_total_comparable = 0
df = df.sort_values(by=['time'])
for i, row in df.iterrows():
if row['event'] == 1:
comparable_rows = df[(df['event'] == 0) & (df['time'] > row['time'])]
n_correct_rows = len(comparable_rows[comparable_rows['proba'] < row['proba']])
n_total_correct += n_correct_rows
n_total_comparable += len(comparable_rows)
return n_total_correct / n_total_comparable if n_total_comparable else None
c = c_index([0.1, 0.3, 0.67, 0.45, 0.56], [1.0,0.0,1.0,0.0,1.0], [3.1,4.5,6.7,5.2,3.4])
print(c) # print 0.5
For each row (in case it matters...):
If the event of the row is 1: retrieve all comparable rows whose
- index is larger (avoid duplicate calculation),
- event is 0, and
- time is larger than the time of the current row. Out of the comparable rows, the rows whose probability is less than the current row are correct predictions.
I guess it is slow because of the for
loop. How should I speed up it?
1 Answer 1
You will not get dramatic speedups untill you can vectorize the operations, but here are some tips already
indexing before iterating
instead of
for i, row in df.iterrows():
if row['event'] == 1:
If you do
for i, row in df[df['event'] == 1].rows():
you will iterate over less rows.
itertuples
generally, itertuples
is faster than iterrows
comparable_rows
for comparable_rows
you are only interested in the proba
and the length, so you might as well make this into a Series, or even better, a numpy array.
The test (df['event'] == 0)
doesn't change during the iteration, so you can define a df2 = df[df['event'] == 0]
outside of the loop
n_correct_rows
instead of len(comparable_rows[comparable_rows['proba'] < row['proba']])
, you can use the fact that True == 1
do (comparable_rows['proba'] < row.proba).sum()
result
def c_index3(y_pred, events, times):
df = pd.DataFrame(data={'proba':y_pred, 'event':events, 'time':times})
n_total_correct = 0
n_total_comparable = 0
df = df.sort_values(by=['time'])
df2 = df.loc[df['event'] == 0]
for row in df[df['event'] == 1].itertuples():
comparable_rows = df2.loc[(df2['time'] > row.time), 'proba'].values
n_correct_rows = (comparable_rows < row.proba).sum()
n_total_correct += n_correct_rows
n_total_comparable += len(comparable_rows)
return n_total_correct / n_total_comparable if n_total_comparable else N
timings
data = ([0.1, 0.3, 0.67, 0.45, 0.56], [1.0,0.0,1.0,0.0,1.0], [3.1,4.5,6.7,5.2,3.4])
%timeit c_index1(*data)
5.17 ms ± 33.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit c_index3(*data)
3.77 ms ± 160 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)