diff --git a/src/axolotl/monkeypatch/models/llama/modeling_llama.py b/src/axolotl/monkeypatch/models/llama/modeling_llama.py index c0a20fbd3..94a4f6580 100644 --- a/src/axolotl/monkeypatch/models/llama/modeling_llama.py +++ b/src/axolotl/monkeypatch/models/llama/modeling_llama.py @@ -107,19 +107,19 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM): compile(forward_source, "", "exec"), self ) - def enable_act_offloading(self): - self.act_offloading_ctx_manager = OffloadActivations() - - forward_source = inspect.getsource(self.forward) + @classmethod + def enable_act_offloading(cls): + forward_source = inspect.getsource(cls.forward) forward_source = forward_source.replace( HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS ) # replace forward method with patched version - self.forward = types.MethodType( - compile(forward_source, "", "exec"), self + cls.forward = types.MethodType( + compile(forward_source, "", "exec"), cls ) - def enable_liger_fce(self, enable_act_offloading=True): + @classmethod + def enable_liger_fce(cls, enable_act_offloading=True): from liger_kernel.transformers.model.llama import ( lce_forward as llama_lce_forward, ) @@ -128,16 +128,17 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM): lce_source = inspect.getsource(llama_lce_forward) lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS) # replace forward method with patched version - self.forward = types.MethodType( + cls.forward = types.MethodType( compile(lce_source, "", "exec"), - self, + cls, ) else: - self.forward = types.methodType(llama_lce_forward, self) + cls.forward = types.methodType(llama_lce_forward, cls) - def patch_hf_ga(self): + @classmethod + def patch_hf_ga(cls): # bugfix patch for gradient accumulation - forward_source = inspect.getsource(self.forward) + forward_source = inspect.getsource(cls.forward) forward_source = forward_source.replace( HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1 ) @@ -145,8 +146,8 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM): HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2 ) # replace forward method with patched version - self.forward = types.MethodType( - compile(forward_source, "", "exec"), self + cls.forward = types.MethodType( + compile(forward_source, "", "exec"), cls ) @@ -155,3 +156,5 @@ def replace_auto_model(): from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM + + return AxolotlLlamaForCausalLM diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4db9107c7..b44e1ec38 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -380,6 +380,15 @@ class ModelLoader: plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) + if self.cfg.model_config_type == "llama": + from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model + + AxolotlLlamaForCausalLM = replace_auto_model() + + AxolotlLlamaForCausalLM.patch_hf_ga() + if self.cfg.activation_offloading: + AxolotlLlamaForCausalLM.enable_act_offloading() + if self.cfg.fsdp: from axolotl.monkeypatch.trainer_fsdp_optim import ( patch_training_loop_for_fsdp,