From 2d5687c9e0a1bc343168b9f62319a0ea310abd59 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 19 Aug 2025 00:25:24 -0700 Subject: [PATCH] Distributed llama3 example --- .../distributed_inference/llama3_model.py | 496 ++++++++++++++++++ .../distributed_inference/rotary_embedding.py | 6 +- .../tensor_parallel_llama3.py | 68 +++ .../tensor_parallel_rotary_embedding.py | 2 +- py/torch_tensorrt/dynamo/_compiler.py | 9 +- py/torch_tensorrt/dynamo/backend/backends.py | 5 - .../passes/_modify_reshape_complex_nodes.py | 105 ---- .../_replace_complex_placeholder_to_tuple.py | 112 ---- .../lowering/passes/complex_graph_rewrite.py | 16 + .../runtime/_CudaGraphsTorchTensorRTModule.py | 4 + .../runtime/_PythonTorchTensorRTModule.py | 4 + .../dynamo/runtime/_TorchTensorRTModule.py | 2 + py/torch_tensorrt/dynamo/runtime/utils.py | 8 + 13 files changed, 603 insertions(+), 234 deletions(-) create mode 100644 examples/distributed_inference/llama3_model.py create mode 100644 examples/distributed_inference/tensor_parallel_llama3.py delete mode 100644 py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py delete mode 100644 py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py create mode 100644 py/torch_tensorrt/dynamo/runtime/utils.py diff --git a/examples/distributed_inference/llama3_model.py b/examples/distributed_inference/llama3_model.py new file mode 100644 index 0000000000..dddee63871 --- /dev/null +++ b/examples/distributed_inference/llama3_model.py @@ -0,0 +1,496 @@ +# Taken and modified pytorch lightening +# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning + + +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, + parallelize_module, +) + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + device: str = "cuda" + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + return torch.polar(torch.ones_like(freqs), freqs) # complex64 + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Reshape frequency tensor for broadcasting it with another tensor. + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary embeddings to input tensors using the given frequency tensor. + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class RMSNorm(nn.Module): + """Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + + +class Attention(nn.Module): + """Multi-head attention module. + Args: + model_args (ModelArgs): Model configuration arguments. + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + """ + + def __init__(self, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + + def init_weights(self, init_std: float) -> None: + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ) -> Any: + """Forward pass of the attention module. + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + Returns: + torch.Tensor: Output tensor after attention. + """ + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + # we use casual mask for training + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """FeedForward module. + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x) -> Any: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float) -> None: + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """TransformerBlock Module. + Args: + layer_id (int): Identifier for the layer. + model_args (ModelArgs): Model configuration arguments. + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + """ + + def __init__(self, layer_id: int, model_args: ModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.num_layers = model_args.n_layers + + self.attention_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """Perform a forward pass through the TransformerBlock. + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + """ + h = x + self.attention(self.attention_norm(x), freqs_cis) + return h + self.feed_forward(self.ffn_norm(h)) + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class ParallelTransformer(nn.Module): + """Transformer Module. + Args: + model_args (ModelArgs): Model configuration arguments. + Attributes: + model_args (ModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + """ + + def __init__(self, model_args: ModelArgs, tp_mesh: DeviceMesh = None): + # Here we use distributed model initialization to avoid memory overflow + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.tok_embeddings.to(model_args.device) + self.tok_embeddings = self.parallel_embeddings(self.tok_embeddings, tp_mesh) + + # TODO persistent should be set to false, since this buffer can be recomputed. + # however, we set it to true for 2 reasons. (1) due to pytorch/pytorch#123411, + # compile or pipeline-tracer will not correctly handle non-persistent buffers, + # so we need to fix that. (2) if we initialize pipeline-parallel models from + # a seed checkpoint rather than calling init_weights, we need freqs_cis to be + # initialized by the checkpoint, or we need to add a separate initializer for + # just the non-persistent buffers that is called after loading checkpoints. + self.register_buffer( + "freqs_cis", + self._precompute_freqs_cis().to(model_args.device), + persistent=True, + ) + + self.layers = torch.nn.ModuleDict().to(model_args.device) + for layer_id in range(model_args.n_layers): + block = TransformerBlock(layer_id, model_args).to(model_args.device) + self.layers[str(layer_id)] = block + self.parallel_transformer_block(self.layers[str(layer_id)], tp_mesh) + + self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps).to( + model_args.device + ) + self.norm = self.parallel_norm(self.norm, tp_mesh) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False).to( + model_args.device + ) + self.output = self.parallel_output(self.output, tp_mesh) + self.init_weights() + + def parallel_transformer_block(self, transformer_block, tp_mesh): + if tp_mesh.size() <= 1: + return + plan = { + "attention": PrepareModuleInput( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(output_layouts=Shard(1)), + "attention_norm": SequenceParallel(), + "feed_forward": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w3": ColwiseParallel(), + "ffn_norm": SequenceParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() + + # Apply the plan for the current transformer block + parallelize_module(transformer_block, tp_mesh, plan) + + def parallel_embeddings(self, embedding, tp_mesh): + plan = { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ) + } + return parallelize_module(embedding, tp_mesh, plan) + + def parallel_output(self, output, tp_mesh): + plan = { + "output": ColwiseParallel( + input_layouts=Shard(1), + ), + } + return parallelize_module(output, tp_mesh, plan) + + def parallel_norm(self, norm, tp_mesh): + plan = { + "norm": SequenceParallel(), + } + return parallelize_module(norm, tp_mesh, plan) + + def reset_parameters(self): + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + + def init_weights(self): + """[Note: On ``init_weights`` vs. + ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + with torch.device(self.freqs_cis.device): + self.freqs_cis = self._precompute_freqs_cis() + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + layer.init_weights() + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # (use 2x max sequence length to be safe) + self.model_args.max_seq_len * 2, + self.model_args.rope_theta, + ) + + def forward(self, tokens: torch.Tensor): + """Perform a forward pass through the Transformer model. + Args: + tokens (torch.Tensor): Input token indices. + Returns: + torch.Tensor: Output logits after applying the Transformer model. + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + + h = self.norm(h) if self.norm else h + return self.output(h).float() if self.output else h + + @classmethod + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": + """Initialize a Transformer model from a ModelArgs object. + Args: + model_args (ModelArgs): Model configuration arguments. + Returns: + Transformer: Transformer model. + """ + return cls(model_args) diff --git a/examples/distributed_inference/rotary_embedding.py b/examples/distributed_inference/rotary_embedding.py index 1153ea2180..6c18f9eb8f 100644 --- a/examples/distributed_inference/rotary_embedding.py +++ b/examples/distributed_inference/rotary_embedding.py @@ -84,20 +84,20 @@ def parallel_rotary_block(rotary_block, tp_mesh): "wk": ColwiseParallel(), "wo": RowwiseParallel(output_layouts=Shard(0)), } - rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode + rotary_block.n_parallel = tp_mesh.size() parallelize_module(rotary_block, tp_mesh, plan) class RotaryAttention(nn.Module): - def __init__(self, dim: int, seq_len: int): + def __init__(self, dim: int, seq_len: int, n_parallel: int = 1): super().__init__() self.dim = dim self.wq = nn.Linear(dim, dim) self.wk = nn.Linear(dim, dim) self.wo = nn.Linear(dim, dim) self.seq_len = seq_len - self.n_parallel = 1 + self.n_parallel = n_parallel self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) self.init_weights() diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py new file mode 100644 index 0000000000..842ebc530f --- /dev/null +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -0,0 +1,68 @@ +# Taken and modified pytorch lightening +# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning +import logging +import os +import time + +import torch +import torch_tensorrt +from llama3_model import ModelArgs, ParallelTransformer +from torch.distributed._composable.fsdp import MixedPrecisionPolicy +from torch.distributed._composable.fsdp.fully_shard import fully_shard +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, +) +from torch_tensorrt.dynamo.distributed.utils import ( + cleanup_distributed_env, + get_tensor_parallel_device_mesh, + initialize_distributed_env, + initialize_logger, +) + +if not dist.is_initialized(): + initialize_distributed_env() + +device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh() +logger = initialize_logger(_rank, "tensor_parallel_simple_example") + +logger.info(f"Starting PyTorch TP example on rank {_rank}.") +assert ( + _world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {_world_size} gpus" + +model_args = ModelArgs( + vocab_size=32000, + dim=1024, + n_layers=4, + n_heads=8, + rope_theta=500000.0, + n_kv_heads=8, + device="cuda", +) + +with torch.no_grad(): + model = ParallelTransformer(model_args, device_mesh) + torch.manual_seed(0) + inp = torch.randint(32000, (8, 256), device="cuda") + python_result = model(inp) + torch_tensorrt.runtime.set_multi_device_safe_mode(True) + model = torch.compile( + model, + fullgraph=True, + backend="torch_tensorrt", + options={ + "use_python_runtime": True, + "use_distributed_mode_trace": True, + "debug": True, + }, + dynamic=False, + ) + + start = time.time() + output = model(inp) + end = time.time() + logger.info(f"Compilation time is {end-start}") + assert (python_result - output).std() < 0.01, "Compilation result is not correct." + + cleanup_distributed_env() diff --git a/examples/distributed_inference/tensor_parallel_rotary_embedding.py b/examples/distributed_inference/tensor_parallel_rotary_embedding.py index d51f9a5787..950f31c941 100644 --- a/examples/distributed_inference/tensor_parallel_rotary_embedding.py +++ b/examples/distributed_inference/tensor_parallel_rotary_embedding.py @@ -41,7 +41,7 @@ DIM = 128 with torch.no_grad(): - model = RotaryAttention(DIM, SEQ_LEN) + model = RotaryAttention(DIM, SEQ_LEN, device_mesh.size()) parallel_rotary_block(model, device_mesh) device = torch.device("cuda", device_mesh.get_rank()) model.to(device) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 5f62506a02..3e685e8046 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -13,7 +13,6 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt._features import needs_cross_compile -from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults, partitioning from torch_tensorrt.dynamo._DryRunTracker import ( DryRunTracker, @@ -296,7 +295,6 @@ def cross_compile_for_windows( arg_inputs = [arg_inputs] # type: ignore # Prepare torch_trt inputs - trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} @@ -386,7 +384,6 @@ def cross_compile_for_windows( ) trt_gm = compile_module( gm, - trt_arg_inputs, trt_kwarg_inputs, settings, ) @@ -632,7 +629,6 @@ def compile( arg_inputs = [arg_inputs] # type: ignore # Prepare torch_trt inputs - trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} @@ -723,16 +719,13 @@ def compile( logger.warning( "Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True" ) - trt_gm = compile_module( - gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache - ) + trt_gm = compile_module(gm, trt_kwarg_inputs, settings, engine_cache) return trt_gm @fn_supports_debugger def compile_module( gm: torch.fx.GraphModule, - sample_arg_inputs: Sequence[Input], sample_kwarg_inputs: Optional[dict[Any, Any]] = None, settings: CompilationSettings = CompilationSettings(), engine_cache: Optional[BaseEngineCache] = None, diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index c39fe57197..cb25a105f2 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -22,7 +22,6 @@ from torch_tensorrt.dynamo.utils import ( is_tegra_platform, parse_dynamo_kwargs, - prepare_inputs, set_log_level, ) @@ -150,9 +149,6 @@ def _pretraced_backend( logger.debug("Lowered Input graph:\n " + str(gm.graph)) - torchtrt_inputs = prepare_inputs( - torch_inputs, disable_memory_format_check=True - ) if settings.require_full_compilation: logger.warning( "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" @@ -163,7 +159,6 @@ def _pretraced_backend( ) trt_compiled = compile_module( gm, - torchtrt_inputs, settings=settings, engine_cache=engine_cache, ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py deleted file mode 100644 index f8ca1f71b9..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py +++ /dev/null @@ -1,105 +0,0 @@ -import logging - -import torch - -logger = logging.getLogger(__name__) - -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, - find_complex_nodes, -) - -from ._replace_complex_placeholder_to_tuple import replace_complex_placeholder_to_tuple - - -def tensorrt_complex_mul(args0, args1): - args0_real, args0_imag = torch.ops.aten.split.Tensor(args0, 1, -1) - args1_real, args1_imag = torch.ops.aten.split.Tensor(args1, 1, -1) - - args0_real = torch.ops.aten.squeeze.dim(args0_real, -1) - args0_imag = torch.ops.aten.squeeze.dim(args0_imag, -1) - args1_real = torch.ops.aten.squeeze.dim(args1_real, -1) - args1_imag = torch.ops.aten.squeeze.dim(args1_imag, -1) - - complex_mul_real = torch.ops.aten.sub( - torch.ops.aten.mul(args0_real, args1_real), - torch.ops.aten.mul(args0_imag, args1_imag), - ) - complex_mul_imag = torch.ops.aten.add( - torch.ops.aten.mul(args0_real, args1_imag), - torch.ops.aten.mul(args0_imag, args1_real), - ) - - return torch.ops.aten.stack((complex_mul_real, complex_mul_imag), -1) - - -def remove_complex_real_view_nodes(gm: torch.fx.GraphModule): - modified_graph = False - nodes_to_remove = [] - for node in gm.graph.nodes: - if "view_as_complex" in node.name or "view_as_real" in node.name: - nodes_to_remove.append(node) - - for node in nodes_to_remove: - input_node = node.args[0] if node.args else None - - for other_node in gm.graph.nodes: - new_args = tuple( - input_node if arg is node else arg for arg in other_node.args - ) - other_node.args = new_args - - gm.graph.erase_node(node) - modified_graph = True - - if modified_graph: - gm = clean_up_graph_after_modifications(gm) - logger.debug( - f"Graph after removing view_as_complex nodes and view_as_real nodes:\n{gm.graph}" - ) - - -def modify_reshape_nodes(gm: torch.fx.GraphModule, complex_nodes): - for node in gm.graph.nodes: - if node in complex_nodes: - # slice and transpose will remain same - if "reshape" in node.name: - new_shape = list(node.args[1]) + [2] - node.args = (node.args[0], tuple(new_shape)) - - -def modify_mul_nodes(gm: torch.fx.GraphModule, complex_nodes): - modified_graph = False - for node in gm.graph.nodes: - if node in complex_nodes: - if "mul" in node.name: - complex_mul_args = (node.args[0], node.args[1]) - with gm.graph.inserting_after(node): - replacement_node = gm.graph.create_node( - op="call_function", - target=tensorrt_complex_mul, - args=complex_mul_args, - ) - node.replace_all_uses_with(replacement_node) - replacement_node.meta.update(node.meta) - modified_graph = True - gm.graph.erase_node(node) - - if modified_graph: - gm = clean_up_graph_after_modifications(gm) - logger.debug( - f"Graph after custom complex mul nodes is applied to the graph:\n{gm.graph}" - ) - - -def modify_complex_nodes(gm: torch.fx.GraphModule, complex_nodes): - modify_reshape_nodes(gm, complex_nodes) - remove_complex_real_view_nodes(gm) - modify_mul_nodes(gm, complex_nodes) - - -def modify_reshape_complex_nodes(gm: torch.fx.GraphModule, complexInputIndices): - complex_nodes = find_complex_nodes(gm) - if complex_nodes: - replace_complex_placeholder_to_tuple(gm, complexInputIndices) - modify_complex_nodes(gm, complex_nodes) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py b/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py deleted file mode 100644 index e2edec3d28..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py +++ /dev/null @@ -1,112 +0,0 @@ -import logging -from typing import List, Tuple - -import torch -from torch._subclasses.fake_tensor import FakeTensorMode -from torch.fx.node import _get_qualified_name -from torch_tensorrt.dynamo.conversion.converter_utils import args_bounds_check - -# dead-code elimination, linting, and recompilation for graph, in-place -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) - -logger = logging.getLogger(__name__) - - -def replace_complex_placeholder_to_tuple( - gm: torch.fx.GraphModule, - inputListindices: List[int], -) -> torch.fx.GraphModule: - modified_graph = False - input_arg_list = [f"arg{inputListIndex}_1" for inputListIndex in inputListindices] - for node in gm.graph.nodes: - if node.op == "placeholder" and node.target in input_arg_list: - from torch._subclasses.fake_tensor import FakeTensorMode - - node_shape = node.meta["val"].size() - new_node_shape = node_shape + (2,) - new_node_dtype = None - if node.meta["val"].dtype == torch.complex64: - new_node_dtype = torch.float32 - else: - new_node_dtype = torch.float64 - fake_mode = FakeTensorMode() - - real_tensor = torch.empty(new_node_shape, dtype=new_node_dtype) - with FakeTensorMode() as fake_mode: - new_placeholder_tuple = fake_mode.from_tensor(real_tensor) - node.meta["val"] = new_placeholder_tuple - modified_graph = True - # propagate the meta data change for the downstream ops - # TODO:to check if this is required in all cases - propogate_complex_num_shape_change_till_complex_mul(gm, node, fake_mode) - - # If graph was modified, clean it up - if modified_graph: - gm = clean_up_graph_after_modifications(gm) - logger.debug( - f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}" - ) - - return gm - - -def infer_slice_shape(node: torch.fx.Node) -> Tuple[int, ...]: - input_shape = node.args[0].meta["val"].shape - slice_args = node.args - dim = slice_args[1] - start = slice_args[2] - end = slice_args[3] - step = args_bounds_check(slice_args, 4, replacement=1) - new_shape = list(input_shape) - new_shape[dim] = (end - start + step - 1) // step - return tuple(new_shape) - - -def infer_reshape_shape(node: torch.fx.Node) -> torch.fx.node.Argument: - return node.args[1] - - -shape_inference_funcs = { - "torch.ops.aten.slice.Tensor": infer_slice_shape, - "torch.ops.aten.reshape.default": infer_reshape_shape, -} - - -# Please note this function is for the use case of Llama model -# with complex placeholder->reshape->slice->complex mul -# Hence mul is the terminating op -def propogate_complex_num_shape_change_till_complex_mul( - node: torch.fx.Node, start_node: torch.fx.Node, fake_mode: FakeTensorMode -) -> None: - visited_nodes = set() - stack = [start_node] - while stack: - node = stack.pop() - if node in visited_nodes: - continue - visited_nodes.add(node) - update_node_meta(node, fake_mode) - for user in node.users: - if ( - user.op == "call_function" - and _get_qualified_name(user.target) == "torch.ops.aten.mul.Tensor" - ): - continue - stack.append(user) - - -def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None: - op_name = node.name - op_target = node.target - - if node.op == "call_function": - op_target = _get_qualified_name(node.target) - - if op_target in shape_inference_funcs: - new_shape = shape_inference_funcs[op_target](node) - real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype) - node.meta["val"] = fake_mode.from_tensor(real_tensor) - else: - print("No shape for the inference function", {op_name}) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py index c3ead218aa..23cedc2211 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py @@ -108,6 +108,7 @@ class ComplexGraphRewriter: def __init__(self, gm: GraphModule, truncate_double: bool = False) -> None: self.gm = gm self.truncate_double = truncate_double + self.processed_input_nodes = set() def extract_shape_dtype_device( self, input_node: Node @@ -185,8 +186,12 @@ def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None: for subgraph in subgraphs: for input_node in subgraph.input_nodes: logger.debug(f"Input node rewrite: {input_node.name}") + if input_node in self.processed_input_nodes: + logger.debug(f"Skipping {input_node.name}, already processed.") + continue if input_node.op not in ("call_function"): self.replace_input_node(input_node) + self.processed_input_nodes.add(input_node) for node in subgraph.subgraph_nodes: logger.debug(f"Subgraph Node rewrite: {node.name}") if node.target == torch.ops.aten.view_as_complex.default: @@ -230,6 +235,17 @@ def match_complex_mul( # type: ignore[no-untyped-def] elif node.target == torch.ops.aten.view_as_real.default: node.replace_all_uses_with(node.args[0]) self.gm.graph.erase_node(node) + elif node.target == torch.ops.aten._reshape_copy.default: + old_shape = node.args[1] + if isinstance(old_shape, (list, tuple)) and all( + isinstance(x, int) for x in old_shape + ): + new_shape = list(old_shape) + [2] + node.args = (node.args[0], new_shape) + logger.debug( + f"Updated reshape {node.name} from {old_shape} to {new_shape}" + ) + modified = True else: logger.debug(f"Unsupported node target: {node.target}") logger.debug( diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 9e54fbac3d..9290bf7909 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -8,6 +8,7 @@ from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo.runtime.utils import complex_to_ri_stacked_tensor logger = logging.getLogger(__name__) @@ -142,6 +143,9 @@ def forward( for i, _ in enumerate(inputs): if not contiguous_inputs[i].is_cuda: + contiguous_inputs[i] = complex_to_ri_stacked_tensor( + contiguous_inputs[i] + ) logger.warning( f"Detected input[{i}] is not on a cuda device. " "This tensor is being moved by the runtime but for performance considerations, " diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 8e18a3ae32..17377365c7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -14,6 +14,7 @@ from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger +from torch_tensorrt.dynamo.runtime.utils import complex_to_ri_stacked_tensor from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( @@ -358,6 +359,9 @@ def setup_input_tensors( ) -> None: for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: + contiguous_inputs[i] = complex_to_ri_stacked_tensor( + contiguous_inputs[i] + ) logger.warning( f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. " "This tensor is being moved by the runtime but for performance considerations, " diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 95f1581881..f238532a51 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -15,6 +15,7 @@ needs_torch_tensorrt_runtime, ) from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.runtime.utils import complex_to_ri_stacked_tensor logger = logging.getLogger(__name__) @@ -320,6 +321,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: # directly cast the input to a Torch Tensor. # # This also avoids the need for type-checking inputs, since they are now explicitly casted to Torch tensors + inputs = tuple(complex_to_ri_stacked_tensor(i) for i in inputs) input_tensors: List[torch.Tensor] = [ (i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) for i in inputs diff --git a/py/torch_tensorrt/dynamo/runtime/utils.py b/py/torch_tensorrt/dynamo/runtime/utils.py new file mode 100644 index 0000000000..ad391b66b1 --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/utils.py @@ -0,0 +1,8 @@ +import torch + + +def complex_to_ri_stacked_tensor(t: torch.Tensor) -> torch.Tensor: + # Converts complex tensor to real/imag stack + if torch.is_complex(t): + return torch.stack([t.real, t.imag], dim=-1) + return t