fix llama modeling

This commit is contained in:
Wing Lian
2025-07-30 11:37:58 -04:00
parent dfa14f87ab
commit 08aa74e418

View File

@@ -49,7 +49,6 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: (
@@ -62,12 +61,11 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
@@ -76,12 +74,7 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,) # type: ignore
return outputs # type: ignore
return hidden_states
def patch_llama():