fix llama modeling
This commit is contained in:
@@ -49,7 +49,6 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
|||||||
attention_mask: torch.Tensor | None = None,
|
attention_mask: torch.Tensor | None = None,
|
||||||
position_ids: torch.LongTensor | None = None,
|
position_ids: torch.LongTensor | None = None,
|
||||||
past_key_value: Cache | None = None,
|
past_key_value: Cache | None = None,
|
||||||
output_attentions: bool | None = False,
|
|
||||||
use_cache: bool | None = False,
|
use_cache: bool | None = False,
|
||||||
cache_position: torch.LongTensor | None = None,
|
cache_position: torch.LongTensor | None = None,
|
||||||
position_embeddings: (
|
position_embeddings: (
|
||||||
@@ -62,12 +61,11 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
|||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states, self_attn_weights = self.self_attn(
|
hidden_states, _ = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
@@ -76,12 +74,7 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
|||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
return hidden_states
|
||||||
outputs = (hidden_states,)
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (self_attn_weights,) # type: ignore
|
|
||||||
|
|
||||||
return outputs # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def patch_llama():
|
def patch_llama():
|
||||||
|
|||||||
Reference in New Issue
Block a user