Fix grad checkpoint and outputs param

This commit is contained in:
NanoCode012
2023-06-09 14:28:44 +09:00
parent e44c9e0b3e
commit 2a801b001a

View File

@@ -27,7 +27,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import transformers
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
@@ -52,10 +51,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
MEM_TOKEN = "<landmark>" # nosec MEM_TOKEN = "<landmark>" # nosec
def hijack_llama_landmark_attn():
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
# Copied from transformers.models.bart.modeling_bart._make_causal_mask # Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask( def _make_causal_mask(
input_ids_shape: torch.Size, input_ids_shape: torch.Size,
@@ -1125,7 +1120,7 @@ class LlamaModel(LlamaPreTrainedModel):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, output_attentions, None) return module(*inputs)
return custom_forward return custom_forward
@@ -1135,6 +1130,8 @@ class LlamaModel(LlamaPreTrainedModel):
attention_mask, attention_mask,
position_ids, position_ids,
None, None,
output_attentions,
None,
is_mem, is_mem,
last_section_mask, last_section_mask,
) )
@@ -1300,7 +1297,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
offload_cache_to_cpu=offload_cache_to_cpu, offload_cache_to_cpu=offload_cache_to_cpu,
) )
past_key_values = outputs[1] past_key_values = outputs.past_key_values
if last_logits is not None: if last_logits is not None:
last_logits = torch.cat((last_logits, outputs[0]), dim=-2) last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
last_logits = outputs[0] last_logits = outputs[0]