From 2a801b001a04049c644ede0225c71ece017d2a95 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 9 Jun 2023 14:28:44 +0900 Subject: [PATCH] Fix grad checkpoint and outputs param --- src/axolotl/monkeypatch/llama_landmark_attn.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py index 64719639e..18e913f09 100644 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ b/src/axolotl/monkeypatch/llama_landmark_attn.py @@ -27,7 +27,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint -import transformers from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN @@ -52,10 +51,6 @@ _CONFIG_FOR_DOC = "LlamaConfig" MEM_TOKEN = "" # nosec -def hijack_llama_landmark_attn(): - transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM - - # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask( input_ids_shape: torch.Size, @@ -1125,7 +1120,7 @@ class LlamaModel(LlamaPreTrainedModel): def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, output_attentions, None) + return module(*inputs) return custom_forward @@ -1135,6 +1130,8 @@ class LlamaModel(LlamaPreTrainedModel): attention_mask, position_ids, None, + output_attentions, + None, is_mem, last_section_mask, ) @@ -1300,7 +1297,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): return_dict=return_dict, offload_cache_to_cpu=offload_cache_to_cpu, ) - past_key_values = outputs[1] + past_key_values = outputs.past_key_values if last_logits is not None: last_logits = torch.cat((last_logits, outputs[0]), dim=-2) last_logits = outputs[0]