diff --git a/examples/distributed_inference/rotary_embedding.py b/examples/distributed_inference/rotary_embedding.py new file mode 100644 index 0000000000..1153ea2180 --- /dev/null +++ b/examples/distributed_inference/rotary_embedding.py @@ -0,0 +1,117 @@ +""" +.. _rotary_embedding: + +Rotary Embedding Implementation for Tensor Parallel Attention +============================================================ + +This module provides an implementation of rotary positional embeddings (RoPE) for transformer models +with support for tensor parallel distributed inference. Rotary embeddings are used to encode positional +information in transformer attention mechanisms. +""" + +import time + +import tensorrt as trt +import torch +import torch.distributed as dist +import torch.nn as nn +import torch_tensorrt +from tensor_parallel_initialize_dist import initialize_distributed_env +from torch.distributed._tensor import Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) + +""" +This example covers the rotary embedding and rotary attention case for tensor parallel +""" + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, n_parallel=1 +) -> torch.Tensor: + """Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + n_parallel (int, optional): Number of GPUs for parallel computation. Defaults to 1. + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim // n_parallel, 2).float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + return torch.polar(torch.ones_like(freqs), freqs) + + +def rotary_embedding(xq, xk, dim, freqs_cis=None): + """This calculates the rotary embedding for the query and key tensors. + Args: + xq (torch.Tensor): Query tensor. + xk (torch.Tensor): Key tensor. + dim (int): Dimension of the query and key tensors. + freqs_cis (torch.Tensor, optional): Precomputed frequency tensor. Defaults to None. + Returns: + tuple: Tuple containing the rotated query and key tensors. + """ + freqs_cis = freqs_cis[None, :, None, :] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return (xq_out.type_as(xq), xk_out.type_as(xk)) + + +########Tensor Parallel######## +def parallel_rotary_block(rotary_block, tp_mesh): + """Parallel rotary block for tensor parallel + Args: + rotary_block: Rotary block to parallelize + tp_mesh: Tensor parallel mesh + """ + if tp_mesh.size() <= 1: + return + + plan = { + "wq": ColwiseParallel(), + "wk": ColwiseParallel(), + "wo": RowwiseParallel(output_layouts=Shard(0)), + } + rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode + + parallelize_module(rotary_block, tp_mesh, plan) + + +class RotaryAttention(nn.Module): + def __init__(self, dim: int, seq_len: int): + super().__init__() + self.dim = dim + self.wq = nn.Linear(dim, dim) + self.wk = nn.Linear(dim, dim) + self.wo = nn.Linear(dim, dim) + self.seq_len = seq_len + self.n_parallel = 1 + self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True) + self.init_weights() + + def _precompute_freqs_cis(self) -> torch.Tensor: + theta = 10000.0 + return precompute_freqs_cis(self.dim, self.seq_len, theta, self.n_parallel) + + def init_weights(self): + with torch.device(self.freqs_cis.device): + self.freqs_cis = self.freqs_cis + + def forward(self, x): + q = self.wq(x) + k = self.wk(x) + freqs_cis = self._precompute_freqs_cis().to(q.device) + q, k = rotary_embedding(q, k, self.dim, freqs_cis=freqs_cis) + return self.wo(q) diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 21e4cbc282..98d3ca18e9 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -1,3 +1,11 @@ +""" +.. _tensor_parallel_initialize_dist: +Tensor Parallel Initialize Distributed Environment +================================================== + +This module provides functions to initialize and clean up the distributed environment for tensor parallel distributed inference. +""" + import logging import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -65,3 +73,9 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950 torch.cuda.set_device(device_id) return device_mesh, world_size, rank, logger + + +def cleanup_distributed_env(): + """Clean up distributed process group to prevent resource leaks.""" + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/examples/distributed_inference/tensor_parallel_rotary_embedding.py b/examples/distributed_inference/tensor_parallel_rotary_embedding.py new file mode 100644 index 0000000000..da3f3fd8fd --- /dev/null +++ b/examples/distributed_inference/tensor_parallel_rotary_embedding.py @@ -0,0 +1,59 @@ +""" +.. _tensor_parallel_rotary_embedding: +Tensor Parallel Rotary Embedding Example +======================================= + +This example demonstrates how to use Torch-TensorRT with tensor parallel distributed inference +for models that use rotary positional embeddings (RoPE). It lowers the complex +operations in attention models with rotary embeddings across multiple GPUs. + +""" + +import logging +import os +import time + +import torch +import torch_tensorrt +from rotary_embedding import RotaryAttention, parallel_rotary_block +from tensor_parallel_initialize_dist import ( + cleanup_distributed_env, + initialize_distributed_env, +) + +device_mesh, _world_size, _rank, logger = initialize_distributed_env( + "./tensor_parallel_rotary_embedding" +) + + +""" +This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning +Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py +""" + +BATCH = 2 +SEQ_LEN = 128 +HEADS = 4 +DIM = 128 + +with torch.no_grad(): + model = RotaryAttention(DIM, SEQ_LEN) + parallel_rotary_block(model, device_mesh) + device = torch.device("cuda", device_mesh.get_rank()) + model.to(device) + x = torch.randn(BATCH, SEQ_LEN, HEADS, DIM).to(device) + + python_result = model(x) + + logger.info("Torch-tensorrt compilation for rotary embedding") + + model = torch.compile(model, backend="torch_tensorrt") + + torch.manual_seed(0) + start = time.time() + output = model(x) + end = time.time() + logger.info(f"Compilation time is {end-start}") + assert (python_result - output).std() < 0.01, "Compilation result is not correct." + + cleanup_distributed_env() diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index d2e3c590c6..c5688c6e5b 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -1,3 +1,24 @@ +""" +.. _tensor_parallel_simple_example: + +Torch Parallel Distributed example for simple model +========================================= + +Below example shows how to use Torch-TensorRT backend for distributed inference with tensor parallelism. + +This example demonstrates: + - Setting up distributed environment for tensor parallelism + - Model sharding across multiple GPUs + - Compilation with Torch-TensorRT + - Distributed inference execution + +Usage +----- +.. code-block:: bash + + mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py +""" + import time import tensorrt as trt @@ -5,7 +26,10 @@ import torch.distributed as dist import torch.nn as nn import torch_tensorrt -from tensor_parallel_initialize_dist import initialize_distributed_env +from tensor_parallel_initialize_dist import ( + cleanup_distributed_env, + initialize_distributed_env, +) from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -18,7 +42,7 @@ ) """ -This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py +This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py """ @@ -79,23 +103,15 @@ def forward(self, x): dynamic=None, ) -try: - for i in range(10): - # For TP, input needs to be same across all TP ranks. - # Setting the random seed is to mimic the behavior of dataloader. - torch.manual_seed(i) - inp = torch.rand(20, 10, device="cuda") - start = time.time() - output = tp_model(inp) - end = time.time() - if i == 0: - logger.info(f"Compilation time is {end-start}") - assert ( - python_result - output - ).std() < 0.01, "Compilation result is not correct." - elif _rank == 0: - logger.info(f"Inference time is {end-start}") -finally: - # This cleans up the distributed process group - if dist.is_initialized(): - dist.destroy_process_group() +# For TP, input needs to be same across all TP ranks. +# Setting the random seed is to mimic the behavior of dataloader. +torch.manual_seed(0) +inp = torch.rand(20, 10, device="cuda") +start = time.time() +output = tp_model(inp) +end = time.time() +logger.info(f"Compilation time is {end - start}") +assert (python_result - output).std() < 0.01, "Result is not correct." + +# This cleans up the distributed process group +cleanup_distributed_env() diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index f320505c94..a2feb99d56 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -23,7 +23,10 @@ def getitem_validator(getitem_node: Node, settings: CompilationSettings = None) from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS # Getitem nodes can only be converted if their parent node also can - return getitem_node.args[0] in DYNAMO_CONVERTERS + return ( + getitem_node.args[0] in DYNAMO_CONVERTERS + or getitem_node.args[0].op == "get_attr" + ) # TODO: Subsequent evaluators should be registered here with their own validators @@ -43,7 +46,10 @@ def generic_evaluator( _LOGGER.debug( f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}" ) - return target(*args) + from torch._subclasses.fake_tensor import unset_fake_temporarily + + with unset_fake_temporarily(): + return target(*args) def rand_validator(rand_node: Node, settings: CompilationSettings = None) -> bool: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py index 716c6505fe..c0e2803e60 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py @@ -1,4 +1,3 @@ from ._aten_lowering_pass import * -from ._modify_reshape_complex_nodes import modify_reshape_complex_nodes from .remove_sym_nodes import remove_sym_nodes from .repair_input_aliasing import repair_input_aliasing diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index c7fe264c5a..fff4473b47 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -6,6 +6,7 @@ from torch_tensorrt.dynamo.utils import is_tegra_platform from .accumulate_fp32_matmul import accumulate_fp32_matmul +from .complex_graph_rewrite import complex_graph_detection from .constant_folding import constant_fold from .fuse_distributed_ops import fuse_distributed_ops from .fuse_prims_broadcast import fuse_prims_broadcast @@ -26,6 +27,7 @@ remove_assert_nodes, accumulate_fp32_matmul, remove_num_users_is_0_nodes, + complex_graph_detection, ] pre_lowering_pass_list = [ diff --git a/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py new file mode 100644 index 0000000000..c3ead218aa --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py @@ -0,0 +1,361 @@ +import logging +from typing import Callable, List, Set, Tuple + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import GraphModule, Node +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +class ComplexSubGraphInfo: + def __init__( + self, + anchor_nodes: List[Node], + subgraph_nodes: List[Node], + input_nodes: List[Node], + ): + self.anchor_nodes = anchor_nodes + self.subgraph_nodes = subgraph_nodes + self.input_nodes = input_nodes + + def __repr__(self) -> str: + return ( + f"ComplexOpSubGraphInfo(anchor_nodes={[n.name for n in self.anchor_nodes]}, " + f"subgraph={[n.name for n in self.subgraph_nodes]}, " + f"inputs={[n.name for n in self.input_nodes]})" + ) + + +class ComplexOpDetector: + def __init__(self) -> None: + pass + + def is_complex_dtype(self, node: Node) -> bool: + # Check if node's metadata or dtype is complex + dtype = None + if "val" in node.meta: + val = node.meta["val"] + if hasattr(val, "dtype"): + dtype = val.dtype + + logger.debug(f"dtype of node: {dtype}") + return dtype in {torch.complex64, torch.complex128} + + def node_include_in_subgraph(self, node: Node) -> bool: + # Include only call_function ops on complex tensors + if node.op == "call_function" and self.is_complex_dtype(node): + logger.debug( + f"node.op is added to subgraph: {node.op}, node name: {node.name} is complex" + ) + return node.op == "call_function" and self.is_complex_dtype(node) + + def subgraph_from_anchor(self, anchor_node: Node) -> ComplexSubGraphInfo: + subgraph_nodes: Set[Node] = set() + input_nodes: Set[Node] = set() + stack = [anchor_node] + while stack: + n = stack.pop() + if n in subgraph_nodes: + continue + subgraph_nodes.add(n) + logger.debug(f"node {n.name} is added to subgraph") + for inp in n.all_input_nodes: + if self.node_include_in_subgraph(inp): + stack.append(inp) + else: + input_nodes.add(inp) + return ComplexSubGraphInfo( + [anchor_node], list(subgraph_nodes), list(input_nodes) + ) + + def find_complex_op_subgraphs( + self, gm: GraphModule, anchor_target: str + ) -> List[ComplexSubGraphInfo]: + complex_op_subgraphs: List[ComplexSubGraphInfo] = [] + for node in gm.graph.nodes: + if node.target == anchor_target: + new_sub = self.subgraph_from_anchor(node) + # if any intersecting nodes between seen and sub.subgraph_nodes they should be merged + merged = False + for existing_sub in complex_op_subgraphs: + if set(existing_sub.subgraph_nodes) & set(new_sub.subgraph_nodes): + logger.debug(f"merging subgraphs {existing_sub} {new_sub}") + # merge the two subgraphs + existing_sub.subgraph_nodes = list( + set(existing_sub.subgraph_nodes) + | set(new_sub.subgraph_nodes) + ) + existing_sub.input_nodes = list( + set(existing_sub.input_nodes) | set(new_sub.input_nodes) + ) + existing_sub.anchor_nodes = list( + set(existing_sub.anchor_nodes) | set(new_sub.anchor_nodes) + ) + merged = True + break + if not merged: + complex_op_subgraphs.append(new_sub) + return complex_op_subgraphs + + +class ComplexGraphRewriter: + def __init__(self, gm: GraphModule, truncate_double: bool = False) -> None: + self.gm = gm + self.truncate_double = truncate_double + + def extract_shape_dtype_device( + self, input_node: Node + ) -> Tuple[Tuple[int, ...], torch.dtype, torch.device]: + if input_node.op == "placeholder": + tensor_val = input_node.meta["val"] + + elif input_node.op == "get_attr": + tensor_val = self.get_attr_tensor(input_node.target) # type: ignore + + else: + raise ValueError(f"Unsupported node type: {input_node.op}") + + node_shape = tensor_val.size() + dtype = tensor_val.dtype + new_node_shape = node_shape + (2,) + device = tensor_val.device + + if dtype == torch.complex64: + new_node_dtype = torch.float32 + elif dtype == torch.complex128 and self.truncate_double: + new_node_dtype = torch.float32 + else: + new_node_dtype = torch.float64 + + return new_node_shape, new_node_dtype, device + + def get_attr_tensor(self, target): # type: ignore + # Check if target is param or buffer + if target in dict(self.gm.named_parameters()): + return self.gm.get_parameter(target) + elif target in dict(self.gm.named_buffers()): + return self.gm.get_buffer(target) + else: + raise ValueError( + f"Attribute {target} not found in gm parameters or buffers." + ) + + def replace_input_node(self, input_node: Node) -> None: + modified = False + logger.debug(f"Replacing input node: {input_node.name}") + new_shape, new_dtype, device = self.extract_shape_dtype_device(input_node) + real_tensor = torch.empty(new_shape, dtype=new_dtype, device=device) + + if input_node.op == "placeholder": + with FakeTensorMode() as fake_mode: + fake_tensor = fake_mode.from_tensor(real_tensor) + with self.gm.graph.inserting_before(input_node): + new_node = self.gm.graph.placeholder(input_node.target + "_reshaped") + new_node.meta["val"] = fake_tensor + + elif input_node.op == "get_attr": + new_attr_name = input_node.target + "_reshaped" + with unset_fake_temporarily(): + original_tensor = self.get_attr_tensor(input_node.target) # type: ignore + stacked_tensor = torch.stack( + [original_tensor.real, original_tensor.imag], dim=-1 + ) + self.gm.register_buffer(new_attr_name, stacked_tensor) + with self.gm.graph.inserting_after(input_node): + new_node = self.gm.graph.get_attr(new_attr_name) + else: + logger.debug( + f"Unsupported node type in replacement of input node: {input_node.op}" + ) + logger.debug( + "This complex subgraph inputnode type does not need to replaced" + ) + input_node.replace_all_uses_with(new_node) + self.gm.graph.erase_node(input_node) + clean_up_graph_after_modifications(self.gm) + + def rewrite_subgraph_nodes(self, subgraphs: List[ComplexSubGraphInfo]) -> None: + modified = False + for subgraph in subgraphs: + for input_node in subgraph.input_nodes: + logger.debug(f"Input node rewrite: {input_node.name}") + if input_node.op not in ("call_function"): + self.replace_input_node(input_node) + for node in subgraph.subgraph_nodes: + logger.debug(f"Subgraph Node rewrite: {node.name}") + if node.target == torch.ops.aten.view_as_complex.default: + node.replace_all_uses_with(node.args[0]) + self.gm.graph.erase_node(node) + elif node.target == torch.ops.aten.mul.Tensor: + # this is complex mul where inputs = a+ib and output = c+id. + # complex mul returns (ac - bd) + (ad + bc)i + # which is then view_as_real as (ac-bd), (ad+bc) stacked along the last dimension with last dimension size 2 + x_placeholder_or_func = ( + True if node.args[0].op != "get_attr" else False + ) + y_placeholder_or_func = ( + True if node.args[1].op != "get_attr" else False + ) + + replaced_nodes = [] + original_mul, replacement = complex_mul_replacement( + x_placeholder_or_func, y_placeholder_or_func + ) + + def match_complex_mul( # type: ignore[no-untyped-def] + match: torch.fx.subgraph_rewriter.Match, + original_graph, + pattern_graph, + ) -> bool: + for original_node in match.nodes_map.values(): + if original_node.name == node.name: + return True + return False + + nodes = torch.fx.subgraph_rewriter.replace_pattern_with_filters( + self.gm, + original_mul, + replacement, + match_filters=[match_complex_mul], + ignore_literals=True, + ) + replaced_nodes += nodes + modified = True + elif node.target == torch.ops.aten.view_as_real.default: + node.replace_all_uses_with(node.args[0]) + self.gm.graph.erase_node(node) + else: + logger.debug(f"Unsupported node target: {node.target}") + logger.debug( + "This complex subgraphnode type does not need to replaced" + ) + + if modified: + self.propagate_metadata() + self.gm.graph.lint() + self.gm.recompile() + + def propagate_metadata(self) -> None: + fake_inputs = [] + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.passes.fake_tensor_prop import FakeTensorProp + + for node in self.gm.graph.nodes: + if node.op == "placeholder": + if "val" in node.meta: + with FakeTensorMode(allow_non_fake_inputs=True): + fake_val = node.meta["val"] + fake_inputs.append( + fake_val.to("cuda") + if fake_val.device.type == "cuda" + else fake_val + ) + else: + fake_tensor = torch.empty( + [s if s != 0 else 1 for s in node.meta["tensor_meta"].shape], + dtype=node.meta["tensor_meta"].dtype, + device=node.meta["tensor_meta"].device, + ) + fake_inputs.append(fake_tensor) + FakeTensorProp( + self.gm, mode=FakeTensorMode(allow_non_fake_inputs=True) + ).propagate(*fake_inputs) + + +def extract_real_imag(input, placeholder_or_func: bool = True): # type: ignore + """Extract real and imaginary parts from a tensor. + This function handles different tensor types based on whether they are placeholder/function + tensors or get_attr tensors. For placeholder/function tensors, it uses select operations, + while for get_attr tensors, it uses indexing. + Args: + input: Input tensor to extract real and imaginary parts from + placeholder_or_func: Boolean flag indicating if the input is a placeholder/function tensor (True) + or a get_attr tensor (False). Defaults to True. + Returns: + Tuple of (real_part, imaginary_part) where both parts have the same type as the input + Note: + - When placeholder_or_func=True: Uses torch.ops.aten.select.int operations + - When placeholder_or_func=False: Uses tensor indexing [..., 0] and [..., 1] + """ + if placeholder_or_func: + # For ITensor, use select operations + real_part = torch.ops.aten.select.int(input, -1, 0) + imag_part = torch.ops.aten.select.int(input, -1, 1) + return real_part, imag_part + else: + # For get_attr, use indexing + return input[..., 0], input[..., 1] + + +def complex_mul_replacement( + x_placeholder_or_func: bool = True, y_placeholder_or_func: bool = True +) -> Tuple[ + Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + Callable[[torch.Tensor, torch.Tensor], torch.Tensor], +]: + """Constructs the original and replacement functions for complex multiplication. + + The original functions correspond to native complex multiplication + via torch.mul or operator.mul on complex tensors. + + The replacement function assumes x and y are real tensors with the last + dimension size 2 representing real and imaginary parts, and performs + complex multiplication manually returning the same shaped tensor. + """ + + # Original pattern: torch.mul for complex tensors + def original_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.ops.aten.mul.Tensor(x, y) + + # Replacement function: manual complex multiplication on real/imag stacked tensors + def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x_real, x_imag = extract_real_imag(x, x_placeholder_or_func) + y_real, y_imag = extract_real_imag(y, y_placeholder_or_func) + + real_part1 = torch.ops.aten.mul.Tensor(x_real, y_real) + real_part2 = torch.ops.aten.mul.Tensor(x_imag, y_imag) + real = torch.ops.aten.sub.Tensor(real_part1, real_part2) + + imag_part1 = torch.ops.aten.mul.Tensor(x_real, y_imag) + imag_part2 = torch.ops.aten.mul.Tensor(x_imag, y_real) + imag = torch.ops.aten.add.Tensor(imag_part1, imag_part2) + + return torch.ops.aten.cat.default( + [ + torch.ops.aten.unsqueeze.default(real, -1), + torch.ops.aten.unsqueeze.default(imag, -1), + ], + -1, + ) + + return (original_mul, replacement) + + +# This lowering pass is used to detect and rewrite complex subgraphs in the graph +def complex_graph_detection( + gm: GraphModule, settings: CompilationSettings +) -> GraphModule: + """Detect and rewrite complex subgraphs in the graph. + This lowering pass is used to detect and rewrite complex subgraphs in the graph. + This lowering pass works for complex tensor in mul which are parameter or buffers in the graph. + Args: + gm: The GraphModule to process + settings: Compilation settings + Returns: + The modified GraphModule with complex subgraphs rewritten + """ + complex_op_detector = ComplexOpDetector() + complex_subgraphs = complex_op_detector.find_complex_op_subgraphs( + gm, anchor_target=torch.ops.aten.view_as_real.default + ) + for subgraph in complex_subgraphs: + logger.debug(f"Complex subgraph info: {subgraph}") + complex_graph_rewriter = ComplexGraphRewriter(gm, settings.truncate_double) + complex_graph_rewriter.rewrite_subgraph_nodes(complex_subgraphs) + return gm diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 69c91db475..7eaccf9348 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -237,5 +237,97 @@ def forward(self, input, mat1, mat2): torch._dynamo.reset() +class TestComplexSubgraph(TestCase): + def test_complex_subgraph(self): + BATCH = 1 + SEQ_LEN = 2 + HEADS = 1 + DIM = 2 + + class RotaryAttention(torch.nn.Module): + def __init__(self): + super().__init__() + self.dim = DIM + self.wq = torch.nn.Linear(self.dim, self.dim) + self.seq_len = SEQ_LEN + + self.register_buffer( + "freqs_ex_tensor", + self._freqs_ex_tensor(), + persistent=True, + ) + + def rotary_embedding(self, x, dim, freqs_cis=None): + x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_out_flatten = torch.view_as_real(x_ * freqs_cis) + return x_out_flatten.type_as(x) + + def _freqs_ex_tensor(self): + real = torch.tensor([[[[1.0000]], [[2.0000]]]], device="cuda") + imag = torch.tensor([[[[0.0000]], [[3.0000]]]], device="cuda") + + z = torch.complex(real, imag) + return z + + def forward(self, x): + q = self.wq(x) + freqs_cis = self._freqs_ex_tensor().to(q.device) + q_out = self.rotary_embedding(q, self.dim, freqs_cis=freqs_cis) + return q_out + + inputs = [torch.randn(BATCH, SEQ_LEN, HEADS, DIM).cuda()] + model = RotaryAttention() + model = model.cuda() + + expected_ops = {torch.ops.aten.mul.Tensor} + unexpected_ops = { + torch.ops.aten.view_as_complex.default, + torch.ops.aten.view_as_real.default, + } + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + model, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + model, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs)[0].detach().cpu() + torch_model_results = model(*inputs)[0].detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"ComplexSubgraph TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index 7894c49967..7adf2c8a58 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -92,7 +92,7 @@ def compile_module_testing( ) # Store intermediate graph from partitioned module - store_intermediate_graphs.append(deepcopy(partitioned_module)) + store_intermediate_graphs.append(partitioned_module) return partitioned_module