-
Notifications
You must be signed in to change notification settings - Fork 293
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit tests on CPU for TritonBench features (#2323)
Summary: Pull Request resolved: #2323 Add unit tests that run on the CPU to verify the behavior of the following: - `x_only = True` for metric registration in [`register_metric()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=337) - custom `label` argument for benchmark registration in [`register_benchmark()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=316) Reviewed By: xuzhao9 Differential Revision: D58558868 fbshipit-source-id: 50e35ec3359db02ce86f45b1d44c34dac9c1a03b
- Loading branch information
1 parent
1f0fecd
commit a529b5a
Showing
3 changed files
with
46 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .operator import Operator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Generator, List, Optional | ||
|
||
import torch | ||
|
||
from torchbenchmark.util.triton_op import ( | ||
BenchmarkOperator, | ||
BenchmarkOperatorMetrics, | ||
register_benchmark, | ||
register_metric, | ||
) | ||
|
||
|
||
class Operator(BenchmarkOperator): | ||
|
||
DEFAULT_METRICS = ["test_metric"] | ||
|
||
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): | ||
super().__init__(mode=mode, device=device, extra_args=extra_args) | ||
|
||
@register_benchmark(label="new_op_label") | ||
def test_op(self, x: torch.Tensor): | ||
return lambda: x | ||
|
||
def get_x_val(self, example_inputs): | ||
return example_inputs[0].shape | ||
|
||
def get_x_vals(self) -> List[int]: | ||
return [2**n for n in [1, 2, 3]] | ||
|
||
def get_input_iter(self) -> Generator: | ||
for x in self.get_x_vals(): | ||
yield (torch.Tensor(torch.randn(x, device=self.device, dtype=self.dtype)),) | ||
|
||
@register_metric(x_only=True) | ||
def test_metric( | ||
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics | ||
): | ||
return [ex.shape[0] + 2 for ex in example_inputs] | ||
|
||
@register_metric() | ||
def test_metric_per_benchmark( | ||
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics | ||
): | ||
return [ex.shape[0] + 3 for ex in example_inputs] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters