Skip to content

Distributed llama3 example #3785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: abose/torchTRT_trt_llm_load_parallel
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
496 changes: 496 additions & 0 deletions examples/distributed_inference/llama3_model.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions examples/distributed_inference/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,20 @@ def parallel_rotary_block(rotary_block, tp_mesh):
"wk": ColwiseParallel(),
"wo": RowwiseParallel(output_layouts=Shard(0)),
}
rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode
rotary_block.n_parallel = tp_mesh.size()

parallelize_module(rotary_block, tp_mesh, plan)


class RotaryAttention(nn.Module):
def __init__(self, dim: int, seq_len: int):
def __init__(self, dim: int, seq_len: int, n_parallel: int = 1):
super().__init__()
self.dim = dim
self.wq = nn.Linear(dim, dim)
self.wk = nn.Linear(dim, dim)
self.wo = nn.Linear(dim, dim)
self.seq_len = seq_len
self.n_parallel = 1
self.n_parallel = n_parallel
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
self.init_weights()

Expand Down
68 changes: 68 additions & 0 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Taken and modified pytorch lightening
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
import logging
import os
import time

import torch
import torch_tensorrt
from llama3_model import ModelArgs, ParallelTransformer
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)
from torch_tensorrt.dynamo.distributed.utils import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
initialize_logger,
)

if not dist.is_initialized():
initialize_distributed_env()

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_logger(_rank, "tensor_parallel_simple_example")

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"

model_args = ModelArgs(
vocab_size=32000,
dim=1024,
n_layers=4,
n_heads=8,
rope_theta=500000.0,
n_kv_heads=8,
device="cuda",
)

with torch.no_grad():
model = ParallelTransformer(model_args, device_mesh)
torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
model = torch.compile(
model,
fullgraph=True,
backend="torch_tensorrt",
options={
"use_python_runtime": True,
"use_distributed_mode_trace": True,
"debug": True,
},
dynamic=False,
)

start = time.time()
output = model(inp)
end = time.time()
logger.info(f"Compilation time is {end-start}")
assert (python_result - output).std() < 0.01, "Compilation result is not correct."

cleanup_distributed_env()
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
DIM = 128

with torch.no_grad():
model = RotaryAttention(DIM, SEQ_LEN)
model = RotaryAttention(DIM, SEQ_LEN, device_mesh.size())
parallel_rotary_block(model, device_mesh)
device = torch.device("cuda", device_mesh.get_rank())
model.to(device)
Expand Down
9 changes: 1 addition & 8 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt._features import needs_cross_compile
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults, partitioning
from torch_tensorrt.dynamo._DryRunTracker import (
DryRunTracker,
Expand Down Expand Up @@ -296,7 +295,6 @@ def cross_compile_for_windows(
arg_inputs = [arg_inputs] # type: ignore

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
Expand Down Expand Up @@ -386,7 +384,6 @@ def cross_compile_for_windows(
)
trt_gm = compile_module(
gm,
trt_arg_inputs,
trt_kwarg_inputs,
settings,
)
Expand Down Expand Up @@ -632,7 +629,6 @@ def compile(
arg_inputs = [arg_inputs] # type: ignore

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
Expand Down Expand Up @@ -723,16 +719,13 @@ def compile(
logger.warning(
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
trt_gm = compile_module(gm, trt_kwarg_inputs, settings, engine_cache)
return trt_gm


@fn_supports_debugger
def compile_module(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
settings: CompilationSettings = CompilationSettings(),
engine_cache: Optional[BaseEngineCache] = None,
Expand Down
5 changes: 0 additions & 5 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from torch_tensorrt.dynamo.utils import (
is_tegra_platform,
parse_dynamo_kwargs,
prepare_inputs,
set_log_level,
)

Expand Down Expand Up @@ -150,9 +149,6 @@ def _pretraced_backend(

logger.debug("Lowered Input graph:\n " + str(gm.graph))

torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
if settings.require_full_compilation:
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
Expand All @@ -163,7 +159,6 @@ def _pretraced_backend(
)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
settings=settings,
engine_cache=engine_cache,
)
Expand Down

This file was deleted.

This file was deleted.

Loading
Loading