Fix grad checkpoint and outputs param
This commit is contained in:
@@ -27,7 +27,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import transformers
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
@@ -52,10 +51,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
|
|||||||
MEM_TOKEN = "<landmark>" # nosec
|
MEM_TOKEN = "<landmark>" # nosec
|
||||||
|
|
||||||
|
|
||||||
def hijack_llama_landmark_attn():
|
|
||||||
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||||
def _make_causal_mask(
|
def _make_causal_mask(
|
||||||
input_ids_shape: torch.Size,
|
input_ids_shape: torch.Size,
|
||||||
@@ -1125,7 +1120,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs, output_attentions, None)
|
return module(*inputs)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -1135,6 +1130,8 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
None,
|
None,
|
||||||
|
output_attentions,
|
||||||
|
None,
|
||||||
is_mem,
|
is_mem,
|
||||||
last_section_mask,
|
last_section_mask,
|
||||||
)
|
)
|
||||||
@@ -1300,7 +1297,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
offload_cache_to_cpu=offload_cache_to_cpu,
|
offload_cache_to_cpu=offload_cache_to_cpu,
|
||||||
)
|
)
|
||||||
past_key_values = outputs[1]
|
past_key_values = outputs.past_key_values
|
||||||
if last_logits is not None:
|
if last_logits is not None:
|
||||||
last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
|
last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
|
||||||
last_logits = outputs[0]
|
last_logits = outputs[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user