1
\$\begingroup\$

I need to generate column a_b based on column a and column b of df, if both a and b are greater than 0, a_b is assigned a value of 1, if both a and b are less than 0, a_b is assigned a value of -1, I am using double np.where .

My code is as follows, where generate_data generates demo data and get_result is used for production, where get_result needs to be run 4 million times:

import numpy as np
import pandas as pd
rand = np.random.default_rng(seed=0)
pd.set_option('display.max_columns', None)
def generate_data() -> pd.DataFrame:
 _df = pd.DataFrame(rand.uniform(-1, 1, 70).reshape(10, 7), columns=['a', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6'])
 return _df
def get_result(_df: pd.DataFrame) -> pd.DataFrame:
 a = _df.a.to_numpy()
 for col in ['b1', 'b2', 'b3', 'b4', 'b5', 'b6']:
 b = _df[col].to_numpy()
 _df[f'a_{col}'] = np.where(
 (a > 0) & (b > 0), 1., np.where(
 (a < 0) & (b < 0), -1., 0.)
 )
 return _df
def main():
 df = generate_data()
 print(df)
 df = get_result(df)
 print(df)
if __name__ == '__main__':
 main()

Data generated by generate_data:

 a b1 b2 b3 b4 b5 b6
0 0.273923 -0.460427 -0.918053 -0.966945 0.626540 0.825511 0.213272
1 0.458993 0.087250 0.870145 0.631707 -0.994523 0.714809 -0.932829
2 0.459311 -0.648689 0.726358 0.082922 -0.400576 -0.154626 -0.943361
3 -0.751433 0.341249 0.294379 0.230770 -0.232645 0.994420 0.961671
4 0.371084 0.300919 0.376893 -0.222157 -0.729807 0.442977 0.050709
5 -0.379516 -0.028329 0.778976 0.868087 -0.284410 0.143060 -0.356261
6 0.188600 -0.324178 -0.216762 0.780549 -0.545685 0.246374 -0.831969
7 0.665288 0.574197 -0.521261 0.752968 -0.882864 -0.327766 -0.699441
8 -0.099321 0.592649 -0.538716 -0.895957 -0.190896 -0.602974 -0.818494
9 0.160665 -0.402608 0.343990 -0.600969 0.884226 -0.269780 -0.789009

My desired result:


 a b1 b2 b3 b4 b5 b6 a_b1 \
0 0.273923 -0.460427 -0.918053 -0.966945 0.626540 0.825511 0.213272 0.0 
1 0.458993 0.087250 0.870145 0.631707 -0.994523 0.714809 -0.932829 1.0 
2 0.459311 -0.648689 0.726358 0.082922 -0.400576 -0.154626 -0.943361 0.0 
3 -0.751433 0.341249 0.294379 0.230770 -0.232645 0.994420 0.961671 0.0 
4 0.371084 0.300919 0.376893 -0.222157 -0.729807 0.442977 0.050709 1.0 
5 -0.379516 -0.028329 0.778976 0.868087 -0.284410 0.143060 -0.356261 -1.0 
6 0.188600 -0.324178 -0.216762 0.780549 -0.545685 0.246374 -0.831969 0.0 
7 0.665288 0.574197 -0.521261 0.752968 -0.882864 -0.327766 -0.699441 1.0 
8 -0.099321 0.592649 -0.538716 -0.895957 -0.190896 -0.602974 -0.818494 0.0 
9 0.160665 -0.402608 0.343990 -0.600969 0.884226 -0.269780 -0.789009 0.0 
 a_b2 a_b3 a_b4 a_b5 a_b6 
0 0.0 0.0 1.0 1.0 1.0 
1 1.0 1.0 0.0 1.0 0.0 
2 1.0 1.0 0.0 0.0 0.0 
3 0.0 0.0 -1.0 0.0 0.0 
4 1.0 0.0 0.0 1.0 1.0 
5 0.0 0.0 -1.0 0.0 -1.0 
6 0.0 1.0 0.0 1.0 0.0 
7 0.0 1.0 0.0 0.0 0.0 
8 -1.0 -1.0 -1.0 -1.0 -1.0 
9 1.0 0.0 1.0 0.0 0.0 

Performance evaluation:

%timeit get_result(df)
1.56 ms ± 54.7 μs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

How can it be faster?

Reinderien
70.9k5 gold badges76 silver badges256 bronze badges
asked Jun 10, 2022 at 10:18
\$\endgroup\$
6
  • \$\begingroup\$ Based on what you've shown, pandas is not needed here. Either you haven't shown code that's realistic enough, or my suggested changes are going to use pure Numpy. \$\endgroup\$ Commented Jun 10, 2022 at 11:19
  • \$\begingroup\$ @Reinderien My real code is exactly the same as the data generated by generated_data, just the column names are different. \$\endgroup\$ Commented Jun 10, 2022 at 11:31
  • \$\begingroup\$ I'm more talking about the way that the data are used before and after this segment of code in the program. I'll assume that pandas is not necessary. \$\endgroup\$ Commented Jun 10, 2022 at 11:48
  • \$\begingroup\$ @Reinderien You can see my other question, I need to process the self.status generated there, and the final result is observastion, and the type is list. codereview.stackexchange.com/questions/277147 \$\endgroup\$ Commented Jun 10, 2022 at 12:34
  • 1
    \$\begingroup\$ If you don't show the code you care about here, in this question, then for the purposes of review it doesn't exist. \$\endgroup\$ Commented Jun 10, 2022 at 12:35

1 Answer 1

2
\$\begingroup\$

I don't see any value in Pandas here. Use Numpy broadcasting directly between an a 10x1 array and a b 10x6 array, producing a new 10x6 array. Your inner where() does work with these arrays unmodified, but there are faster methods that do not use where and instead call np.sign or np.ceil.

from collections import defaultdict
from statistics import mean
from timeit import timeit
import numpy as np
rand = np.random.default_rng(seed=0)
a = rand.uniform(low=-1, high=1, size=(10, 1))
b = rand.uniform(low=-1, high=1, size=(10, 6))
# if both a and b are greater than 0, a_b is assigned a value of 1,
# if both a and b are less than 0, a_b is assigned a value of -1
def op():
 return np.where(
 (a > 0) & (b > 0),
 1,
 np.where(
 (a < 0) & (b < 0), -1, 0,
 )
 )
def ceils():
 return np.ceil(a)*np.ceil(b) - np.floor(a)*np.floor(b)
def signs():
 sa = np.sign(a)
 return sa * (sa == np.sign(b))
METHODS = (op, ceils, signs)
results = [method() for method in METHODS]
for result in results[1:]:
 assert np.allclose(results[0], result)
N = 10_000
times = defaultdict(list)
for _ in range(10):
 for method in METHODS:
 t = timeit(method, number=N)
 times[method.__name__].append(t)
for method, method_time in times.items():
 print(f'{method:>6}: {mean(method_time)/N*1e6:6.1f} us')

Output

 op: 24.3 us
 ceils: 10.0 us
 signs: 8.0 us
answered Jun 10, 2022 at 12:38
\$\endgroup\$
1
  • 1
    \$\begingroup\$ Pure numpy signs() is 100 times faster than my code. \$\endgroup\$ Commented Jun 10, 2022 at 14:49

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.