Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MeshWorkload: Initial Implementation #16405

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

tt-asaigal
Copy link
Contributor

@tt-asaigal tt-asaigal commented Jan 2, 2025

Ticket

#16409

Problem description

MeshWorkload APIs need to be implemented as per the spec presented here. Please see the issue for more details and the scope of this work.

What's changed

TT-Metal Dispatch Changes:

  • Expose finalize as a generic function templated on Program and MeshWorkload to support computing L1 offsets for both data structures through a shared path
  • Move write_program_command_sequence out of the EnqueueProgramCommand and expose it as a utility function, since it is used by both MeshWorkload and Program
  • Add an API to query the program dispatch core per CQ per device, since this is needed by MeshCommandQueue

TT-Mesh Changes:

  • Add the MeshWorkload class. Currently supports Single-Program-Multi-Device and Multi-Program-Multi-Device use cases. Heterogenous Runtime Args will be brought up in a separate commit.
  • Add the MeshCommandQueue class. Currently piggybacks off the single device Command Queues for performing IO. All functionality will eventually be moved into the MeshCommandQueue, once we support MeshBuffer reads and writes.
  • The MeshCommandQueue maintains independent accelerator state for dispatching MeshWorkloads. Since buffer reads and writes are still done through the single device CQs, this state must be in sync across all CQ objects. This is done through the experimental::write_program_commands function in mesh_workload_utils.hpp
  • Expose top level APIs to create, populate and enqueue a MeshWorkload through a MeshCommandQueue when using Fast Dispatch
  • Add several sanity, randomized and end to end tests for MeshWorkload creation and dispatch

Checklist

  • Post commit CI passes
  • Blackhole Post commit (if applicable)
  • Model regression CI testing passes (if applicable)
  • Device performance regression CI testing passes (if applicable)
  • (For models and ops writers) Full new models tests passes
  • New/Existing tests provide coverage for changes

@tt-asaigal tt-asaigal force-pushed the asaigal/mesh_workload branch from 0e77e6a to 6a95e83 Compare January 2, 2025 21:38
@tt-asaigal tt-asaigal marked this pull request as ready for review January 2, 2025 22:00
@tt-asaigal tt-asaigal force-pushed the asaigal/mesh_workload branch 3 times, most recently from 05ef361 to c1610b6 Compare January 3, 2025 17:15
@tt-asaigal
Copy link
Contributor Author

tests/tt_metal/distributed/distributed_fixture.hpp Outdated Show resolved Hide resolved
}

void TearDown() override {
mesh_device_->close_devices();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, @TT-BrianLiu found that mesh_device_ isn't even created if the test is skipped, so TearDown needs a conditional block.

This is a duplicated code btw, can we use the fixture you defined here in place of that in tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp? https://github.com/tenstorrent/tt-metal/blob/aaf2d7304c138c046efd5c0a94c14da3e6f95ce4/tests/tt_metal/tt_metal/common/multi_device_fixture.hpp would be even better!

Ideally the tests would #include this header directly, but I think for the time being you can add forwarding include in the ttnn_test_fixtures and just delete T3kMultiDeviceFixture.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the teardown step for the distributed fixture. Changing other mutl-device fixtures is outside the scope of this PR, will clean this up in a separate change.

tt_metal/distributed/distributed.hpp Show resolved Hide resolved
tt_metal/distributed/mesh_command_queue.hpp Outdated Show resolved Hide resolved
tt_metal/impl/program/program_dispatch_utils.cpp Outdated Show resolved Hide resolved
Comment on lines 40 to 41
template <typename T>
void finalize(T& workload_type, Device* device);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The friend relation here and in program_base_addr_on_core is non-ideal, especially when the state is mutated externally... Is it possible to keep the member function "finalize" on both mesh workload / program, but then move parts of the implementation here? For example, a shared utility may compute and return all of the necessary offsets / sizes based on the provided core type and some other parameters (e.g. the ones in workload.get_kernels(index), workload.get_kernel_groups(index), etc), then MeshWorkload / Program may use the result to set everything necessary internally?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For program_base_addr_on_core, you can similarly have a shared implementation that works on a vector of subdevices, and the last used cq - no need to templetize and no friend relation (unless I missed something?)

tests/tt_metal/distributed/test_mesh_workload.cpp Outdated Show resolved Hide resolved
tests/tt_metal/distributed/test_mesh_workload.cpp Outdated Show resolved Hide resolved
// The LogicalDeviceRange concept is fundamentally identical to the CoreRange concept
// Use this definition for now, since CoreRange contains several utility functions required
// in the MeshWorkload context. CoreRange can eventually be renamed to Range2D.
using LogicalDeviceRange = CoreRange;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 here

@tt-asaigal tt-asaigal force-pushed the asaigal/mesh_workload branch 4 times, most recently from f35feb2 to 118672f Compare January 6, 2025 21:27
tt_metal/impl/dispatch/command_queue.hpp Outdated Show resolved Hide resolved
tt_metal/distributed/mesh_device.cpp Outdated Show resolved Hide resolved
}

void TearDown() override {
mesh_device_->close_devices();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

std::unordered_map<LogicalDeviceRange, Program>& get_programs() { return this->programs_; }
// For testing purposes only
void set_last_used_command_queue_for_testing(MeshCommandQueue* mesh_cq);
MeshCommandQueue* get_last_used_command_queue() const;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this API just for testing? what is this used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is an API for testing purposes only (for now at least). Each MeshDevice will eventually have multiple command queues. Each CQ has a separate WorkerConfigBufferMgr object that tracks the state of the program config ring buffer in L1/SRAM.
The last used command queue is referenced to query this object and derive global addresses in functions like get_sem_base_addr and get_cb_base_addr.

We have essentially the same function in our tt_metal CQs for the same purpose.

// Main User-Facing API building blocks
MeshWorkload();
void add_program(const LogicalDeviceRange& device_range, Program& program);
std::unordered_map<LogicalDeviceRange, Program>& get_programs() { return this->programs_; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want users to update this map externally? should this just be a const-ref?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch here. We want users to be able to mutate certain program state (ex: RTA) but not the map itself. Made the fn return a const ref for the map and added separate getter for a program.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great!


void MeshWorkload::add_program(const LogicalDeviceRange& device_range, Program& program) {
// Add a program to a MeshWorkload and tie it a specific logical device range
this->programs_[device_range] = std::move(program);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug. After the move, the original program parameter will be left in a moved-from state. Should the signature be Program&& program

Copy link
Contributor Author

@tt-asaigal tt-asaigal Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this, not sure why it doesnt show up in the diff here. It's there in the latest commit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great!

@@ -278,6 +278,7 @@ class Device {
void load_sub_device_manager(SubDeviceManagerId sub_device_manager_id);
void clear_loaded_sub_device_manager();
LaunchMessageRingBufferState& get_worker_launch_message_buffer_state(SubDeviceId sub_device_id);
CoreCoord virtual_program_dispatch_core(uint8_t cq_id) const;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify what this API is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API returns the virtual coordinates of the dispatch core responsible for program management on device. TT-Mesh infra needs access to this when generating the FD commands for a MeshWorkload.

tt_metal/distributed/distributed.hpp Outdated Show resolved Hide resolved
tt_metal/distributed/distributed.cpp Outdated Show resolved Hide resolved
tt_metal/distributed/mesh_command_queue.hpp Outdated Show resolved Hide resolved
tt_metal/distributed/mesh_command_queue.hpp Outdated Show resolved Hide resolved
tt_metal/distributed/mesh_command_queue.hpp Outdated Show resolved Hide resolved
Comment on lines 1762 to 1768
template void finalize_program_offsets<Program>(Program&, Device*);
template void finalize_program_offsets<distributed::MeshWorkload>(distributed::MeshWorkload&, Device*);
template uint32_t program_base_addr_on_core<Program, Device*>(Program&, Device*, HalProgrammableCoreType);
template uint32_t program_base_addr_on_core<distributed::MeshWorkload, std::shared_ptr<distributed::MeshDevice>>(
distributed::MeshWorkload&, std::shared_ptr<distributed::MeshDevice>, HalProgrammableCoreType);
} // namespace program_dispatch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed.. let's just get rid of all this templating and just have interface classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went through a potential change with a Workload interface class here.

We can't remove the template parameters here, until MeshDevice and Device inherit from an IDevice object.

This is because MeshWorkload needs to be templated on MeshDevice and Program needs to be templated on Device, due to a disconnect between the 2 classes and their impl.

I think it makes sense to revisit this, once we have Artem's Device cleanup on main.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as agreed we'll rework this once Artem merges in the changes on IDevice 👍

@tt-asaigal tt-asaigal force-pushed the asaigal/mesh_workload branch 3 times, most recently from 7b0a60c to a40cc50 Compare January 7, 2025 02:40
@tt-asaigal tt-asaigal force-pushed the asaigal/mesh_workload branch 3 times, most recently from ed5d8ec to d4be00f Compare January 7, 2025 20:18
@tt-asaigal tt-asaigal force-pushed the asaigal/mesh_workload branch from d4be00f to f651fa0 Compare January 7, 2025 20:26

MeshWorkload CreateMeshWorkload() { return MeshWorkload(); }

void InsertProgramInMeshWorkload(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit and minor (so feel free to ignore), consider rename to AddProgramToMeshWorkload

CoreType dispatch_core_type_;

public:
MeshCommandQueue(MeshDevice* mesh_device, uint32_t id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does a MeshCommandQueue id correspond to, is there only 2 MeshCommandQueues per MeshDevice?

uint32_t num_workers = 0;
for (auto& device : this->mesh_device_->get_devices()) {
if (num_workers) {
TT_FATAL(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this always be true? if two devices have different num of harvested rows/cols then they can't be programmed with the same MeshWorkload even if the number of workers used by Workload is the same

}

void TearDown() override {
if (!::testing::Test::IsSkipped()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's probably just safest to check that mesh_device is not null

// Main User-Facing API building blocks
MeshWorkload();
void add_program(const LogicalDeviceRange& device_range, Program& program);
std::unordered_map<LogicalDeviceRange, Program>& get_programs() { return this->programs_; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great!


void MeshWorkload::add_program(const LogicalDeviceRange& device_range, Program& program) {
// Add a program to a MeshWorkload and tie it a specific logical device range
this->programs_[device_range] = std::move(program);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great!

Comment on lines 1762 to 1768
template void finalize_program_offsets<Program>(Program&, Device*);
template void finalize_program_offsets<distributed::MeshWorkload>(distributed::MeshWorkload&, Device*);
template uint32_t program_base_addr_on_core<Program, Device*>(Program&, Device*, HalProgrammableCoreType);
template uint32_t program_base_addr_on_core<distributed::MeshWorkload, std::shared_ptr<distributed::MeshDevice>>(
distributed::MeshWorkload&, std::shared_ptr<distributed::MeshDevice>, HalProgrammableCoreType);
} // namespace program_dispatch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as agreed we'll rework this once Artem merges in the changes on IDevice 👍

Comment on lines +4 to +5
// dispatch.hpp
#include "dispatch.hpp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove comment?

And normally the style in the codebase is to include the full path to the header right?

Comment on lines +299 to +308
program.get_program_config(index).rta_offset = rta_offset;
program.get_program_config(index).crta_offsets = crta_offsets;
program.get_program_config(index).crta_sizes = crta_sizes;
program.get_program_config(index).sem_offset = sem_offset;
program.get_program_config(index).sem_size = sem_size;
program.get_program_config(index).cb_offset = cb_offset;
program.get_program_config(index).cb_size = cb_size;
program.get_program_config(index).kernel_text_offset = kernel_text_offset;
program.get_program_config(index).kernel_text_size = kernel_text_size;
program.get_program_config_sizes()[index] = offset;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Would it be better to just do auto& program_config = program.get_program_config(index);, and then just use program_config?

Comment on lines +308 to +318
// void MeshWorkload::set_runtime_args(const LogicalDeviceRange& device_range, const CoreRangeSet& core_range_set,
// KernelHandle kernel_id, const std::vector<uint32_t> runtime_args) {
// std::size_t intersection_count = 0;

// for (auto& program_on_grid : this->programs_) {
// auto& program_device_range = program_on_grid.first;
// if (device_range.intersects(program_device_range)) {
// program_to_set_rt
// }
// }
// }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants