Skip to content

Commit

Permalink
#15894: Add conv1d tests in convolution short sweep.
Browse files Browse the repository at this point in the history
Signed-off-by: Nilaykumar Patel <[email protected]>
  • Loading branch information
nkpatel-tt authored Dec 12, 2024
1 parent dd874dd commit 74311e9
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 17 deletions.
78 changes: 76 additions & 2 deletions tests/sweep_framework/sweep_utils/conv2d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def mesh_device_fixture():
ttnn.close_device(device)


def run_full(
def run_conv2d_full_sweep(
input_specs,
input_channels,
output_channels,
Expand Down Expand Up @@ -174,7 +174,7 @@ def run_full(
return [check_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=0.998), e2e_perf]


def run_short(
def run_conv2d_short_sweep(
input_specs,
device,
) -> list:
Expand Down Expand Up @@ -256,3 +256,77 @@ def run_short(
torch_output_tensor = torch.permute(torch_output_tensor, (0, 3, 1, 2))

return [check_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=0.998), e2e_perf]


def run_conv1d_short_sweep(
input_specs,
device,
) -> list:
[
batch_size,
output_channels,
input_channels,
input_length,
kernel_size,
stride,
padding,
groups,
has_bias,
dilation,
] = input_specs
print(input_specs)

# has_bias = False
torch.manual_seed(0)
conv_input_shape = [batch_size, input_channels, input_length]
conv_weight_shape = [output_channels, input_channels // groups, kernel_size]
conv_bias_shape = [1, 1, 1, output_channels]
torch_input_tensor_ncl = torch.randn(conv_input_shape, dtype=torch.bfloat16).float()
torch_input_tensor = torch.permute(torch_input_tensor_ncl, (0, 2, 1))
torch_weight_tensor = torch.randn(conv_weight_shape, dtype=torch.bfloat16).float()
torch_bias_tensor = torch.randn(conv_bias_shape, dtype=torch.bfloat16).float() if has_bias else None
torch_out_golden_tensor = torch.nn.functional.conv1d(
torch_input_tensor_ncl,
torch_weight_tensor,
bias=torch_bias_tensor.reshape(-1) if has_bias else None,
stride=stride,
padding=padding,
groups=groups,
)

tt_weight_tensor = ttnn.from_torch(torch_weight_tensor, ttnn.bfloat16)
tt_bias_tensor = None
if has_bias:
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16)

tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16)

start_time = start_measuring_time()
[tt_output_tensor_on_device, out_length, [weights_device, bias_device]] = ttnn.Conv1d(
input_tensor=tt_input_tensor,
weight_tensor=tt_weight_tensor,
in_channels=input_channels,
out_channels=output_channels,
device=device,
bias_tensor=tt_bias_tensor,
kernel_size=kernel_size,
stride=stride,
padding=padding,
batch_size=batch_size,
input_length=input_length,
groups=groups,
return_output_dim=True,
return_weights_and_bias=True,
)

tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)
e2e_perf = stop_measuring_time(start_time)

# torch_output_tensor is in row major layout and NLC shape
# NLC to NCL
torch_output_tensor = torch_output_tensor.reshape(batch_size, out_length, output_channels)

torch_output_tensor = torch.permute(torch_output_tensor, (0, 2, 1))

return [check_with_pcc(torch_output_tensor, torch_out_golden_tensor, pcc=0.998), e2e_perf]
4 changes: 2 additions & 2 deletions tests/sweep_framework/sweeps/conv2d/full/conv2d_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random
from tests.sweep_framework.sweep_utils.conv2d_common import run_full, get_input_specs, mesh_device_fixture
from tests.sweep_framework.sweep_utils.conv2d_common import run_conv2d_full_sweep, get_input_specs, mesh_device_fixture

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30
Expand Down Expand Up @@ -242,7 +242,7 @@ def run(
*,
device,
) -> list:
return run_full(
return run_conv2d_full_sweep(
input_specs,
input_channels,
output_channels,
Expand Down
4 changes: 2 additions & 2 deletions tests/sweep_framework/sweeps/conv2d/full/conv2d_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random
from tests.sweep_framework.sweep_utils.conv2d_common import run_full, get_input_specs, mesh_device_fixture
from tests.sweep_framework.sweep_utils.conv2d_common import run_conv2d_full_sweep, get_input_specs, mesh_device_fixture

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30
Expand Down Expand Up @@ -111,7 +111,7 @@ def run(
*,
device,
) -> list:
return run_full(
return run_conv2d_full_sweep(
input_specs,
input_channels,
output_channels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random
from tests.sweep_framework.sweep_utils.conv2d_common import run_full, get_input_specs, mesh_device_fixture
from tests.sweep_framework.sweep_utils.conv2d_common import run_conv2d_full_sweep, get_input_specs, mesh_device_fixture

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30
Expand Down Expand Up @@ -109,7 +109,7 @@ def run(
*,
device,
) -> list:
return run_full(
return run_conv2d_full_sweep(
input_specs,
input_channels,
output_channels,
Expand Down
35 changes: 26 additions & 9 deletions tests/sweep_framework/sweeps/conv2d/short/conv2d_short_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random
from tests.sweep_framework.sweep_utils.conv2d_common import run_short, mesh_device_fixture
from tests.sweep_framework.sweep_utils.conv2d_common import (
run_conv2d_short_sweep,
run_conv1d_short_sweep,
mesh_device_fixture,
)

parameters = {
"short_sweep_suite": {
"short_sweep_suite_conv2d": {
"input_specs": [
# Contains following params
# [batch_size, output_channels, input_channels, input_height, input_width, kernel_height, kernel_width, stride_x, stride_y, pad_x, pad_y, groups, bias, dilation]
Expand Down Expand Up @@ -1566,6 +1570,18 @@
[1, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, 1, True, 1],
[1, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, 1, True, 1],
],
"is_conv1d": [False],
},
"short_sweep_suite_conv1d": {
"input_specs": [
# Contains following params
# [batch_size, output_channels, input_channels, input_length, kernel_size, stride, pad, groups, bias, dilation]
[1, 256, 1024, 512, 1, 1, 0, 1, True, 1],
[1, 1024, 256, 512, 1, 1, 0, 1, True, 1],
[1, 768, 768, 3000, 3, 2, 1, 1, True, 1],
[1, 768, 80, 3000, 3, 1, 1, 1, True, 1],
],
"is_conv1d": [True],
},
}

Expand All @@ -1576,22 +1592,23 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:

def run(
input_specs,
is_conv1d=False,
*,
device,
) -> list:
return run_short(
input_specs,
device,
)
if is_conv1d:
return run_conv1d_short_sweep(input_specs, device)
else:
return run_conv2d_short_sweep(input_specs, device)


import pytest


@pytest.mark.parametrize("input_spec", parameters["short_sweep_suite"]["input_specs"])
@pytest.mark.parametrize("input_spec", parameters["short_sweep_suite_conv2d"]["input_specs"])
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_conv2d_localrun(device, input_spec):
run_short(
run_conv2d_short_sweep(
input_spec,
device,
)
Expand Down Expand Up @@ -1658,7 +1675,7 @@ def test_conv2d_localrun(device, input_spec):
@pytest.mark.parametrize("input_spec", failing_parameters)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_conv2d_localrun_fail_only(device, input_spec):
run_short(
run_conv2d_short_sweep(
input_spec,
device,
)

0 comments on commit 74311e9

Please sign in to comment.