Fix grad checkpoint and outputs param

This commit is contained in:
NanoCode012
2023-06-09 14:28:44 +09:00
parent e44c9e0b3e
commit 2a801b001a

View File

@@ -27,7 +27,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
import transformers
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
@@ -52,10 +51,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
MEM_TOKEN = "<landmark>" # nosec
def hijack_llama_landmark_attn():
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
@@ -1125,7 +1120,7 @@ class LlamaModel(LlamaPreTrainedModel):
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return module(*inputs)
return custom_forward
@@ -1135,6 +1130,8 @@ class LlamaModel(LlamaPreTrainedModel):
attention_mask,
position_ids,
None,
output_attentions,
None,
is_mem,
last_section_mask,
)
@@ -1300,7 +1297,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
return_dict=return_dict,
offload_cache_to_cpu=offload_cache_to_cpu,
)
past_key_values = outputs[1]
past_key_values = outputs.past_key_values
if last_logits is not None:
last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
last_logits = outputs[0]