@@ -118,6 +118,7 @@ def __init__(self, *args, **kwargs):
118118Tensor .__init__ = __init__
119119origin_setitem = Tensor .__setitem__
120120origin_is_contiguous = Tensor .is_contiguous
121+ origin_to = Tensor .to
121122Tensor ._requires_grad = False
122123
123124def tensor (data , * , dtype = None , device = None , requires_grad = False ):
@@ -248,6 +249,8 @@ def __getitem__(self, slices):
248249
249250 if self .device .type == 'meta' :
250251 out = ops .getitem_np (self , slices )
252+ elif self .device .type == 'cpu' :
253+ return ops .getitem (self , slices )
251254 else :
252255 out = ops .tensor_getitem (self , slices )
253256
@@ -2251,7 +2254,7 @@ def _move_to(self, device, non_blocking=False):
22512254 # self.data_sync(True)
22522255 if self .device .type == 'cpu' :
22532256 self .data_ptr ()
2254- data = Tensor . move_to (self , device_str , blocking = not non_blocking )
2257+ data = origin_to (self , device_str , non_blocking = non_blocking )
22552258
22562259 out = Tensor (data )
22572260 out ._device = device
0 commit comments