Fix grad checkpoint and outputs param
This commit is contained in:
@@ -27,7 +27,6 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.activations import ACT2FN
|
||||
@@ -52,10 +51,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
|
||||
MEM_TOKEN = "<landmark>" # nosec
|
||||
|
||||
|
||||
def hijack_llama_landmark_attn():
|
||||
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
@@ -1125,7 +1120,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
@@ -1135,6 +1130,8 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
output_attentions,
|
||||
None,
|
||||
is_mem,
|
||||
last_section_mask,
|
||||
)
|
||||
@@ -1300,7 +1297,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
return_dict=return_dict,
|
||||
offload_cache_to_cpu=offload_cache_to_cpu,
|
||||
)
|
||||
past_key_values = outputs[1]
|
||||
past_key_values = outputs.past_key_values
|
||||
if last_logits is not None:
|
||||
last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
|
||||
last_logits = outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user