-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[tests] feat: add AoT compilation tests #12203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2059,6 +2059,7 @@ def test_torch_compile_recompilation_and_graph_break(self): | |
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
||
model = self.model_class(**init_dict).to(torch_device) | ||
model.eval() | ||
model = torch.compile(model, fullgraph=True) | ||
|
||
with ( | ||
|
@@ -2076,6 +2077,7 @@ def test_torch_compile_repeated_blocks(self): | |
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
||
model = self.model_class(**init_dict).to(torch_device) | ||
model.eval() | ||
model.compile_repeated_blocks(fullgraph=True) | ||
|
||
recompile_limit = 1 | ||
|
@@ -2098,7 +2100,6 @@ def test_compile_with_group_offloading(self): | |
|
||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
model = self.model_class(**init_dict) | ||
|
||
model.eval() | ||
# TODO: Can test for other group offloading kwargs later if needed. | ||
group_offload_kwargs = { | ||
|
@@ -2111,25 +2112,46 @@ def test_compile_with_group_offloading(self): | |
} | ||
model.enable_group_offload(**group_offload_kwargs) | ||
model.compile() | ||
|
||
with torch.no_grad(): | ||
_ = model(**inputs_dict) | ||
_ = model(**inputs_dict) | ||
|
||
@require_torch_version_greater("2.7.1") | ||
def test_compile_on_different_shapes(self): | ||
if self.different_shapes_for_compilation is None: | ||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") | ||
torch.fx.experimental._config.use_duck_shape = False | ||
|
||
init_dict, _ = self.prepare_init_args_and_inputs_for_common() | ||
model = self.model_class(**init_dict).to(torch_device) | ||
model.eval() | ||
model = torch.compile(model, fullgraph=True, dynamic=True) | ||
|
||
for height, width in self.different_shapes_for_compilation: | ||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad(): | ||
inputs_dict = self.prepare_dummy_input(height=height, width=width) | ||
_ = model(**inputs_dict) | ||
|
||
def test_compile_works_with_aot(self): | ||
from torch._inductor.package import load_package | ||
|
||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
||
model = self.model_class(**init_dict).to(torch_device) | ||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict) | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2") | ||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path) | ||
assert os.path.exists(package_path) | ||
loaded_binary = load_package(package_path, run_single_threaded=True) | ||
|
||
model.forward = loaded_binary | ||
|
||
with torch.no_grad(): | ||
_ = model(**inputs_dict) | ||
_ = model(**inputs_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a particular reason why you're running it twice? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To emulate the real scenario as the model is typically invoked more than once during the actual generation process. |
||
|
||
|
||
@slow | ||
@require_torch_2 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not passing in a path should also automatically give you a path in the tmp dir!