fix llama modeling
This commit is contained in:
@@ -49,7 +49,6 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
use_cache: bool | None = False,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
position_embeddings: (
|
||||
@@ -62,12 +61,11 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
@@ -76,12 +74,7 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,) # type: ignore
|
||||
|
||||
return outputs # type: ignore
|
||||
return hidden_states
|
||||
|
||||
|
||||
def patch_llama():
|
||||
|
||||
Reference in New Issue
Block a user