-
Notifications
You must be signed in to change notification settings - Fork 982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed GEMM #1907
base: main
Are you sure you want to change the base?
Distributed GEMM #1907
Conversation
Adds experimental support for running tensor parallel GEMMs natively through CUTLASS. Distributed GEMM (DistGEMM) implements communication-fused GEMMs using point-to-point communication, which allows for better pipelining, and theoretically can hide all communication behind computation. It also makes very few assumptions about the underlying kernel, and only adds a few barriers to the beginning of each GEMM kernel, and attempts to either use the epilogue source as the communication buffer, or a memcopy branch in the cuda graph, leaving SMs free for GEMMs, and the copy engine free for communication. When benchmarked with Llama 70B and 405B training shapes, DistGEMM can reach 70-80% of peak performance. A more detailed blog post on DistGEMM will be released soon.
include/cutlass/experimental/distributed/schedules/dist_gemm_base_schedule.hpp
Outdated
Show resolved
Hide resolved
/* ProcessorMappingA_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx | ||
/* ProcessorMappingB_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx | ||
/* ProcessorMappingC_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx | ||
/* ProcessorMappingD_ = */ cute::Layout<cute::Shape<TP_, _1>, cute::Stride<_1, _0>>, // (identity) = device_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "ProcessorMapping" is the same in all of these schedules and are used in-code very trivially.
(1) What do these parameterize and what other versions of this parameter do you expect to support?
(2) Why do they need to be CuTe Layouts? The sizes, shapes, ranks, strides are never used.
(3) The comments next to these parameters do not explain the domain of this function or the codomain.
(4) The size of the "bias mode" is always 1
and the stride is always 0
, yet these are indexed with a coordinate 1
in that mode always. This seems like a poor parameterization and Layout
s are not what you actually want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, ProcessorMapping is definitely unnecessary at this point. It was originally there in case the first iteration has a remote buffer, in which case they would map device index to the peer device's index.
I'll get rid of this; but just to clarify this is only in reference to ProcessorMapping and not IterationMapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still picking apart "IterationMapping"... it seems redundant with "PeerDeviceMapping", but that's still unclear to me. It also appears that the relationships between MNKL are not respected because each of these "Layout"s refers to A|B|C|D instead... I suspect there are a lot of implied invariants in this representation and those should be eliminated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PeerDeviceMapping can technically be represented with a boolean with the schedules implemented right now, but I didn't want to assume that it would necessarily stay that way with other schedules that we may add.
And correct, it's a little difficult to solely rely on the relationships of MNKL because operands are sharded and rotated differently in different schedules. Because DistGEMM is also supposed to have a separate buffer space for remote tensors and switch between them, it wasn't possible to maintain that behavior anymore.
I'll try and think of a better way to do this, but at some point, references to ABCD will be inevitable just because schedules shard them differently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And my point is that what is important is how the processors are moving through MNK space, which can then be translated to actions on ABC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for explaining. I think we could just replace IterationMapping{A,B,C,D}
with mappings to MNKL tile coordinates. It would only make sense given that the tile shape (IterationTiler
) is set up in size-4 tuples corresponding to MNKL anyway. Would moving to IterationMapping{M,N,K,L}
(or if I can figure out the mapping for it just one IterationMappingMNKL
layout) be a step in the right direction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we briefly discussed this before your internship ended -- this is just the same idea as tile schedulers we have at the kernel layer, but rather at the scale of the NVLink system. I think calling this a DistributedTileScheduler makes sense with the same abstract class hierarchy as the tile schedulers we already have. The job of this scheduler is to map a given tile coordinate (in this case, a tile of the global problem layout) to a given physical processor (in this case, the GPU ID). This will let you generalize this to 2D and 3D TP in the future. Distributed schedules (patterns as you called them) are then just different TV layouts where the T mode can be 1,2, or 3D depending on the TP strategy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course; that's the end goal, but at the same time it's kind of non-trivial to come up with an API that closely resembles kernel layer tile schedulers.
Right now it is handling cross-device and on-device tiling, and map device index and pipeline stage / iteration to tile coordinates. Are you saying that it should be broken up into two different components, one handling the tiling, and one handling the mapping to tile coordinates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partitioning is the job of the TV layout that maps the GPU rank to the values it extracts from the global coordinate. Scheduling adds the temporal component on top, which tells each GPU which coordinate to work on at which step of the computation. Separating them out into two makes the most sense. @ccecka agree?
Remove ProcessorMapping{A,B,C,D} and IterationMapping{A,B,C,D}, and use IterationMapping{M,N,K,L} instead.
Adds experimental support for running tensor parallel GEMMs natively through CUTLASS.
Distributed GEMM (DistGEMM) implements communication-fused GEMMs using point-to-point communication, which allows for better pipelining, and theoretically can hide all communication behind computation. It also makes very few assumptions about the underlying kernel, and only adds a few barriers to the beginning of each GEMM kernel, and attempts to either use the epilogue source as the communication buffer, or a memcopy branch in the cuda graph, leaving SMs free for GEMMs, and the copy engine free for communication.
When benchmarked with Llama 70B and 405B training shapes, DistGEMM can reach 70-80% of peak performance.
A more detailed blog post on DistGEMM will be released soon.