Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 568e398

Browse files
fix c class models on OrangePi (#2213)
1 parent 2dfb9ad commit 568e398

File tree

4 files changed

+16
-9
lines changed

4 files changed

+16
-9
lines changed

‎mindtorch/_apis/gpu.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ def addcmul(input, tensor1, tensor2, value=1.0):
848848
return legacy.addcmul(input, tensor1, tensor2, mindspore.Tensor(value))
849849

850850
def addmm(input, mat1, mat2, alpha=1.0, beta=1.0):
851-
return add(mul(input, beta), mul(bmm(mat1, mat2), alpha))
851+
return add(mul(input, beta), mul(matmul(mat1, mat2), alpha))
852852

853853
def im2col(input, kernel_size, dilation=1, padding=0, stride=1):
854854
out = legacy.im2_col(input, kernel_size, stride, dilation, padding)

‎mindtorch/_apis/npu.py‎

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import mindspore
2+
import mindtorch
23
from mindspore._c_expression import _empty_instance
34
from ..configs import use_pyboost, ON_A1, ON_ORANGE_PI
45
from .._op_prim.ascend import legacy, pyboost
@@ -824,9 +825,11 @@ def argmax(input, axis, keepdims):
824825
return legacy.argmax(input, axis, keepdims)
825826

826827
def argmin(input, axis, keepdims):
827-
if use_pyboost():
828+
if use_pyboost()andnotON_ORANGE_PI:
828829
return pyboost.argmin_ext_op(input, axis, keepdims)
829-
return legacy.argmin(input, axis, keepdims)
830+
if axis is None:
831+
axis = -1
832+
return legacy.arg_min_with_value(input, axis, keepdims)[0]
830833

831834

832835
def bmm(input, other):
@@ -1136,7 +1139,7 @@ def masked_scatter(input, mask, value):
11361139
return legacy.masked_scatter(input, mask, value)
11371140

11381141
def neg(input):
1139-
if use_pyboost():
1142+
if use_pyboost()andnotON_ORANGE_PI:
11401143
return pyboost.neg_op(input)
11411144
return legacy.neg(input)
11421145

@@ -1557,7 +1560,7 @@ def inplace_exponential(self, lambd, generator):
15571560
return legacy.expo(self, lambd, generator)
15581561

15591562
def im2col(input, kernel_size, dilation=1, padding=0, stride=1):
1560-
if use_pyboost() and not ON_A1:
1563+
if use_pyboost() and not ON_A1andnotON_ORANGE_PI:
15611564
return pyboost.im2col_ext_op(input, kernel_size, dilation, padding, stride)
15621565
out = legacy.im2_col(input, kernel_size, stride, dilation, padding)
15631566
out_shape = out.shape[:1] + (-1,) + out.shape[-1:]
@@ -1570,9 +1573,10 @@ def upsample_nearest2d(input, output_size, scale_factors):
15701573
return legacy.upsample_nearest2d(input, scale_factor, align_corners)
15711574

15721575
def addmm(input, mat1, mat2, alpha=1.0, beta=1.0):
1573-
if use_pyboost():
1576+
if use_pyboost()andnotON_ORANGE_PI:
15741577
return pyboost.addmm_op(input, mat1, mat2, alpha, beta)
1575-
return legacy.addmm(input, mat1, mat2, alpha, beta)
1578+
return add(mul(input, beta), mul(matmul(mat1, mat2), alpha))
1579+
15761580

15771581
def meshgrid(input, lambd):
15781582
if use_pyboost():

‎mindtorch/ops/array.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,9 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
622622
self_viewed = self
623623
self_viewed_shape = list(self.shape)
624624
dim = 0
625+
if ON_ORANGE_PI:
626+
if all([isinstance(index, slice) for index in indexes]):
627+
return getitem(self_viewed, tuple(indexes)), remain_indexes
625628
for i, index in enumerate(indexes):
626629
if isinstance(index, (list, tuple, np.ndarray)):
627630
index_np = np.array(index) if isinstance(index, (list, tuple)) else index
@@ -634,7 +637,6 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
634637
raise TypeError(f"Index {index} contain unsupported elements")
635638
self_viewed, dim, remain_indexes, self_viewed_shape = _process_dim_in_multi_dim_index(
636639
self_viewed, self, index, dim, indexed_dims, i, remain_indexes, self_viewed_shape)
637-
638640
return self_viewed, remain_indexes
639641

640642

‎tests/run_test.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import mindnlp
77
from mindnlp import transformers
88

9-
mindspore.set_context(pynative_synchronize=True)
9+
# mindspore.set_context(pynative_synchronize=True)
10+
mindspore.runtime.launch_blocking()
1011

1112
def run_tests():
1213
"""

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /