From 7ac9cbebb981e6c1b99305edf270997fe16453c5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Dec 2024 17:29:34 -0500 Subject: [PATCH] make sure to set forward first --- src/axolotl/monkeypatch/models/llama/modeling_llama.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/models/llama/modeling_llama.py b/src/axolotl/monkeypatch/models/llama/modeling_llama.py index 6105bfe6d..3fa5bf20b 100644 --- a/src/axolotl/monkeypatch/models/llama/modeling_llama.py +++ b/src/axolotl/monkeypatch/models/llama/modeling_llama.py @@ -105,9 +105,12 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM): def __init__(self, config: LlamaConfig): super().__init__(config) + @classmethod + def set_forward(cls): forward_source = inspect.getsource(LlamaForCausalLM.forward) - self.forward = types.MethodType( - compile(forward_source, "", "exec"), self + forward_source, _ = detab_code(forward_source) + cls.forward = types.MethodType( + compile(forward_source, "", "exec"), cls ) @classmethod @@ -162,5 +165,6 @@ def replace_auto_model(): from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM + AxolotlLlamaForCausalLM.set_forward() return AxolotlLlamaForCausalLM