Skip to content

Commit

Permalink
#0: Revert "suppport all dim lengths for reduction (#16247)"
Browse files Browse the repository at this point in the history
This reverts commit 573ed07.

This is to fix post-commit.
  • Loading branch information
tt-rkim committed Jan 3, 2025
1 parent 9acc400 commit eaec2d6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 194 deletions.
144 changes: 20 additions & 124 deletions tests/ttnn/unit_tests/operations/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,38 +50,13 @@ def test_var(device, batch_size, h, w, dim):
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("c", [11])
@pytest.mark.parametrize("h", [67])
@pytest.mark.parametrize("w", [77])
@pytest.mark.parametrize("dim", [0, 1, 2, 3])
@pytest.mark.parametrize("keepdim", [True])
def test_prod(device, batch_size, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((batch_size, c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.prod(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(
torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)

output_tensor = ttnn.prod(input_tensor, dim=dim, memory_config=ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert len(output_tensor.shape) == len(torch_output_tensor.shape)
assert output_tensor.shape == torch_output_tensor.shape
# assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("batch_size", [3])
@pytest.mark.parametrize("c", [5])
@pytest.mark.parametrize("h", [37])
@pytest.mark.parametrize("w", [63])
@pytest.mark.parametrize("dim", [None, [], 0, 2, [0, 1], [1, 3], [0, 1, 2], [1, 2, 3], [0, 1, 2, 3]])
@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("c", [1, 4, 8, 16])
@pytest.mark.parametrize("h", [32, 64, 41, 37])
@pytest.mark.parametrize("w", [32, 64, 31, 63])
@pytest.mark.parametrize("dim", [None, [0, 1, 2, 3]])
@pytest.mark.parametrize("keepdim", [True])
def test_sum_4d_tensor_dims(device, batch_size, c, h, w, dim, keepdim):
def test_sum_4d_tensors(device, batch_size, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((batch_size, c, h, w), dtype=torch.bfloat16)
Expand All @@ -97,105 +72,26 @@ def test_sum_4d_tensor_dims(device, batch_size, c, h, w, dim, keepdim):
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("c", [3])
@pytest.mark.parametrize("h", [31])
@pytest.mark.parametrize("w", [32])
@pytest.mark.parametrize("dim", [[0, 2], [0, 1, 2]])
@pytest.mark.parametrize("keepdim", [True])
def test_sum_3d_tensor_dims(device, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.sum(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.sum(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("h", [41])
@pytest.mark.parametrize("w", [31])
@pytest.mark.parametrize("dim", [0, 1, [0, 1]])
@pytest.mark.parametrize("keepdim", [True])
def test_sum_2d_tensor_dims(device, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.sum(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.sum(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("batch_size", [3])
@pytest.mark.parametrize("c", [5])
@pytest.mark.parametrize("h", [37])
@pytest.mark.parametrize("w", [63])
@pytest.mark.parametrize("dim", [None, [], 0, 2, [0, 1], [1, 3], [0, 1, 2], [1, 2, 3], [0, 1, 2, 3]])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("c", [11])
@pytest.mark.parametrize("h", [67])
@pytest.mark.parametrize("w", [77])
@pytest.mark.parametrize("dim", [0, 1, 2, 3])
@pytest.mark.parametrize("keepdim", [True])
def test_mean_4d_tensor_dims(device, batch_size, c, h, w, dim, keepdim):
def test_prod(device, batch_size, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((batch_size, c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.mean(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("c", [3])
@pytest.mark.parametrize("h", [31])
@pytest.mark.parametrize("w", [32])
@pytest.mark.parametrize("dim", [[0, 2], [0, 1, 2]])
@pytest.mark.parametrize("keepdim", [True])
def test_mean_3d_tensor_dims(device, c, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((c, h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

output_tensor = ttnn.mean(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)


@pytest.mark.parametrize("h", [41])
@pytest.mark.parametrize("w", [31])
@pytest.mark.parametrize("dim", [0, 1, [0, 1]])
@pytest.mark.parametrize("keepdim", [True])
def test_mean_2d_tensor_dims(device, h, w, dim, keepdim):
torch.manual_seed(0)

torch_input_tensor = torch.randn((h, w), dtype=torch.bfloat16)
torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=keepdim)
torch_output_tensor = torch.prod(torch_input_tensor, dim=dim, keepdim=keepdim)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor = ttnn.from_torch(
torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)

output_tensor = ttnn.mean(input_tensor, dim=dim, keepdim=keepdim)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.prod(input_tensor, dim=dim, memory_config=ttnn.L1_MEMORY_CONFIG)
output_tensor = ttnn.from_device(output_tensor)

output_tensor = ttnn.to_torch(output_tensor)
assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)
assert len(output_tensor.shape) == len(torch_output_tensor.shape)
assert output_tensor.shape == torch_output_tensor.shape
# assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99)
119 changes: 49 additions & 70 deletions ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,20 @@
namespace ttnn {
namespace operations::reduction {

ttnn::SmallVector<int> generate_reduce_dim(
const Tensor& input_tensor_arg, const std::optional<std::variant<int, ttnn::SmallVector<int>>>& dim_arg) {
template <ReduceType reduce_type>
static Tensor reduce_impl(
const Tensor& input_tensor_arg,
const std::optional<std::variant<int, ttnn::SmallVector<int>>>& dim_arg,
const bool keepdim,
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar,
bool reshape) {
using ttnn::operations::experimental::auto_format::AutoFormat;
auto input_shape = input_tensor_arg.get_shape();
auto rank = input_shape.size();
auto memory_config = memory_config_arg.value_or(input_tensor_arg.memory_config());

ttnn::SmallVector<int> dim{};
if (dim_arg.has_value()) {
if (not std::holds_alternative<ttnn::SmallVector<int>>(dim_arg.value())) {
Expand All @@ -24,8 +34,7 @@ ttnn::SmallVector<int> generate_reduce_dim(
} else {
dim = std::get<ttnn::SmallVector<int>>(dim_arg.value());
}
}
if (dim.empty()) {
} else {
dim = ttnn::SmallVector<int>(rank);
for (int i = 0; i < rank; i++) {
dim[i] = i;
Expand All @@ -46,22 +55,6 @@ ttnn::SmallVector<int> generate_reduce_dim(
}

std::sort(dim.begin(), dim.end());
return dim;
}

template <ReduceType reduce_type>
static Tensor reduce_impl(
const Tensor& input_tensor_arg,
const ttnn::SmallVector<int>& dim,
const bool keepdim,
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar,
bool reshape) {
using ttnn::operations::experimental::auto_format::AutoFormat;
auto input_shape = input_tensor_arg.get_shape();
auto rank = input_shape.size();
auto memory_config = memory_config_arg.value_or(input_tensor_arg.memory_config());

ttnn::SmallVector<uint32_t> output_shape;
for (int axis = 0; axis < input_shape.size(); axis++) {
Expand All @@ -75,58 +68,45 @@ static Tensor reduce_impl(
}
}

if (dim.size() == 1 && (rank == 3 || rank == 4)) {
if (dim[0] == 1 && rank == 4) {
Tensor output = ttnn::transpose(input_tensor_arg, 1, -2, memory_config);
output = reduce_impl<reduce_type>(
output, 2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true);
output = ttnn::transpose(output, 1, -2, memory_config);
if (reshape) {
output = ttnn::reshape(output, ttnn::Shape{output_shape});
}
return output;
} else if (dim[0] == 0) {
Tensor output = ttnn::transpose(input_tensor_arg, 0, -2, memory_config);
output = reduce_impl<reduce_type>(
output, -2, /*keepdim=*/true, memory_config, compute_kernel_config, scalar, /*reshape=*/true);
output = ttnn::transpose(output, 0, -2, memory_config);
if (reshape) {
output = ttnn::reshape(output, ttnn::Shape{output_shape});
}
return output;
}
}

auto input_tensor = ttnn::unsqueeze_to_4D(input_tensor_arg);

Tensor output_tensor;
bool single_reduce_op = (dim.size() == 1 && (dim[0] == rank - 1 || dim[0] == rank - 2)) ||
(dim.size() == 2 && dim[0] == rank - 1 && dim[0] == rank - 2);
if (!single_reduce_op) {
auto reduce_4d_loop = [&](const bool use_reduce_type) -> Tensor {
Tensor output_tensor = input_tensor;
int offset = 4 - rank;
for (int i_dim = rank - 1; i_dim >= 0; i_dim--) {
bool found = std::find(dim.begin(), dim.end(), i_dim) != dim.end();
if (found) {
bool transpose = i_dim < rank - 2;
int adjusted_dim = offset + i_dim;
int reduce_dim = adjusted_dim;
if (transpose) {
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, 2, memory_config);
reduce_dim = 2;
}
if (use_reduce_type) {
output_tensor = reduce_impl<reduce_type>(
output_tensor,
{reduce_dim},
/*keepdim=*/true,
memory_config,
compute_kernel_config,
scalar,
/*reshape=*/false);
} else {
output_tensor = reduce_impl<ReduceType::Sum>(
output_tensor,
{reduce_dim},
/*keepdim=*/true,
memory_config,
compute_kernel_config,
scalar,
/*reshape=*/false);
}
if (transpose) {
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config);
}
}
if (!dim_arg.has_value() || dim.size() == rank) {
if constexpr (
reduce_type == ReduceType::Sum || reduce_type == ReduceType::Max || reduce_type == ReduceType::Min) {
output_tensor = input_tensor;
for (int rank = input_tensor.get_legacy_shape().rank() - 1; rank >= 0; rank--) {
output_tensor = reduce_impl<reduce_type>(
output_tensor, rank, true, memory_config, compute_kernel_config, scalar, false);
}
return output_tensor;
};
constexpr bool linear_type =
reduce_type == ReduceType::Sum || reduce_type == ReduceType::Max || reduce_type == ReduceType::Min;
if (dim.size() == 1 || linear_type) {
output_tensor = reduce_4d_loop(/*use_reduce_type=*/true);
} else if constexpr (reduce_type == ReduceType::Mean) {
output_tensor = reduce_4d_loop(
/*use_reduce_type=*/false);
output_tensor = input_tensor;
for (int rank = input_tensor.get_legacy_shape().rank() - 1; rank >= 0; rank--) {
output_tensor = reduce_impl<ReduceType::Sum>(
output_tensor, rank, true, memory_config, compute_kernel_config, scalar, false);
}
float inv_volume = 1.0f / input_tensor.get_logical_volume();
output_tensor = ttnn::mul_sfpu(inv_volume, output_tensor, memory_config);
} else {
Expand Down Expand Up @@ -213,7 +193,7 @@ static Tensor reduce_impl(
}

if (reshape) {
output_tensor = ttnn::reshape(output_tensor, ttnn::SimpleShape{output_shape});
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape{output_shape});
}

return output_tensor;
Expand All @@ -227,9 +207,8 @@ Tensor Reduce<reduce_type>::invoke(
const std::optional<MemoryConfig>& memory_config_arg,
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar) {
ttnn::SmallVector<int> dim = generate_reduce_dim(input_tensor_arg, dim_arg);
return reduce_impl<reduce_type>(
input_tensor_arg, dim, keepdim, memory_config_arg, compute_kernel_config, scalar, true);
input_tensor_arg, dim_arg, keepdim, memory_config_arg, compute_kernel_config, scalar, true);
}

template class Reduce<ReduceType::Sum>;
Expand Down

0 comments on commit eaec2d6

Please sign in to comment.