Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

asking for pmap training demo #1

Open
guyujun opened this issue Mar 16, 2025 · 1 comment
Open

asking for pmap training demo #1

guyujun opened this issue Mar 16, 2025 · 1 comment
Assignees

Comments

@guyujun
Copy link

guyujun commented Mar 16, 2025

Can you provide a more detail demo for training when using PMAP and multi-GPUs? it will be great help!

@southfreebird
Copy link
Collaborator

kvax.flash_attention works with multiple GPUs by default when attention_specs and Mesh are specified.

Under the hood, it is called within jax.experimental.shard_map. So, you can use kvax.flash_attention just like a regular shard_map.

Here is a simple example of running kvax.flash_attention inside pmap:

import jax
import functools
import jax.numpy as jnp
from jax.experimental import mesh_utils
from kvax.ops import (
    create_attention_mask,
    flash_attention,
)
from kvax.utils import (
    attention_specs,
)

# Define model parameters
embedding = 1024  # Embedding dimension
seq_len = 16   # Sequence length
num_heads = 8  # Number of attention heads
num_kv_heads = 8 # Number of kv attention heads
head_dim = 128  # Dim of each head
num_devices = jax.device_count()
batch_size = num_devices

# Define projection function using shard_map
def project_qkv(x, w_q, w_k, w_v):
    query = x @ w_q.T
    key = x @ w_k.T
    value = x @ w_v.T

    # Reshape tensors to separate num_heads and head_dim axes
    query = query.reshape((-1, seq_len, num_heads, head_dim))
    key = key.reshape((-1, seq_len, num_kv_heads, head_dim))
    value = value.reshape((-1, seq_len, num_kv_heads, head_dim))

    return query, key, value

# Define the full attention computation inside pmap
@functools.partial(
    jax.pmap,
    in_axes=(0, 0, 0, None, None, None),
)
def attention_fn(x, positions, segment_ids, w_q, w_k, w_v):
    
    # Pmap squeeze the input dimention, and we return it here
    # Shmap expect the inputs to have 4 axes
    x = jnp.expand_dims(x, 0)
    positions = jnp.expand_dims(positions, 0)
    segment_ids = jnp.expand_dims(segment_ids, 0)

    query, key, value = project_qkv(x, w_q, w_k, w_v)
    
    # Calculate mask
    mask = create_attention_mask(
        positions,
        segment_ids,
        positions,
        segment_ids,
    )

    # Calculate attention
    output = flash_attention(
        query,
        key,
        value,
        positions,
        segment_ids,
        positions,
        segment_ids,
        mask,
    )

    # Sqeeze the result back
    return jnp.squeeze(output, axis=(0,))

# Simulated input data (batch size = number of devices)
x = jnp.ones((batch_size, seq_len, embedding), dtype=jnp.bfloat16)

# Initialize random projection weights
key = jax.random.PRNGKey(42)
w_q = jax.random.normal(key, (embedding, num_heads * head_dim), dtype=jnp.bfloat16)
w_k = jax.random.normal(key, (embedding, num_kv_heads * head_dim), dtype=jnp.bfloat16)
w_v = jax.random.normal(key, (embedding, num_kv_heads * head_dim), dtype=jnp.bfloat16)

positions = jnp.broadcast_to(jnp.arange(seq_len), (batch_size, seq_len))
segment_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)

# Create mesh for kvax attention
device_mesh = mesh_utils.create_device_mesh(mesh_shape=(num_devices,))
mesh = jax.sharding.Mesh(device_mesh, ('inner_shard',))

# Define attention specs. We do not shard the output because pmap make all
# the work for us.
# It's also possible to implement data paralelism with pmap and tensor or 
# context parallelism with shmap.
with mesh, attention_specs(
        query_specs=(None, None, None, None),
        kv_specs=(None, None, None, None),
    ):
    # Run the computation
    output = attention_fn(x, positions, segment_ids, w_q, w_k, w_v)

# Expected: (num_devices, seq_len, num_heads, head_dim)
print("Attention Output Shape:", output.shape)

Please feel free to reach out if anything is unclear.

@southfreebird southfreebird self-assigned this Mar 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants