Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed GEMM #1907

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
861 changes: 861 additions & 0 deletions examples/64_distributed_gemm/64_distributed_gemm.cu

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions examples/64_distributed_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

cutlass_example_add_executable(
64_distributed_gemm
64_distributed_gemm.cu
)
64 changes: 64 additions & 0 deletions examples/64_distributed_gemm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Distributed GEMM

This example implements Tensor Parallel GEMMs for the Hopper architecture with the experimental
[Distributed GEMM](../../include/cutlass/experimental/distributed) API in CUTLASS.

This example requires Hopper GPUs with an any-to-any NVLink network.
Please refer to [REQUIREMENTS.md](REQUIREMENTS.md) for more information.

By default, the example assumes 8 GPUs (TP=8) and runs an All Gather + GEMM operation, which rotates
operand A. To run with a different number of GPUs or schedule, please refer to
[64_distributed_gemm.cu](64_distributed_gemm.cu).


## Getting started

Command line arguments are mostly similar to other examples:

```
--m=<int> Sets the M extent of the GEMM
--n=<int> Sets the N extent of the GEMM
--k=<int> Sets the K extent of the GEMM
--l=<int> Sets the L extent (batch) of the GEMM (default: 1)
--alpha=<f32> Epilogue scalar alpha (default: 1.0)
--beta=<f32> Epilogue scalar beta (default: 0.0)
--iterations=<int> Number of profiling iterations to perform (default: 100)
--warmup-iterations=<int> Number of warmup iterations prior to profiling (default: 10)
--eps=<f32> Threshold for error compared to reference GEMM (default: 0.0)
```

Sample run command:

```bash
./64_distributed_gemm --m=16384 --n=106496 --k=16384 --warmup-iterations=10 --iterations=100
```

This executes a GEMM with shape `<16384, 106496, 16384>`, and reports average runtime
over 100 iterations, with 10 warmup iterations.
A reference check with respect to a single-device GEMM is also performed by default.

## Trying out other schedules

Schedules that are currently supported are:

* All Gather + GEMM:
* `AllGather1D_TilingCD_RotatingA`
* `AllGather1D_TilingCD_RotatingB`

* GEMM + Reduce Scatter:
* `ReduceScatter1D_TilingA_RotatingC`
* `ReduceScatter1D_TilingB_RotatingC`

To try out different schedules, simply change this line in the example, and set your desired
schedule:

```cpp
using DistSchedule = cutlass::distributed::schedules::AllGather1D_TilingCD_RotatingA<TP>;
```

If you're interesting it trying out other TP values (run on a different number of GPUs), the
procedure is the same, simply modify the following line in the example:

```cpp
using TP = _8;
```
86 changes: 86 additions & 0 deletions examples/64_distributed_gemm/REQUIREMENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Distributed GEMM

## Requirements

### Build
Make sure to set up CUTLASS with
support for [Programmatic Dependent Launch (PDL)](../../media/docs/dependent_kernel_launch.md),
that is with the `CUTLASS_ENABLE_GDC_FOR_SM90` flag.

```bash
cmake $PATH -DCUTLASS_NVCC_ARCHS="90a" -DCUTLASS_ENABLE_GDC_FOR_SM90=1
```

### Minimum software

Like all other CUTLASS examples, the NVIDIA driver, runtime, and CUDA Toolkit are required.
This example specifically requires CUDA Toolkit 12.6 or newer, due to some of the necessary
CUDA graph APIs.

### Hardware / driver settings

This example requires Hopper GPUs with NVLink network.

If you're not sure, first run the following command and make sure your GPU
compute capability is 9.0:

```bash
nvidia-smi --query-gpu=name,compute_cap --format=csv
```

Sample output:

```
name, compute_cap
NVIDIA H100 80GB HBM3, 9.0
NVIDIA H100 80GB HBM3, 9.0
NVIDIA H100 80GB HBM3, 9.0
NVIDIA H100 80GB HBM3, 9.0
NVIDIA H100 80GB HBM3, 9.0
NVIDIA H100 80GB HBM3, 9.0
NVIDIA H100 80GB HBM3, 9.0
NVIDIA H100 80GB HBM3, 9.0
```


Then you should make sure there is an NVLink network by checking the GPU network topology,
and making sure there's `NV*` links between every pair of GPUs:

```bash
nvidia-smi topo -m
```

Sample output:

```
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X
```

Finally, check if the driver enables peer to peer access, which should usually be the case,
but it's good to check anyway:

```bash
nvidia-smi topo -p2p r
```

Sample output:

```
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 X OK OK OK OK OK OK OK
GPU1 OK X OK OK OK OK OK OK
GPU2 OK OK X OK OK OK OK OK
GPU3 OK OK OK X OK OK OK OK
GPU4 OK OK OK OK X OK OK OK
GPU5 OK OK OK OK OK X OK OK
GPU6 OK OK OK OK OK OK X OK
GPU7 OK OK OK OK OK OK OK X
```
118 changes: 118 additions & 0 deletions examples/64_distributed_gemm/util/benchmark.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/

/*! \file
\brief Benchmark helpers for Distributed GEMM

A delay kernel to gate all GEMMs across devices, controlled by a flag that
the host will set off once it launches DistGEMM across all devices.

DistGpuTimer extends cutlass's existing cudaEvent-based timer to multiple devices.
*/

#pragma once

#include <iostream>
#include <cuda/atomic>
#include <cuda/std/atomic>


namespace cutlass {

/////////////////////////////////////////////////////////////////////////////////////////////////
/// Delay kernel
/////////////////////////////////////////////////////////////////////////////////////////////////

using AtomicBoolean = cuda::atomic<bool>;

__global__ void delay_kernel(const AtomicBoolean* atomic_flag_ptr) {
while (not atomic_flag_ptr->load()) {
__nanosleep(40);
}
}


/////////////////////////////////////////////////////////////////////////////////////////////////
/// Distributed GPU Timer
/// Sets up cuda events for multiple processors.
/////////////////////////////////////////////////////////////////////////////////////////////////
template <int NP>
struct DistGpuTimer {
int _primary_device;
cudaEvent_t _start[NP];
cudaEvent_t _stop[NP];

/// Constructor
DistGpuTimer()
{
CUDA_CHECK(cudaGetDevice(&_primary_device));
for (int device = 0; device < NP; ++device) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(cudaEventCreate(&_start[device]));
CUDA_CHECK(cudaEventCreate(&_stop[device]));
}
CUDA_CHECK(cudaSetDevice(_primary_device));
}

/// Destructor
~DistGpuTimer()
{
for (int device = 0; device < NP; ++device) {
CUDA_CHECK(cudaSetDevice(device));
CUDA_CHECK(cudaEventDestroy(_start[device]));
CUDA_CHECK(cudaEventDestroy(_stop[device]));
}
CUDA_CHECK(cudaSetDevice(_primary_device));
}

/// Start the timer for a given stream (defaults to the default stream)
void start(int device, cudaStream_t stream) {
assert(device >= 0 && device < NP);
CUDA_CHECK(cudaEventRecord(_start[device], stream));
}

/// Stop the timer
void stop(int device, cudaStream_t stream) {
assert(device >= 0 && device < NP);
CUDA_CHECK(cudaEventRecord(_stop[device], stream));
}

/// Return the elapsed time (in milliseconds)
float elapsed_millis(int device) {
assert(device >= 0 && device < NP);
float elapsed = 0.0;
CUDA_CHECK(cudaEventSynchronize(_stop[device]));
CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start[device], _stop[device]));
return elapsed;
}
};

} //namespace cutlass
82 changes: 82 additions & 0 deletions examples/64_distributed_gemm/util/device_copy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/******************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/

/*! \file
\brief generic device-to-device data movement kernel based for CuTe tensors.

NOTE: this kernel assigns one element copy to every thread, and is by no means
an efficient way of copying tensors. It should only be used for convenience in
reference checks.

*/

#pragma once

#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/cuda_host_adapter.hpp"

namespace cutlass {

template <typename TensorSource, typename TensorDestination>
void device_copy(TensorSource tensor_source,
TensorDestination tensor_destination,
cudaStream_t stream);


template <typename TensorSource, typename TensorDestination>
__global__ void device_copy_kernel(TensorSource const tensor_source,
TensorDestination tensor_destination) {
auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
using ElementDst = typename TensorDestination::value_type;
if (linear_idx < size(tensor_source)) {
tensor_destination(linear_idx) = static_cast<ElementDst>(tensor_source(linear_idx));
}
}

template <typename TensorSource, typename TensorDestination>
void device_copy(TensorSource tensor_source,
TensorDestination tensor_destination,
cudaStream_t stream) {

assert(tensor_source.size() == tensor_destination.size());

auto numel = tensor_source.size();
static constexpr int NumThreads = 128;
auto grid_size = cute::ceil_div(numel, NumThreads);

dim3 grid(grid_size);
dim3 block(NumThreads);
device_copy_kernel<<<grid, block, 0, stream>>>(tensor_source, tensor_destination);
}

} //namespace cutlass
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ foreach(EXAMPLE
61_hopper_gemm_with_topk_and_softmax
62_hopper_sparse_gemm
63_hopper_gemm_with_weight_prefetch
64_distributed_gemm
)

add_subdirectory(${EXAMPLE})
Expand Down
Loading