@@ -2059,6 +2059,7 @@ def test_torch_compile_recompilation_and_graph_break(self):
2059
2059
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
2060
2060
2061
2061
model = self .model_class (** init_dict ).to (torch_device )
2062
+ model .eval ()
2062
2063
model = torch .compile (model , fullgraph = True )
2063
2064
2064
2065
with (
@@ -2076,6 +2077,7 @@ def test_torch_compile_repeated_blocks(self):
2076
2077
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
2077
2078
2078
2079
model = self .model_class (** init_dict ).to (torch_device )
2080
+ model .eval ()
2079
2081
model .compile_repeated_blocks (fullgraph = True )
2080
2082
2081
2083
recompile_limit = 1
@@ -2098,7 +2100,6 @@ def test_compile_with_group_offloading(self):
2098
2100
2099
2101
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
2100
2102
model = self .model_class (** init_dict )
2101
-
2102
2103
model .eval ()
2103
2104
# TODO: Can test for other group offloading kwargs later if needed.
2104
2105
group_offload_kwargs = {
@@ -2111,25 +2112,46 @@ def test_compile_with_group_offloading(self):
2111
2112
}
2112
2113
model .enable_group_offload (** group_offload_kwargs )
2113
2114
model .compile ()
2115
+
2114
2116
with torch .no_grad ():
2115
2117
_ = model (** inputs_dict )
2116
2118
_ = model (** inputs_dict )
2117
2119
2118
- @require_torch_version_greater ("2.7.1" )
2119
2120
def test_compile_on_different_shapes (self ):
2120
2121
if self .different_shapes_for_compilation is None :
2121
2122
pytest .skip (f"Skipping as `different_shapes_for_compilation` is not set for { self .__class__ .__name__ } ." )
2122
2123
torch .fx .experimental ._config .use_duck_shape = False
2123
2124
2124
2125
init_dict , _ = self .prepare_init_args_and_inputs_for_common ()
2125
2126
model = self .model_class (** init_dict ).to (torch_device )
2127
+ model .eval ()
2126
2128
model = torch .compile (model , fullgraph = True , dynamic = True )
2127
2129
2128
2130
for height , width in self .different_shapes_for_compilation :
2129
2131
with torch ._dynamo .config .patch (error_on_recompile = True ), torch .no_grad ():
2130
2132
inputs_dict = self .prepare_dummy_input (height = height , width = width )
2131
2133
_ = model (** inputs_dict )
2132
2134
2135
+ def test_compile_works_with_aot (self ):
2136
+ from torch ._inductor .package import load_package
2137
+
2138
+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
2139
+
2140
+ model = self .model_class (** init_dict ).to (torch_device )
2141
+ exported_model = torch .export .export (model , args = (), kwargs = inputs_dict )
2142
+
2143
+ with tempfile .TemporaryDirectory () as tmpdir :
2144
+ package_path = os .path .join (tmpdir , f"{ self .model_class .__name__ } .pt2" )
2145
+ _ = torch ._inductor .aoti_compile_and_package (exported_model , package_path = package_path )
2146
+ assert os .path .exists (package_path )
2147
+ loaded_binary = load_package (package_path , run_single_threaded = True )
2148
+
2149
+ model .forward = loaded_binary
2150
+
2151
+ with torch .no_grad ():
2152
+ _ = model (** inputs_dict )
2153
+ _ = model (** inputs_dict )
2154
+
2133
2155
2134
2156
@slow
2135
2157
@require_torch_2
0 commit comments