Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] THD ring attention #1454

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

zlsh80826
Copy link
Collaborator

@zlsh80826 zlsh80826 commented Feb 4, 2025

Description

Support P2P context parallel (ring attn) with THD format. This feature is only available for self attn + causal + segment_ids/pos + load balancing (reorder before the attn and inverse-reorder after the attn).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Refactor reorder/inverse_reorder_causal_loading_balancing API to support different reorder strategy.
  • Support P2P context parallel. The limitations are list above.
  • Reduce the number of test configs of kv_groups in test_distributed_fused_attn
  • Use AttnBiasType, AttnMaskType, QKVLayout in cpp_extenion/attention.py for maintaining the readibility.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zlsh80826 zlsh80826 marked this pull request as draft February 4, 2025 20:27
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@zlsh80826 zlsh80826 marked this pull request as ready for review February 6, 2025 09:46
@zlsh80826 zlsh80826 requested a review from phu0ngng February 6, 2025 09:56
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@zlsh80826 zlsh80826 force-pushed the rewang/thd-ring-attn branch from d486032 to 33ac4d2 Compare February 8, 2025 09:41
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@zlsh80826 zlsh80826 force-pushed the rewang/thd-ring-attn branch from 33ac4d2 to ddd6a8b Compare February 10, 2025 06:38
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

@phu0ngng phu0ngng requested a review from kocchop February 13, 2025 01:54
@zlsh80826 zlsh80826 force-pushed the rewang/thd-ring-attn branch from 6dd5fdb to 4c17948 Compare February 19, 2025 15:13
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

Comment on lines 345 to 350
if strategy == ReorderStrategy.DualChunkSwap:
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True)
if strategy == ReorderStrategy.Striped:
return _inverse_reorder_causal_striped(tensor, cp_size, seq_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Why we implemented the reorder_causal_load_balancing() in jax/cpp_extensions/attention.py but the _inverse_reorder_causal_striped in attention.py?

I think we should make the _reorder_causal_striped have the same API as the tex.reorder_causal_load_balancing() which accepts the boolean if_inverse and can handle both cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last argument for reorder_causal_load_balancing is actually for to_contiguous instead of inverse and not inverse. The reason that reorder_causal_load_balancing need to be under cpp_extensions/attention.py is because that it is also needed by the cpp_extensions/attention.py, but _reorder_causal_striped doesn't need to be instead.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But for a better alignment, I can move it into the cpp_extension/attention.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -310,7 +312,7 @@ def abstract(
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)

if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, how does the full shape of the bias_aval here look like? Is this bias for PreBias or PostBias or both?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When no_bias, there is a 0 shape bias passed. When it is not, it is intend for both PreBias and PostBias.

@zlsh80826 zlsh80826 force-pushed the rewang/thd-ring-attn branch from 4c17948 to fc2ebcb Compare February 24, 2025 14:45
@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants