Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def tie_weights(self):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, attention_mask=None, labels=None):
"""
"position_ids" is just to be compatible with Transformer generation. We don't use it.
num_last_tokens: if > 0, only return the logits for the last n tokens
Expand All @@ -234,8 +234,25 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

# Source: https://github.com/huggingface/transformers/blob/80377eb018c077dba434bc8e7912bcaed3a64d09/src/transformers/models/llama/modeling_llama.py#L1196
if labels is not None:
logits = lm_logits
vocab_size = logits.shape[-1]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return (loss,)
else:
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
return CausalLMOutput(logits=lm_logits)

@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
Expand Down