From 43a2f9a155c05ba67d5b52245a1e7dec2a156c8e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Dec 2024 17:22:34 -0500 Subject: [PATCH] fix enable_act_offloading --- src/axolotl/monkeypatch/models/llama/modeling_llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/models/llama/modeling_llama.py b/src/axolotl/monkeypatch/models/llama/modeling_llama.py index 94a4f6580..2d8d06168 100644 --- a/src/axolotl/monkeypatch/models/llama/modeling_llama.py +++ b/src/axolotl/monkeypatch/models/llama/modeling_llama.py @@ -5,6 +5,8 @@ import types from torchtune.training import OffloadActivations from transformers import LlamaConfig, LlamaForCausalLM +from axolotl.monkeypatch.unsloth_ import detab_code + HF_MODEL_OUTPUTS = """ outputs = self.model( input_ids=input_ids, @@ -98,9 +100,10 @@ PATCHED_HF_GA_FORWARD_2 = """ class AxolotlLlamaForCausalLM(LlamaForCausalLM): + act_offloading_ctx_manager = contextlib.nullcontext() + def __init__(self, config: LlamaConfig): super().__init__(config) - self.act_offloading_ctx_manager = contextlib.nullcontext() forward_source = inspect.getsource(LlamaForCausalLM.forward) self.forward = types.MethodType( @@ -117,6 +120,7 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM): cls.forward = types.MethodType( compile(forward_source, "", "exec"), cls ) + cls.act_offloading_ctx_manager = OffloadActivations() @classmethod def enable_liger_fce(cls, enable_act_offloading=True): @@ -145,6 +149,7 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM): forward_source = forward_source.replace( HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2 ) + forward_source = detab_code(forward_source) # replace forward method with patched version cls.forward = types.MethodType( compile(forward_source, "", "exec"), cls