You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Depending on whether watcher is on or not and whether a specific program_config is passed in to ttnn.linear, different behaviour exists for the following test:
from loguru import logger
import torch
import pytest
from models.utility_functions import (
is_wormhole_b0,
skip_for_grayskull,
is_grayskull,
is_wormhole_b0,
is_x2_harvested,
is_blackhole,
skip_for_blackhole,
is_blackhole,
)
from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc, check_with_pcc_without_tensor_printout
import ttnn
def _nearest_32(x):
import math
return math.ceil(x / 32) * 32
def test_resnet_l4m1c1_7x8(device, use_program_cache):
# torch.manual_seed(0)
batch_size = 16
input_channels = 1024
output_channels = 512
input_h, input_w = 14, 14
input_shape = [1, 1, batch_size * input_h * input_w, input_channels]
torch_input_tensor = torch.randn(input_shape, dtype=torch.bfloat16)
torch_weight_tensor = torch.randn([1, 1, output_channels, input_channels], dtype=torch.bfloat16)
torch_bias_tensor = torch.randn([1, 1, 1, output_channels], dtype=torch.bfloat16)
torch_out_golden_tensor = torch.nn.functional.linear(
torch_input_tensor[0, 0, :, :], torch_weight_tensor[0, 0, :, :], bias=torch_bias_tensor[0, 0, :, :]
)
# tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT)
# torch_input_tensor = torch.nn.functional.pad(torch_input_tensor, (0, 0, 0, 3328 - 3136), value=0)
tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat8_b, device=device, layout=ttnn.TILE_LAYOUT)
tt_weight_tensor = ttnn.from_torch(
torch.permute(torch_weight_tensor, (0, 1, 3, 2)), ttnn.bfloat8_b, device=device, layout=ttnn.TILE_LAYOUT
)
tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat8_b, device=device, layout=ttnn.TILE_LAYOUT)
compute_config = ttnn.init_device_compute_kernel_config(
device.arch(),
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
grid_size = (7, 8)
matmul_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid_size,
in0_block_w=4,
out_subblock_h=1,
out_subblock_w=2,
per_core_M=14,
per_core_N=2,
transpose_mcast=True,
fuse_batch=True,
fused_activation=None,#ttnn.UnaryOpType.RELU,
)
shard_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(grid_size[0] - 1, grid_size[1] - 1),
)
}
)
x = tt_input_tensor
shard_shape = [
_nearest_32(x.volume() // x.shape.with_tile_padding()[-1] // grid_size[0]),
# x.volume() // x.shape.with_tile_padding()[-1] // grid_size[0],
x.shape.with_tile_padding()[-1] // grid_size[1],
]
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR)
sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.BufferType.L1, shard_spec)
x = ttnn.to_memory_config(x, sharded_mem_config)
tt_output_tensor_on_device = ttnn.linear(
x,
tt_weight_tensor,
bias=tt_bias_tensor,
memory_config=sharded_mem_config,
dtype=ttnn.bfloat8_b,
compute_kernel_config=compute_config,
# with watcher on
# - regardless of whether the next parameter, the watcher fails
# with watcher off
# - using matmul_config allows the test to pass
# - using None causes the test to hang even though the auto generated config is the same as the passed in one
program_config=matmul_config, #None,
)
tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device)
torch_output_tensor = ttnn.to_torch(tt_output_tensor)[:, :, :3136, :]
assert_with_pcc(torch_out_golden_tensor, torch_output_tensor[0, 0, :, :], pcc=0.999)
The program configs that are seen by the matmul op are:
Auto generated program config:
MatmulMultiCoreReuseMultiCastProgramConfig(compute_with_storage_grid_size=(x=7,y=8),in0_block_w=4,out_subblock_h=1,out_subblock_w=2,out_block_h=14,out_block_w=2,per_core_M=14,per_core_N=2,transpose_mcast=1,fused_activation=std::nullopt,fuse_batch=1)
vs (specified without fused):
program_config =
MatmulMultiCoreReuseMultiCastProgramConfig(compute_with_storage_grid_size=(x=7,y=8),in0_block_w=4,out_subblock_h=1,out_subblock_w=2,out_block_h=14,out_block_w=2,per_core_M=14,per_core_N=2,transpose_mcast=1,fused_activation=std::nullopt,fuse_batch=1)
This issue is to investigate
the watcher failure
why passing in a program config vs automatically selecting the same program config causes different behaviour without watcher.
Depending on whether watcher is on or not and whether a specific program_config is passed in to ttnn.linear, different behaviour exists for the following test:
The program configs that are seen by the matmul op are:
This issue is to investigate
The behaviour exists on both
Found while investigating #16869
The text was updated successfully, but these errors were encountered: