Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#16391: propagate sub_device_ids to mesh #16410

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Conversation

SeanNijjar
Copy link
Contributor

@SeanNijjar SeanNijjar commented Jan 2, 2025

  • Further update all-gather-async tests to pass subdevice ID information
    • Also modified the test to add fabric teardown in case of exception to avoid hangs :) (longer term this can hopefully be replaced with something cleaner like teardown callback registration exposed by metal so we don't need to wrap in try-catch

Ticket

Link to Github Issue

Problem description

All-gather v2 hangs when running with cluster axis API on persistent fabric

What's changed

In tests:

  • Updated tensor to/from calls to take subdevice

Infra:

  • Update mesh tensor mesh composer APIs to accept and properly handle subdevice IDs for copying tensors

Checklist

Closes #16391

@SeanNijjar
Copy link
Contributor Author

FYI @xuncaiTT

@SeanNijjar SeanNijjar force-pushed the snijjar/issue-16391 branch from 6254b84 to 3ff6e59 Compare January 3, 2025 15:56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we now also need to update the C++ APIs? See distributed_tensor.hpp

@tt-aho / @SeanNijjar - what is the high level plan for the APIs involving subdevices? Passing subdevice IDs to mesh composers is odd, as it is completely unrelated to the mesh distribution functionality. Do we plan to plumb subdevice IDs to all of the APIs that copy tensors under the hood? From the documentation: The sub-device IDs to wait on. Defaults to all sub-devices. - what does this mean exactly, do we wait before copying a tensor, or after? If this is a synchronization primitive, can we make it an explicit API instead, like ttnn.wait_for_subdevices(...)?

@ayerofieiev-tt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for stalling before the reading/writing of buffers.
I am currently working on a new api similar to what you have proposed, but it is for adjusting the default to stall on, instead of an explicit stall api that you proposed. Adjusting a stored/cached list of what to stall on minimizes the burden on the user to inject synchronization calls themselves, and having to track their own list to stall on everywhere. This should also allow us to remove the need to propagate sub_device_ids to all these apis.

Ex below:

What would be coded now

sub_device_0 = ...
sub_device_1 = ...
manager = create_manager([sub_device_0, sub_device_1])
load_manager(manager)
run_long_running_op_on_sub_device_1()
adjust_default_stalls([sub_device_0])
write_buffer(sub_device_ids=[sub_device_0])
run_op_on_sub_device_0()
read_buffer(sub_device_ids=[sub_device_0])

With new api (adjust_default_stalls is the new api and is not the final name for it)

sub_device_0 = ...
sub_device_1 = ...
manager = create_manager([sub_device_0, sub_device_1])
load_manager(manager)
run_long_running_op_on_sub_device_1()
adjust_default_stalls([sub_device_0])
write_buffer()
run_op_on_sub_device_0()
read_buffer()

Copy link
Contributor Author

@SeanNijjar SeanNijjar Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tt-aho - based on above discussion - I take it the recommendation here is to abandon part of this PR (the part that updates the mesh composer) and when your changes are available, rebase and merge (well... after review of course). Correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. This is my current pr for reference #16473. I'm planning to add the new api first, then remove the sub_device_ids propagation in the read/write apis in a subsequent pr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is great! Makes sense, also +1 to using the term set instead of adjust as per #16473.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

All-gather v2 hangs when running with cluster axis API on persistent fabric
5 participants