Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 0 additions & 81 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
import torch
import torch_tensorrt
from rotary_embedding import RotaryAttention, parallel_rotary_block
from tensor_parallel_initialize_dist import (
from torch.distributed import dist
from torch_tensorrt.dynamo.distributed.utils import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
initialize_logger,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_rotary_embedding"
)
if not dist.is_initialized():
initialize_distributed_env()

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_logger(_rank, "tensor_parallel_simple_example")

"""
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
Expand Down
15 changes: 12 additions & 3 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,20 @@
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
from torch_tensorrt.dynamo.distributed.utils import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
initialize_logger,
)

if not dist.is_initialized():
initialize_distributed_env()

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_logger(_rank, "tensor_parallel_simple_example")


"""
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.distributed.utils import load_tensorrt_llm_for_nccl
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)
from torch_tensorrt.dynamo.utils import load_tensorrt_llm_for_nccl

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down
Loading
Loading