11import mindspore
2+ import mindtorch
23from mindspore ._c_expression import _empty_instance
34from ..configs import use_pyboost , ON_A1 , ON_ORANGE_PI
45from .._op_prim .ascend import legacy , pyboost
@@ -824,9 +825,11 @@ def argmax(input, axis, keepdims):
824825 return legacy .argmax (input , axis , keepdims )
825826
826827def argmin (input , axis , keepdims ):
827- if use_pyboost ():
828+ if use_pyboost ()and not ON_ORANGE_PI :
828829 return pyboost .argmin_ext_op (input , axis , keepdims )
829- return legacy .argmin (input , axis , keepdims )
830+ if axis is None :
831+ axis = - 1
832+ return legacy .arg_min_with_value (input , axis , keepdims )[0 ]
830833
831834
832835def bmm (input , other ):
@@ -1136,7 +1139,7 @@ def masked_scatter(input, mask, value):
11361139 return legacy .masked_scatter (input , mask , value )
11371140
11381141def neg (input ):
1139- if use_pyboost ():
1142+ if use_pyboost ()and not ON_ORANGE_PI :
11401143 return pyboost .neg_op (input )
11411144 return legacy .neg (input )
11421145
@@ -1557,7 +1560,7 @@ def inplace_exponential(self, lambd, generator):
15571560 return legacy .expo (self , lambd , generator )
15581561
15591562def im2col (input , kernel_size , dilation = 1 , padding = 0 , stride = 1 ):
1560- if use_pyboost () and not ON_A1 :
1563+ if use_pyboost () and not ON_A1 and not ON_ORANGE_PI :
15611564 return pyboost .im2col_ext_op (input , kernel_size , dilation , padding , stride )
15621565 out = legacy .im2_col (input , kernel_size , stride , dilation , padding )
15631566 out_shape = out .shape [:1 ] + (- 1 ,) + out .shape [- 1 :]
@@ -1570,9 +1573,10 @@ def upsample_nearest2d(input, output_size, scale_factors):
15701573 return legacy .upsample_nearest2d (input , scale_factor , align_corners )
15711574
15721575def addmm (input , mat1 , mat2 , alpha = 1.0 , beta = 1.0 ):
1573- if use_pyboost ():
1576+ if use_pyboost ()and not ON_ORANGE_PI :
15741577 return pyboost .addmm_op (input , mat1 , mat2 , alpha , beta )
1575- return legacy .addmm (input , mat1 , mat2 , alpha , beta )
1578+ return add (mul (input , beta ), mul (matmul (mat1 , mat2 ), alpha ))
1579+ 15761580
15771581def meshgrid (input , lambd ):
15781582 if use_pyboost ():
0 commit comments