Skip to content

adding rotary embedding example, with graph rewrite for complex subgraph #3570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions examples/distributed_inference/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
.. _rotary_embedding:

Rotary Embedding Implementation for Tensor Parallel Attention
============================================================

This module provides an implementation of rotary positional embeddings (RoPE) for transformer models
with support for tensor parallel distributed inference. Rotary embeddings are used to encode positional
information in transformer attention mechanisms.
"""

import time

import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

"""
This example covers the rotary embedding and rotary attention case for tensor parallel
"""


def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, n_parallel=1
) -> torch.Tensor:
"""Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
n_parallel (int, optional): Number of GPUs for parallel computation. Defaults to 1.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim // n_parallel, 2).float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
return torch.polar(torch.ones_like(freqs), freqs)


def rotary_embedding(xq, xk, dim, freqs_cis=None):
"""This calculates the rotary embedding for the query and key tensors.
Args:
xq (torch.Tensor): Query tensor.
xk (torch.Tensor): Key tensor.
dim (int): Dimension of the query and key tensors.
freqs_cis (torch.Tensor, optional): Precomputed frequency tensor. Defaults to None.
Returns:
tuple: Tuple containing the rotated query and key tensors.
"""
freqs_cis = freqs_cis[None, :, None, :]
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return (xq_out.type_as(xq), xk_out.type_as(xk))


########Tensor Parallel########
def parallel_rotary_block(rotary_block, tp_mesh):
"""Parallel rotary block for tensor parallel
Args:
rotary_block: Rotary block to parallelize
tp_mesh: Tensor parallel mesh
"""
if tp_mesh.size() <= 1:
return

plan = {
"wq": ColwiseParallel(),
"wk": ColwiseParallel(),
"wo": RowwiseParallel(output_layouts=Shard(0)),
}
rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode

parallelize_module(rotary_block, tp_mesh, plan)


class RotaryAttention(nn.Module):
def __init__(self, dim: int, seq_len: int):
super().__init__()
self.dim = dim
self.wq = nn.Linear(dim, dim)
self.wk = nn.Linear(dim, dim)
self.wo = nn.Linear(dim, dim)
self.seq_len = seq_len
self.n_parallel = 1
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
self.init_weights()

def _precompute_freqs_cis(self) -> torch.Tensor:
theta = 10000.0
return precompute_freqs_cis(self.dim, self.seq_len, theta, self.n_parallel)

def init_weights(self):
with torch.device(self.freqs_cis.device):
self.freqs_cis = self.freqs_cis

def forward(self, x):
q = self.wq(x)
k = self.wk(x)
freqs_cis = self._precompute_freqs_cis().to(q.device)
q, k = rotary_embedding(q, k, self.dim, freqs_cis=freqs_cis)
return self.wo(q)
14 changes: 14 additions & 0 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
.. _tensor_parallel_initialize_dist:
Tensor Parallel Initialize Distributed Environment
==================================================

This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference.
"""

import logging
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -65,3 +73,9 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger


def cleanup_distributed_env():
"""Clean up distributed process group to prevent resource leaks."""
if dist.is_initialized():
dist.destroy_process_group()
70 changes: 70 additions & 0 deletions examples/distributed_inference/tensor_parallel_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
.. _tensor_parallel_rotary_embedding:
Tensor Parallel Rotary Embedding Example
=======================================

This example demonstrates how to use Torch-TensorRT with tensor parallel distributed inference
for models that use rotary positional embeddings (RoPE). It lowers the complex
operations in attention models with rotary embeddings across multiple GPUs.

"""

import logging
import os
import time

import torch
import torch_tensorrt
from rotary_embedding import RotaryAttention, parallel_rotary_block
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_rotary_embedding"
)


"""
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.pyx
"""

BATCH = 2
SEQ_LEN = 128
HEADS = 4
DIM = 128

with torch.no_grad():
model = RotaryAttention(DIM, SEQ_LEN)
parallel_rotary_block(model, device_mesh)
device = torch.device("cuda", device_mesh.get_rank())
model.to(device)
x = torch.randn(BATCH, SEQ_LEN, HEADS, DIM).to(device)

python_result = model(x)

logger.info("Torch-tensorrt compilation for rotary embedding")

model = torch.compile(model, backend="torch_tensorrt")

try:
for i in range(15):
# seeding with dp_rank to ensure identical inputs for TP groups
torch.manual_seed(i)
start = time.time()
output = model(x)
end = time.time()
if i == 0:
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
logger.info(f"Inference time is {end-start}")
except Exception as e:
logger.error(f"Error: {e}")
raise e
finally:
cleanup_distributed_env()
Comment on lines +52 to +70
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a try/except and finally here for the example ? And also why are we looping ? If you want to display any results of this block, please use the right formatting https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/torch_export_sam2.py#L282-L297

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The try/except block is for the case where in we check for the inference time improvement over a no of iterations after the graph compilation to see performance improvement without the graph breaks. I am not clear about the formatting part pointed above in the link. Did you point it out because the try loop wont be rendered? I see that the rendering is correct.

I could as such remove the loop too

31 changes: 27 additions & 4 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
"""
.. _tensor_parallel_simple_example:

Torch Parallel Distributed example for simple model
=========================================

Below example shows how to use Torch-TensorRT backend for distributed inference with tensor parallelism.

This example demonstrates:
- Setting up distributed environment for tensor parallelism
- Model sharding across multiple GPUs
- Compilation with Torch-TensorRT
- Distributed inference execution

Usage
-----
.. code-block:: bash

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
"""

import time

import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import initialize_distributed_env
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
Expand All @@ -18,7 +42,7 @@
)

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""


Expand Down Expand Up @@ -97,5 +121,4 @@ def forward(self, x):
logger.info(f"Inference time is {end-start}")
finally:
# This cleans up the distributed process group
if dist.is_initialized():
dist.destroy_process_group()
cleanup_distributed_env()
10 changes: 8 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/ops_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def getitem_validator(getitem_node: Node, settings: CompilationSettings = None)
from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS

# Getitem nodes can only be converted if their parent node also can
return getitem_node.args[0] in DYNAMO_CONVERTERS
return (
getitem_node.args[0] in DYNAMO_CONVERTERS
or getitem_node.args[0].op == "get_attr"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed because the complex tensor is at present a buffer, wherein we extract the real and imag part through input[..., 0], input[..., 1]. This leads to getitem node with arg being a get_attr node.
If I don't include the above it leads to graph break for this part in GPU, which is unnecessary if we support get_attr.

graph():
   %_frozen_param3_reshaped : [num_users=2] = get_attr[target=_frozen_param3_reshaped]
   %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_frozen_param3_reshaped, (Ellipsis, 0)), kwargs = {})
   %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%_frozen_param3_reshaped, (Ellipsis, 1)), kwargs = {})
   return (getitem, getitem_1)

)


# TODO: Subsequent evaluators should be registered here with their own validators
Expand All @@ -43,7 +46,10 @@ def generic_evaluator(
_LOGGER.debug(
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
)
return target(*args)
from torch._subclasses.fake_tensor import unset_fake_temporarily

with unset_fake_temporarily():
return target(*args)


def rand_validator(rand_node: Node, settings: CompilationSettings = None) -> bool:
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ._aten_lowering_pass import *
from ._modify_reshape_complex_nodes import modify_reshape_complex_nodes
from .remove_sym_nodes import remove_sym_nodes
from .repair_input_aliasing import repair_input_aliasing
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_tensorrt.dynamo.utils import is_tegra_platform

from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .complex_graph_rewrite import complex_graph_detection
from .constant_folding import constant_fold
from .fuse_distributed_ops import fuse_distributed_ops
from .fuse_prims_broadcast import fuse_prims_broadcast
Expand All @@ -26,6 +27,7 @@
remove_assert_nodes,
accumulate_fp32_matmul,
remove_num_users_is_0_nodes,
complex_graph_detection,
]

pre_lowering_pass_list = [
Expand Down
Loading
Loading