-
Notifications
You must be signed in to change notification settings - Fork 364
adding rotary embedding example, with graph rewrite for complex subgraph #3570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9d5b3c0
5a2ad50
f5cc275
a90f651
109e5c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" | ||
.. _rotary_embedding: | ||
|
||
Rotary Embedding Implementation for Tensor Parallel Attention | ||
============================================================ | ||
|
||
This module provides an implementation of rotary positional embeddings (RoPE) for transformer models | ||
with support for tensor parallel distributed inference. Rotary embeddings are used to encode positional | ||
information in transformer attention mechanisms. | ||
""" | ||
|
||
import time | ||
|
||
import tensorrt as trt | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
import torch_tensorrt | ||
from tensor_parallel_initialize_dist import initialize_distributed_env | ||
from torch.distributed._tensor import Shard | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
RowwiseParallel, | ||
parallelize_module, | ||
) | ||
|
||
""" | ||
This example covers the rotary embedding and rotary attention case for tensor parallel | ||
""" | ||
|
||
|
||
def precompute_freqs_cis( | ||
dim: int, end: int, theta: float = 10000.0, n_parallel=1 | ||
) -> torch.Tensor: | ||
"""Precompute the frequency tensor for complex exponentials (cis) with given dimensions. | ||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' | ||
and the end index 'end'. The 'theta' parameter scales the frequencies. | ||
The returned tensor contains complex values in complex64 data type. | ||
Args: | ||
dim (int): Dimension of the frequency tensor. | ||
end (int): End index for precomputing frequencies. | ||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. | ||
n_parallel (int, optional): Number of GPUs for parallel computation. Defaults to 1. | ||
Returns: | ||
torch.Tensor: Precomputed frequency tensor with complex exponentials. | ||
""" | ||
freqs = 1.0 / (theta ** (torch.arange(0, dim // n_parallel, 2).float() / dim)) | ||
t = torch.arange(end, device=freqs.device) | ||
freqs = torch.outer(t, freqs).float() | ||
return torch.polar(torch.ones_like(freqs), freqs) | ||
|
||
|
||
def rotary_embedding(xq, xk, dim, freqs_cis=None): | ||
"""This calculates the rotary embedding for the query and key tensors. | ||
Args: | ||
xq (torch.Tensor): Query tensor. | ||
xk (torch.Tensor): Key tensor. | ||
dim (int): Dimension of the query and key tensors. | ||
freqs_cis (torch.Tensor, optional): Precomputed frequency tensor. Defaults to None. | ||
Returns: | ||
tuple: Tuple containing the rotated query and key tensors. | ||
""" | ||
freqs_cis = freqs_cis[None, :, None, :] | ||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | ||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | ||
|
||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | ||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | ||
return (xq_out.type_as(xq), xk_out.type_as(xk)) | ||
|
||
|
||
########Tensor Parallel######## | ||
def parallel_rotary_block(rotary_block, tp_mesh): | ||
"""Parallel rotary block for tensor parallel | ||
Args: | ||
rotary_block: Rotary block to parallelize | ||
tp_mesh: Tensor parallel mesh | ||
""" | ||
if tp_mesh.size() <= 1: | ||
return | ||
|
||
plan = { | ||
"wq": ColwiseParallel(), | ||
"wk": ColwiseParallel(), | ||
"wo": RowwiseParallel(output_layouts=Shard(0)), | ||
} | ||
rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode | ||
|
||
parallelize_module(rotary_block, tp_mesh, plan) | ||
|
||
|
||
class RotaryAttention(nn.Module): | ||
def __init__(self, dim: int, seq_len: int): | ||
super().__init__() | ||
self.dim = dim | ||
self.wq = nn.Linear(dim, dim) | ||
self.wk = nn.Linear(dim, dim) | ||
self.wo = nn.Linear(dim, dim) | ||
self.seq_len = seq_len | ||
self.n_parallel = 1 | ||
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) | ||
self.init_weights() | ||
|
||
def _precompute_freqs_cis(self) -> torch.Tensor: | ||
theta = 10000.0 | ||
return precompute_freqs_cis(self.dim, self.seq_len, theta, self.n_parallel) | ||
|
||
def init_weights(self): | ||
with torch.device(self.freqs_cis.device): | ||
self.freqs_cis = self.freqs_cis | ||
|
||
def forward(self, x): | ||
q = self.wq(x) | ||
k = self.wk(x) | ||
freqs_cis = self._precompute_freqs_cis().to(q.device) | ||
q, k = rotary_embedding(q, k, self.dim, freqs_cis=freqs_cis) | ||
return self.wo(q) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
.. _tensor_parallel_rotary_embedding: | ||
Tensor Parallel Rotary Embedding Example | ||
======================================= | ||
|
||
This example demonstrates how to use Torch-TensorRT with tensor parallel distributed inference | ||
for models that use rotary positional embeddings (RoPE). It lowers the complex | ||
operations in attention models with rotary embeddings across multiple GPUs. | ||
|
||
""" | ||
|
||
import logging | ||
import os | ||
import time | ||
|
||
import torch | ||
import torch_tensorrt | ||
from rotary_embedding import RotaryAttention, parallel_rotary_block | ||
from tensor_parallel_initialize_dist import ( | ||
cleanup_distributed_env, | ||
initialize_distributed_env, | ||
) | ||
|
||
device_mesh, _world_size, _rank, logger = initialize_distributed_env( | ||
"./tensor_parallel_rotary_embedding" | ||
) | ||
|
||
|
||
""" | ||
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning | ||
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.pyx | ||
""" | ||
|
||
BATCH = 2 | ||
SEQ_LEN = 128 | ||
HEADS = 4 | ||
DIM = 128 | ||
|
||
with torch.no_grad(): | ||
model = RotaryAttention(DIM, SEQ_LEN) | ||
parallel_rotary_block(model, device_mesh) | ||
device = torch.device("cuda", device_mesh.get_rank()) | ||
model.to(device) | ||
x = torch.randn(BATCH, SEQ_LEN, HEADS, DIM).to(device) | ||
|
||
python_result = model(x) | ||
|
||
logger.info("Torch-tensorrt compilation for rotary embedding") | ||
|
||
model = torch.compile(model, backend="torch_tensorrt") | ||
|
||
try: | ||
for i in range(15): | ||
# seeding with dp_rank to ensure identical inputs for TP groups | ||
torch.manual_seed(i) | ||
start = time.time() | ||
output = model(x) | ||
end = time.time() | ||
if i == 0: | ||
logger.info(f"Compilation time is {end-start}") | ||
assert ( | ||
python_result - output | ||
).std() < 0.01, "Compilation result is not correct." | ||
elif _rank == 0: | ||
logger.info(f"Inference time is {end-start}") | ||
except Exception as e: | ||
logger.error(f"Error: {e}") | ||
raise e | ||
finally: | ||
cleanup_distributed_env() | ||
Comment on lines
+52
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need a try/except and finally here for the example ? And also why are we looping ? If you want to display any results of this block, please use the right formatting https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/torch_export_sam2.py#L282-L297 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The try/except block is for the case where in we check for the inference time improvement over a no of iterations after the graph compilation to see performance improvement without the graph breaks. I am not clear about the formatting part pointed above in the link. Did you point it out because the try loop wont be rendered? I see that the rendering is correct. I could as such remove the loop too |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,10 @@ def getitem_validator(getitem_node: Node, settings: CompilationSettings = None) | |
from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS | ||
|
||
# Getitem nodes can only be converted if their parent node also can | ||
return getitem_node.args[0] in DYNAMO_CONVERTERS | ||
return ( | ||
getitem_node.args[0] in DYNAMO_CONVERTERS | ||
or getitem_node.args[0].op == "get_attr" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this needed ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed because the complex tensor is at present a buffer, wherein we extract the real and imag part through input[..., 0], input[..., 1]. This leads to getitem node with arg being a get_attr node.
|
||
) | ||
|
||
|
||
# TODO: Subsequent evaluators should be registered here with their own validators | ||
|
@@ -43,7 +46,10 @@ def generic_evaluator( | |
_LOGGER.debug( | ||
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}" | ||
) | ||
return target(*args) | ||
from torch._subclasses.fake_tensor import unset_fake_temporarily | ||
|
||
with unset_fake_temporarily(): | ||
return target(*args) | ||
|
||
|
||
def rand_validator(rand_node: Node, settings: CompilationSettings = None) -> bool: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
from ._aten_lowering_pass import * | ||
from ._modify_reshape_complex_nodes import modify_reshape_complex_nodes | ||
from .remove_sym_nodes import remove_sym_nodes | ||
from .repair_input_aliasing import repair_input_aliasing |
Uh oh!
There was an error while loading. Please reload this page.