Skip to content

Commit

Permalink
#16318: Add support for act_block_h_override to Width Sharded Conv2d (#…
Browse files Browse the repository at this point in the history
…16374)

### Ticket
#16318 

### Problem description
Conv2d with large input height/width leads to OOM due to big CBs.

### What's changed
By enabling act_block_h_override, it reduces the size of CBs, preventing
OOM errors.

### Checklist
- [x] Post commit CI
[passes](https://github.com/tenstorrent/tt-metal/actions/runs/12592969729)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sankarmanoj-tt authored Jan 3, 2025
1 parent a9c2b0e commit c675228
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 204 deletions.
24 changes: 10 additions & 14 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,20 +354,16 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config(
div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT);

if (act_block_h_override > 0) {
if (parallel_config.shard_scheme == TensorMemoryLayout::WIDTH_SHARDED) {
log_info(LogOp, "act_block_h_override is set, but ignored when Width Sharding is used");
uint32_t act_block_h_override_ntiles = act_block_h_override / constants::TILE_HEIGHT;
if (padded_output_height_ntiles % act_block_h_override_ntiles == 0) {
act_block_h_ntiles = act_block_h_override_ntiles;
} else {
uint32_t act_block_h_override_ntiles = act_block_h_override / constants::TILE_HEIGHT;
if (padded_output_height_ntiles % act_block_h_override_ntiles == 0) {
act_block_h_ntiles = act_block_h_override_ntiles;
} else {
log_info(
LogOp,
"act_block_h_override {} is not a valid override for padded_output_height_ntiles {}, override will "
"be ignored",
act_block_h_override_ntiles,
padded_output_height_ntiles);
}
log_info(
LogOp,
"act_block_h_override {} is not a valid override for padded_output_height_ntiles {}, override will "
"be ignored",
act_block_h_override_ntiles,
padded_output_height_ntiles);
}
}

Expand Down Expand Up @@ -836,7 +832,7 @@ void adjust_conv_op_config_for_auto_shard_if_necessary(
shard_orientation,
compute_grid_size);

if (conv_config.act_block_h_override == 0 && conv_config.shard_layout != TensorMemoryLayout::WIDTH_SHARDED) {
if (conv_config.act_block_h_override == 0) {
if (in_channels <= constants::TILE_WIDTH / 2 && conv_config.input_channels_alignment == constants::TILE_WIDTH &&
!is_mm_conv && conv_config.shard_layout == TensorMemoryLayout::HEIGHT_SHARDED &&
input_tensor_layout == Layout::ROW_MAJOR) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,16 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh
if (has_bias) {
bias_buffer = bias.value().buffer();
bias_dram_addr = bias_buffer->address();
bias_ntiles =
bias.value().get_legacy_shape()[3] / constants::TILE_WIDTH; // TODO: support non tile multiple sizes
bias_ntiles = weight_block_w_ntiles;
bias_in_dram = bias_buffer->buffer_type() == BufferType::DRAM;
}

uint32_t num_weight_slices_width = weight_matrix_width_ntiles / per_core_out_matrix_width_ntiles;
uint32_t num_blocks_act_h_per_core =
(per_core_out_matrix_height_ntiles + act_block_h_ntiles - 1) / act_block_h_ntiles;
uint32_t num_blocks_weight_w_per_core = per_core_out_matrix_width_ntiles / weight_block_w_ntiles;
uint32_t bias_ntiles_per_core = bias_ntiles / num_weight_slices_width;

auto output_shape = sliding_window_config.get_output_shape();
uint32_t conv_output_size_h = output_shape[1];
uint32_t conv_output_size_w = output_shape[2];
Expand Down Expand Up @@ -463,7 +468,6 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh
out_block_h_ntiles * act_block_w_ntiles * tt::tt_metal::detail::TileSize(act_df);
uint32_t dst_l1_weight_buffer_size_bytes =
weight_block_h_ntiles * weight_block_w_ntiles * tt::tt_metal::detail::TileSize(weight_df);

// Number of bytes to be read from the channel dimension in one block.
uint32_t conv_act_c_read_bytes = conv_act_size_c * a.element_size() / (input_num_cores * per_core_num_blocks_act_w);

Expand All @@ -482,6 +486,7 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh
log_debug(LogOp, "weight_block_w_ntiles: {}", weight_block_w_ntiles);
log_debug(LogOp, "out_subblock_h_ntiles_padded: {}", out_subblock_h_ntiles_padded);
log_debug(LogOp, "out_subblock_w_ntiles: {}", out_subblock_w_ntiles);
log_debug(LogOp, "num_blocks_weight_w_per_core: {}", num_blocks_weight_w_per_core);
}

// For debug
Expand Down Expand Up @@ -576,6 +581,7 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh
(uint32_t)act_block_h_datums,
(uint32_t)act_block_num_tiles,
(uint32_t)input_num_cores,
(uint32_t)num_blocks_act_h_per_core,
(uint32_t)per_core_num_blocks_act_w,
(uint32_t)act_mcast_sender_semaphore,
(uint32_t)act_mcast_receiver_semaphore,
Expand All @@ -598,15 +604,10 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh
per_core_num_blocks_act_w,
input_num_cores, // other_core_weight_height_blocks
per_core_num_blocks_act_w, // this_core_weight_height_blocks
num_blocks_act_h_per_core,
bias_cb,
bias_in_dram};

uint32_t num_weight_slices_width = weight_matrix_width_ntiles / per_core_out_matrix_width_ntiles;
uint32_t num_blocks_act_h_per_core =
(per_core_out_matrix_height_ntiles + act_block_h_ntiles - 1) / act_block_h_ntiles;
uint32_t num_blocks_weight_w_per_core = per_core_out_matrix_width_ntiles / weight_block_w_ntiles;
uint32_t bias_ntiles_per_core = bias_ntiles / num_weight_slices_width;

std::map<string, string> writer_defines;
std::map<string, string> writer_mcast_sender_defines;
std::map<string, string> compute_defines;
Expand Down Expand Up @@ -662,7 +663,7 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh
tilize_in0, // tilize_in0
untilize_out, // untilize_out

bias_ntiles_per_core,
bias_ntiles,

out0_cb,
num_output_tiles,
Expand All @@ -678,10 +679,12 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh
uint32_t weight_tile_size = tt_metal::detail::TileSize(weight_df);

// Local L1 to store temp vars
// Used for act_mcast_sender_semaphore_valid_addr_ptr
CircularBufferConfig cb_for_l1_array_config =
CircularBufferConfig(32 * 2, {{cb_for_l1_array, tt::DataFormat::Float16_b}})
.set_page_size(cb_for_l1_array, 32 * 2);
tt_metal::CreateCircularBuffer(program, all_cores, cb_for_l1_array_config);
log_debug(LogOp, "CB for L1 Array CB: {}, npages: {}, pagesize: {}", cb_for_l1_array, 1, 32 * 2);

CircularBufferConfig cb_sharded_act_config =
CircularBufferConfig(shard_shape[0] * shard_shape[1] * datum_size(act_df), {{sharded_act_cb, act_df}})
Expand Down Expand Up @@ -762,13 +765,14 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh
.set_page_size(matmul_partials_cb, out_tile_size);
if (output.is_sharded()) {
cb_matmul_partials_config = cb_matmul_partials_config.set_globally_allocated_address(*output.buffer());
} else {
log_debug(
LogOp,
"Matmul Partials CB: {}, npages: {}, pagesize: {}",
matmul_partials_cb,
num_output_tiles,
out_tile_size);
}
log_debug(
LogOp,
"Matmul Partials CB: {}, npages: {}, pagesize: {}",
matmul_partials_cb,
num_output_tiles,
out_tile_size);
cb_output = tt_metal::CreateCircularBuffer(program, all_cores, cb_matmul_partials_config);
} else {
// Separate buffer if not same data format
Expand Down
Loading

0 comments on commit c675228

Please sign in to comment.