Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit ffc8c0c

Browse files
[tests] feat: add AoT compilation tests (#12203)
* feat: add a test for aot. * up
1 parent 4acbfbf commit ffc8c0c

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

‎tests/models/test_modeling_common.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2059,6 +2059,7 @@ def test_torch_compile_recompilation_and_graph_break(self):
20592059
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
20602060

20612061
model = self.model_class(**init_dict).to(torch_device)
2062+
model.eval()
20622063
model = torch.compile(model, fullgraph=True)
20632064

20642065
with (
@@ -2076,6 +2077,7 @@ def test_torch_compile_repeated_blocks(self):
20762077
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
20772078

20782079
model = self.model_class(**init_dict).to(torch_device)
2080+
model.eval()
20792081
model.compile_repeated_blocks(fullgraph=True)
20802082

20812083
recompile_limit = 1
@@ -2098,7 +2100,6 @@ def test_compile_with_group_offloading(self):
20982100

20992101
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
21002102
model = self.model_class(**init_dict)
2101-
21022103
model.eval()
21032104
# TODO: Can test for other group offloading kwargs later if needed.
21042105
group_offload_kwargs = {
@@ -2111,25 +2112,46 @@ def test_compile_with_group_offloading(self):
21112112
}
21122113
model.enable_group_offload(**group_offload_kwargs)
21132114
model.compile()
2115+
21142116
with torch.no_grad():
21152117
_ = model(**inputs_dict)
21162118
_ = model(**inputs_dict)
21172119

2118-
@require_torch_version_greater("2.7.1")
21192120
def test_compile_on_different_shapes(self):
21202121
if self.different_shapes_for_compilation is None:
21212122
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
21222123
torch.fx.experimental._config.use_duck_shape = False
21232124

21242125
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
21252126
model = self.model_class(**init_dict).to(torch_device)
2127+
model.eval()
21262128
model = torch.compile(model, fullgraph=True, dynamic=True)
21272129

21282130
for height, width in self.different_shapes_for_compilation:
21292131
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
21302132
inputs_dict = self.prepare_dummy_input(height=height, width=width)
21312133
_ = model(**inputs_dict)
21322134

2135+
def test_compile_works_with_aot(self):
2136+
from torch._inductor.package import load_package
2137+
2138+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
2139+
2140+
model = self.model_class(**init_dict).to(torch_device)
2141+
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
2142+
2143+
with tempfile.TemporaryDirectory() as tmpdir:
2144+
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
2145+
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
2146+
assert os.path.exists(package_path)
2147+
loaded_binary = load_package(package_path, run_single_threaded=True)
2148+
2149+
model.forward = loaded_binary
2150+
2151+
with torch.no_grad():
2152+
_ = model(**inputs_dict)
2153+
_ = model(**inputs_dict)
2154+
21332155

21342156
@slow
21352157
@require_torch_2

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /