From 08aa74e418e8ffe9e27aada78c594294b96dbb98 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 30 Jul 2025 11:37:58 -0400 Subject: [PATCH] fix llama modeling --- .../integrations/modeling/llama/modeling_llama.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/axolotl/integrations/modeling/llama/modeling_llama.py b/src/axolotl/integrations/modeling/llama/modeling_llama.py index b2df51242..1fe9983a4 100644 --- a/src/axolotl/integrations/modeling/llama/modeling_llama.py +++ b/src/axolotl/integrations/modeling/llama/modeling_llama.py @@ -49,7 +49,6 @@ class LlamaDecoderLayer(GradientCheckpointingLayer): attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_value: Cache | None = None, - output_attentions: bool | None = False, use_cache: bool | None = False, cache_position: torch.LongTensor | None = None, position_embeddings: ( @@ -62,12 +61,11 @@ class LlamaDecoderLayer(GradientCheckpointingLayer): hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, - output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -76,12 +74,7 @@ class LlamaDecoderLayer(GradientCheckpointingLayer): hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) # type: ignore - - return outputs # type: ignore + return hidden_states def patch_llama():