Skip to content

[shardformer] Upgrade transformers version: falcon model #6283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: upgrade_transformers
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 48 additions & 87 deletions colossalai/shardformer/modeling/falcon.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -110,19 +103,17 @@ def forward(
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
layer_past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states

if self.config.new_decoder_architecture:
if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
Expand All @@ -138,7 +129,8 @@ def forward(
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
**kwargs,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

attention_output = attn_outputs[0]
Expand All @@ -152,6 +144,13 @@ def forward(
)
mlp_layernorm_out = self.post_attention_layernorm(residual)

if (
self.config.new_decoder_architecture
and self.config.parallel_attn
and self.config.num_ln_in_parallel_attn == 1
):
mlp_layernorm_out = attention_layernorm_out

outputs = attn_outputs[1:]

# MLP.
Expand All @@ -167,7 +166,7 @@ def forward(
else:
outputs = (output,) + outputs[1:]

return outputs # hidden_states, present, attentions
return outputs # hidden_states, past_kv, attentions

return forward

Expand All @@ -190,11 +189,14 @@ def falcon_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
# Add cache_position and position_embeddings args for v4.51.3 transformers

logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -206,9 +208,8 @@ def falcon_model_forward(
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False

if past_key_values is not None:
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
past_key_values = None
logger.warning_once("past_key_values is not supported for pipeline models at the moment.")
past_key_values = None

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand All @@ -221,20 +222,17 @@ def falcon_model_forward(
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape

if past_key_values is None:
past_key_values = tuple([None] * len(self.h))

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
Expand All @@ -243,10 +241,10 @@ def falcon_model_forward(
all_hidden_states = () if output_hidden_states else None

# Compute alibi tensor: check build_alibi_tensor documentation
alibi = None
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-2]

batch_size, seq_length, _ = hidden_states.shape
if self.use_alibi:
mask = (
torch.ones(
Expand All @@ -256,73 +254,31 @@ def falcon_model_forward(
else attention_mask
)
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if alibi is None:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
elif head_mask is None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])

attention_mask_2d = attention_mask
# We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

# We take care to integrate alibi bias in the attention_mask here.
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
min_dtype,
)

# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# use new version of causal mask construction.
# In v4.51.3 version, sdpa, egaer and flash attention are merged into one class.
causal_mask = self._update_causal_mask(
attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(
zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx
):
for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

Expand All @@ -331,28 +287,32 @@ def falcon_model_forward(
block.__call__,
hidden_states,
alibi,
attention_mask,
causal_mask,
position_ids,
head_mask[i],
layer_past,
past_key_values,
use_cache,
output_attentions,
cache_position,
position_embeddings,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
layer_past=past_key_values,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
cache_position=cache_position,
position_embeddings=position_embeddings,
)

hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
outputs[1]

if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
Expand All @@ -365,6 +325,7 @@ def falcon_model_forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if stage_manager.is_last_stage():

if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/policies/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def get_held_layers(self) -> List[Module]:
module = self.model.transformer
stage_manager = self.pipeline_stage_manager
held_layers = []
held_layers.append(module.rotary_emb)
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = stage_manager.distribute_layers(len(module.h))
Expand Down