diff --git a/src/axolotl/monkeypatch/models/llama/modeling_llama.py b/src/axolotl/monkeypatch/models/llama/modeling_llama.py index 6105bfe6d..3fa5bf20b 100644 --- a/src/axolotl/monkeypatch/models/llama/modeling_llama.py +++ b/src/axolotl/monkeypatch/models/llama/modeling_llama.py @@ -105,9 +105,12 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM): def __init__(self, config: LlamaConfig): super().__init__(config) + @classmethod + def set_forward(cls): forward_source = inspect.getsource(LlamaForCausalLM.forward) - self.forward = types.MethodType( - compile(forward_source, "", "exec"), self + forward_source, _ = detab_code(forward_source) + cls.forward = types.MethodType( + compile(forward_source, "", "exec"), cls ) @classmethod @@ -162,5 +165,6 @@ def replace_auto_model(): from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM + AxolotlLlamaForCausalLM.set_forward() return AxolotlLlamaForCausalLM