Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Adapt ones_like dtype for torch 2.8.0 #2598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
M-Quadra wants to merge 1 commit into apple:main from M-Quadra:pr/ones_like

Conversation

@M-Quadra
Copy link
Contributor

@M-Quadra M-Quadra commented Sep 26, 2025

This PR is compatible with executorch torch 2.7.

Unit test

pytest coremltools/converters/mil/frontend/torch/test/test_torch_ops.py::TestIndexPut::test_index_put_updates_bool

In torch 2.8.0, the error message is:

ValueError: In op, of type scatter_nd, named index_put, the named input `updates` must have the same data type as the named input `data`. However, updates has dtype fp32 whereas data has dtype int32.

Detail

import torch
import numpy as np
import coremltools as ct
class Model(torch.nn.Module):
 def forward(self, x):
 x = torch.ones(x.shape, dtype=torch.bool)
 y = torch.ones_like(x).bool()
 mask = torch.tensor([True, False, False, False, True, True]).view(3, 2)
 x[mask] = y[mask]
 return x
x = torch.randn(3, 2)
model = Model().eval()
exported_model = torch.export.export(model, (x,)).run_decompositions({})
mlmodel = ct.convert(
 exported_model,
 minimum_deployment_target=ct.target.iOS16,
)
y0 = model(x).numpy()
y1 = mlmodel.predict({"x": x.numpy()})["index_put"]
assert np.equal(y0, y1).all()
  • exported_model.graph in torch 2.7.0
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
%x : [num_users=0] = placeholder[target=x]
%ones : [num_users=2] = call_function[target=torch.ops.aten.ones.default](args = ([3, 2],), kwargs = {dtype: torch.bool, device: cpu, pin_memory: False})
%ones_like : [num_users=1] = call_function[target=torch.ops.aten.ones_like.default](args = (%ones,), kwargs = {pin_memory: False})
%_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%ones_like,), kwargs = {dtype: torch.bool})
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
%view : [num_users=2] = call_function[target=torch.ops.aten.view.default](args = (%clone, [3, 2]), kwargs = {})
%index : [num_users=2] = call_function[target=torch.ops.aten.index.Tensor](args = (%_to_copy, [%view]), kwargs = {})
%sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%index, 0), kwargs = {})
%sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
%le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 6), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 6 on node 'le_1'), kwargs = {})
%index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%ones, [%view], %index), kwargs = {})
return (index_put,)
  • exported_model.graph in torch 2.8.0
%c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0]
%x : [num_users=0] = placeholder[target=x]
%ones : [num_users=2] = call_function[target=torch.ops.aten.ones.default](args = ([3, 2],), kwargs = {dtype: torch.bool, device: cpu, pin_memory: False})
%ones_like : [num_users=2] = call_function[target=torch.ops.aten.ones_like.default](args = (%ones,), kwargs = {pin_memory: False})
%_assert_tensor_metadata : [num_users=0] = call_function[target=torch.ops.aten._assert_tensor_metadata.default](args = (%ones_like, None, None, torch.bool), kwargs = {device: cpu, layout: torch.strided})
%clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {})
%view : [num_users=2] = call_function[target=torch.ops.aten.view.default](args = (%clone, [3, 2]), kwargs = {})
%index : [num_users=2] = call_function[target=torch.ops.aten.index.Tensor](args = (%ones_like, [%view]), kwargs = {})
%sym_size_int_1 : [num_users=3] = call_function[target=torch.ops.aten.sym_size.int](args = (%index, 0), kwargs = {})
%sym_constrain_range_for_size_default : [num_users=0] = call_function[target=torch.ops.aten.sym_constrain_range_for_size.default](args = (%sym_size_int_1,), kwargs = {})
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u0 >= 0 on node 'ge_2'), kwargs = {})
%le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 6), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u0 <= 6 on node 'le_1'), kwargs = {})
%index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%ones, [%view], %index), kwargs = {})
return (index_put,)

In torch 2.8.0, the dtype of ones_like may be moved to _assert_tensor_metadata.

Copy link
Collaborator

The max version of PyTorch that we currently support is 2.7.0 and that's what's being used in our CI system.

So there is no way to test this change other than making sure it doesn't break 2.7.0. As result, I'm reluctant to merge this change.

Would it be possible for you to look into what other changes are necessary for coremltools to support PyTorch 2.8.0?

Copy link
Contributor Author

Sure. I'll open a new issue to track PyTorch 2.8.0 compatibility problems.

@M-Quadra M-Quadra deleted the pr/ones_like branch October 19, 2025 09:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Reviewers

No reviews

Assignees

No one assigned

Labels

None yet

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

AltStyle によって変換されたページ (->オリジナル) /