From 692da2bc7d16f0898d23b3928dd221e5592ca27c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 09:39:01 +0530 Subject: [PATCH 1/2] feat: add a test for aot. --- tests/models/test_modeling_common.py | 34 ++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 1e08191f56aa..d539b027dded 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -2059,6 +2059,7 @@ def test_torch_compile_recompilation_and_graph_break(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) + model.eval() model = torch.compile(model, fullgraph=True) with ( @@ -2076,6 +2077,7 @@ def test_torch_compile_repeated_blocks(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) + model.eval() model.compile_repeated_blocks(fullgraph=True) recompile_limit = 1 @@ -2098,7 +2100,6 @@ def test_compile_with_group_offloading(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) - model.eval() # TODO: Can test for other group offloading kwargs later if needed. group_offload_kwargs = { @@ -2111,11 +2112,11 @@ def test_compile_with_group_offloading(self): } model.enable_group_offload(**group_offload_kwargs) model.compile() + with torch.no_grad(): _ = model(**inputs_dict) _ = model(**inputs_dict) - @require_torch_version_greater("2.7.1") def test_compile_on_different_shapes(self): if self.different_shapes_for_compilation is None: pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") @@ -2123,6 +2124,7 @@ def test_compile_on_different_shapes(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) + model.eval() model = torch.compile(model, fullgraph=True, dynamic=True) for height, width in self.different_shapes_for_compilation: @@ -2130,6 +2132,34 @@ def test_compile_on_different_shapes(self): inputs_dict = self.prepare_dummy_input(height=height, width=width) _ = model(**inputs_dict) + def test_compile_works_with_aot(self): + from torch._inductor.package import load_package + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict).to(torch_device) + exported_model = torch.export.export(model, args=(), kwargs=inputs_dict) + + with tempfile.TemporaryDirectory() as tmpdir: + package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2") + _ = torch._inductor.aoti_compile_and_package( + exported_model, + package_path=package_path, + inductor_configs={ + "aot_inductor.package_constants_in_so": False, + "aot_inductor.package_constants_on_disk": True, + "aot_inductor.package": True, + }, + ) + assert os.path.exists(package_path) + loaded_binary = load_package(package_path) + + model.forward = loaded_binary + + with torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + @slow @require_torch_2 From 0fef57a9baf62bfbd277b72dfcd38938cab37eae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 10:00:11 +0530 Subject: [PATCH 2/2] up --- tests/models/test_modeling_common.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index d539b027dded..444a0a050645 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -2142,17 +2142,9 @@ def test_compile_works_with_aot(self): with tempfile.TemporaryDirectory() as tmpdir: package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2") - _ = torch._inductor.aoti_compile_and_package( - exported_model, - package_path=package_path, - inductor_configs={ - "aot_inductor.package_constants_in_so": False, - "aot_inductor.package_constants_on_disk": True, - "aot_inductor.package": True, - }, - ) + _ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path) assert os.path.exists(package_path) - loaded_binary = load_package(package_path) + loaded_binary = load_package(package_path, run_single_threaded=True) model.forward = loaded_binary

AltStyle によって変換されたページ (->オリジナル) /