Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#15642: Replace shapes in eltwise #15646

Merged
merged 5 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,8 @@ Tensor ExecutePrelu::invoke(

Tensor ExecutePrelu::invoke(
const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
const auto s_a = input_a.get_shape();
const auto s_a = input_a.get_logical_shape();
const auto volume = input_b.get_logical_volume();

TT_FATAL(
s_a[1] == volume,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = {} and channel size = {}.",
Expand All @@ -360,7 +359,7 @@ Tensor ExecutePrelu::invoke(
if (s_a.rank() > 2) {
SmallVector<uint32_t> reshape(s_a.rank(), 1);
reshape[1] = s_a[1];
b = ttnn::reshape(input_b, ttnn::Shape(reshape));
b = ttnn::reshape(input_b, ttnn::SimpleShape(reshape));
}

Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a);
Expand Down Expand Up @@ -491,9 +490,9 @@ Tensor _scatter(const Tensor& input_a, const Tensor& input_b, const std::optiona
* by running reshape.
*/
Tensor _outer(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
const tt::tt_metal::LegacyShape s_a = input_a.get_legacy_shape();
const tt::tt_metal::LegacyShape s_b = input_b.get_legacy_shape();
auto num_ones = [](const tt::tt_metal::LegacyShape& s) -> uint32_t {
const ttnn::SimpleShape s_a = input_a.padded_shape();
const ttnn::SimpleShape s_b = input_b.padded_shape();
auto num_ones = [](const ttnn::SimpleShape& s) -> uint32_t {
uint32_t num1s = 0;
for (uint32_t idx = 0; idx < 4; idx++) {
num1s += (uint32_t)(s[idx] == 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_f
}
if (height_b == 1) {
if (tensor_args.input_tensor_a.is_sharded()) {
if (tensor_args.input_tensor_a.get_padded_shape()[0] ==
tensor_args.input_tensor_b->get_padded_shape()[0] ||
tensor_args.input_tensor_a.get_padded_shape()[0] > 1 and
tensor_args.input_tensor_b->get_padded_shape()[0] == 1) {
if (tensor_args.input_tensor_a.padded_shape()[0] ==
tensor_args.input_tensor_b->padded_shape()[0] ||
tensor_args.input_tensor_a.padded_shape()[0] > 1 and
tensor_args.input_tensor_b->padded_shape()[0] == 1) {
return BroadcastHeightMultiCoreShardedOptimized{};
}
return BroadcastHeightMultiCoreSharded{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create(
const auto& b = tensor_args.input_tensor_b;
auto& output = tensor_return_value;
auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type);
const auto ashape = a.get_padded_shape();
const auto bshape = b.has_value() ? b->get_padded_shape() : Shape{1, 1};
const auto ashape = a.padded_shape();
const auto bshape = b.has_value() ? b->padded_shape() : Shape{1, 1};
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down Expand Up @@ -298,8 +298,8 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a

auto dst_buffer = output_tensor.buffer();

const auto ashape = input_tensor_a.get_padded_shape();
const auto bshape = input_tensor_b.has_value() ? input_tensor_b->get_padded_shape() : Shape{1, 1};
const auto ashape = input_tensor_a.padded_shape();
const auto bshape = input_tensor_b.has_value() ? input_tensor_b->padded_shape() : Shape{1, 1};
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ BinaryDeviceOperation ::BroadcastHeightMultiCore::create(
auto& output = tensor_return_value;
auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type);

const auto ashape = a.get_legacy_shape();
const auto bshape = b->get_legacy_shape();
const auto ashape = a.padded_shape();
const auto bshape = b->padded_shape();
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down Expand Up @@ -238,8 +238,8 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCore::override_runtime_argument

auto dst_dram_buffer = output_tensor.buffer();

const auto ashape = input_tensor_a.get_legacy_shape();
const auto bshape = input_tensor_b->get_legacy_shape();
const auto ashape = input_tensor_a.padded_shape();
const auto bshape = input_tensor_b->padded_shape();
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreShardedOptimized::create(
auto& output = tensor_return_value;
auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type);

const auto ashape = a.get_legacy_shape();
const auto bshape = b->get_legacy_shape();
const auto ashape = a.padded_shape();
const auto bshape = b->padded_shape();
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down Expand Up @@ -267,9 +267,9 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCoreShardedOptimized::override_
auto all_cores = shard_spec.grid;
uint32_t ncores = shard_spec.num_cores();
uint32_t Wt = 0, Ht = 0;
const auto ashape = input_tensor_a.get_legacy_shape();
const auto ashape = input_tensor_a.padded_shape();
uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3];
uint32_t bN = input_tensor_b->get_legacy_shape()[0];
uint32_t bN = input_tensor_b->padded_shape()[0];
uint32_t NC = N * C;
if (a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) {
Wt = shard_spec.shape[1] / TILE_WIDTH;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreSharded::create(
auto& output = tensor_return_value;
auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type);

const auto ashape = a.get_legacy_shape();
const auto bshape = b->get_legacy_shape();
const auto ashape = a.padded_shape();
const auto bshape = b->padded_shape();
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down Expand Up @@ -127,7 +127,7 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreSharded::create(
.set_globally_allocated_address(*output.buffer());
auto out_cb = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config);

uint32_t num_input_tiles = (b->get_legacy_shape()[-1] * output.element_size() + TILE_HW - 1) / TILE_HW;
uint32_t num_input_tiles = (b->padded_shape()[-1] * output.element_size() + TILE_HW - 1) / TILE_HW;
uint32_t src1_cb_index = tt::CBIndex::c_1;
tt_metal::CircularBufferConfig src1_cb_config =
tt_metal::CircularBufferConfig(num_input_tiles * input1_tile_size, {{src1_cb_index, b_df}})
Expand Down Expand Up @@ -249,9 +249,9 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCoreSharded::override_runtime_a
auto all_cores = shard_spec.grid;
uint32_t ncores = shard_spec.num_cores();
uint32_t Wt = 0, Ht = 0;
const auto ashape = input_tensor_a.get_legacy_shape();
const auto ashape = input_tensor_a.padded_shape();
uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3];
uint32_t bN = input_tensor_b->get_legacy_shape()[0];
uint32_t bN = input_tensor_b->padded_shape()[0];
uint32_t NC = N * C;
if (a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) {
Wt = shard_spec.shape[1] / TILE_WIDTH;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ BinaryDeviceOperation::BroadcastWidthMultiCore::cached_program_t BinaryDeviceOpe
auto& output = tensor_return_value;
auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type);

const auto ashape = a.get_legacy_shape();
const auto bshape = b->get_legacy_shape();
const auto ashape = a.padded_shape();
const auto bshape = b->padded_shape();
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down Expand Up @@ -240,8 +240,8 @@ void BinaryDeviceOperation::BroadcastWidthMultiCore::override_runtime_arguments(

auto dst_dram_buffer = output_tensor.buffer();

const auto ashape = input_tensor_a.get_legacy_shape();
const auto bshape = input_tensor_b->get_legacy_shape();
const auto ashape = input_tensor_a.padded_shape();
const auto bshape = input_tensor_b->padded_shape();
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,5 +238,4 @@ void BinaryDeviceOperation::ElementWiseMultiCore::override_runtime_arguments(
shared_variables.src1_single_tile_size,
shared_variables.dst_single_tile_size);
}

} // namespace ttnn::operations::binary
Original file line number Diff line number Diff line change
Expand Up @@ -632,10 +632,7 @@ std::vector<std::optional<Tensor>> ExecuteBackwardConcat::invoke(
if (are_required_outputs[0]) {
ttnn::SmallVector<uint32_t> start_index = {0, 0, 0, 0};
ttnn::SmallVector<uint32_t> end_index = {
input.get_legacy_shape()[0],
input.get_legacy_shape()[1],
input.get_legacy_shape()[2],
input.get_legacy_shape()[3]};
input.padded_shape()[0], input.padded_shape()[1], input.padded_shape()[2], input.padded_shape()[3]};
ttnn::SmallVector<uint32_t> step = {1, 1, 1, 1};
ttnn::slice(queue_id, grad, start_index, end_index, step, std::nullopt, input_grad);
grad_tensor[0] = input_grad;
Expand All @@ -644,19 +641,16 @@ std::vector<std::optional<Tensor>> ExecuteBackwardConcat::invoke(
if (are_required_outputs[1]) {
ttnn::SmallVector<uint32_t> start_index_2 = {0, 0, 0, 0};
if (dim == 0) {
start_index_2 = {input.get_legacy_shape()[0], 0, 0, 0};
start_index_2 = {input.padded_shape()[0], 0, 0, 0};
} else if (dim == 1) {
start_index_2 = {0, input.get_legacy_shape()[1], 0, 0};
start_index_2 = {0, input.padded_shape()[1], 0, 0};
} else if (dim == 2) {
start_index_2 = {0, 0, input.get_legacy_shape()[2], 0};
start_index_2 = {0, 0, input.padded_shape()[2], 0};
} else if (dim == 3) {
start_index_2 = {0, 0, 0, input.get_legacy_shape()[3]};
start_index_2 = {0, 0, 0, input.padded_shape()[3]};
}
ttnn::SmallVector<uint32_t> end_index_2 = {
grad.get_legacy_shape()[0],
grad.get_legacy_shape()[1],
grad.get_legacy_shape()[2],
grad.get_legacy_shape()[3]};
grad.padded_shape()[0], grad.padded_shape()[1], grad.padded_shape()[2], grad.padded_shape()[3]};
ttnn::SmallVector<uint32_t> step_2 = {1, 1, 1, 1};
ttnn::slice(queue_id, grad, start_index_2, end_index_2, step_2, std::nullopt, other_grad);
grad_tensor[1] = other_grad;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ void set_or_update_runtime_arguments(
const auto& a = tensor_args.input_tensor_a;
const auto& b = tensor_args.input_tensor_b;

const auto ashape = a.get_padded_shape();
const auto bshape = b.has_value() ? b->get_padded_shape() : SimpleShape{1, 1};
const auto cshape = c.get_padded_shape();
const auto ashape = a.padded_shape();
const auto bshape = b.has_value() ? b->padded_shape() : SimpleShape{1, 1};
const auto cshape = c.padded_shape();

const auto [aN, aC, aHt, aWt] = extract_shape_dims(a);
const auto [bN, bC, bHt, bWt] = b.has_value() ? extract_shape_dims(*b) : std::tuple{1u, 1u, 1u, 1u};
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/eltwise/complex/complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void ComplexTensor::deallocate() {
}

ComplexTensor CreateComplexTensor::invoke(const Tensor& real, const Tensor& imag) {
TT_ASSERT(real.get_legacy_shape() == imag.get_legacy_shape(), "Tensor shapes of real and imag should be identical");
TT_ASSERT(real.padded_shape() == imag.padded_shape(), "Tensor shapes of real and imag should be identical");
return ComplexTensor({real, imag});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ Tensor _variance_impl(
const std::optional<MemoryConfig>& output_mem_config) {
ttnn::SmallVector<int> dims = {2, 3};
constexpr float correction = 0.0f;
auto shape_wh = y.get_legacy_shape();
auto shape_wh = y.padded_shape();
float scale = 1.0f / ((float)(shape_wh[3] * shape_wh[2]) - correction);
Tensor sqr_y_minus_mean_y = ttnn::square(y_minus_mean_y, output_mem_config);
return ttnn::sum(sqr_y_minus_mean_y, dims, true, std::nullopt, std::nullopt, scale);
Expand Down Expand Up @@ -599,10 +599,10 @@ Tensor ExecuteUnaryCompositeThreshold::invoke(
std::vector<Tensor> split_tensor_for_glu(
const Tensor& input_a, int32_t dim, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> t_split;
tt::tt_metal::LegacyShape inshape(input_a.get_legacy_shape());
ttnn::SimpleShape inshape(input_a.padded_shape());
TT_FATAL(((inshape[dim] / 2) % tt::constants::TILE_WIDTH == 0), "Split tensor dimension should be in full tile");
ttnn::SmallVector<uint32_t> s_a = {0, 0, 0, 0};
ttnn::SmallVector<uint32_t> e_a = {input_a.get_legacy_shape()[0], inshape[1], inshape[2], inshape[3] / 2};
ttnn::SmallVector<uint32_t> e_a = {input_a.padded_shape()[0], inshape[1], inshape[2], inshape[3] / 2};

ttnn::SmallVector<uint32_t> s_b = {0, 0, 0, inshape[3] / 2};
ttnn::SmallVector<uint32_t> e_b = {inshape[0], inshape[1], inshape[2], inshape[3]};
Expand Down
Loading
Loading