Skip to main content
Code Review

Return to Revisions

3 of 4
deleted 492 characters in body

Speed up this python function on concordance index calculation

I am trying to calculate a customized concordance index for survival analysis. Below is my code. It runs well for small input dataframe but extremely slow on a dataframe with one million rows (>30min).

import pandas as pd
def c_index1(y_pred, events, times):
 df = pd.DataFrame(data={'proba':y_pred, 'event':events, 'time':times})
 n_total_correct = 0
 n_total_comparable = 0
 df = df.sort_values(by=['time'])
 for i, row in df.iterrows():
 if row['event'] == 1:
 comparable_rows = df[(df['event'] == 0) & (df['time'] > row['time'])]
 n_correct_rows = len(comparable_rows[comparable_rows['proba'] < row['proba']])
 n_total_correct += n_correct_rows
 n_total_comparable += len(comparable_rows)
 return n_total_correct / n_total_comparable if n_total_comparable else None
c = c_index([0.1, 0.3, 0.67, 0.45, 0.56], [1.0,0.0,1.0,0.0,1.0], [3.1,4.5,6.7,5.2,3.4])
print(c) # print 0.5

For each row (in case it matters...):

  • If the event of the row is 1: retrieve all comparable rows whose
  1. index is larger (avoid duplicate calculation),
  2. event is 0, and
  3. time is larger than the time of the current row. Out of the comparable rows, the rows whose probability is less than the current row are correct predictions.

I guess it is slow because of the for loop. How should I speed up it?

lang-py

AltStyle によって変換されたページ (->オリジナル) /