use as class methods

This commit is contained in:
Wing Lian
2024-12-12 17:19:43 -05:00
parent 3872d5eaed
commit 8b79f1cbf6
2 changed files with 26 additions and 14 deletions

View File

@@ -107,19 +107,19 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
compile(forward_source, "<forward>", "exec"), self compile(forward_source, "<forward>", "exec"), self
) )
def enable_act_offloading(self): @classmethod
self.act_offloading_ctx_manager = OffloadActivations() def enable_act_offloading(cls):
forward_source = inspect.getsource(cls.forward)
forward_source = inspect.getsource(self.forward)
forward_source = forward_source.replace( forward_source = forward_source.replace(
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
) )
# replace forward method with patched version # replace forward method with patched version
self.forward = types.MethodType( cls.forward = types.MethodType(
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), self compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
) )
def enable_liger_fce(self, enable_act_offloading=True): @classmethod
def enable_liger_fce(cls, enable_act_offloading=True):
from liger_kernel.transformers.model.llama import ( from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward, lce_forward as llama_lce_forward,
) )
@@ -128,16 +128,17 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
lce_source = inspect.getsource(llama_lce_forward) lce_source = inspect.getsource(llama_lce_forward)
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS) lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
# replace forward method with patched version # replace forward method with patched version
self.forward = types.MethodType( cls.forward = types.MethodType(
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"), compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
self, cls,
) )
else: else:
self.forward = types.methodType(llama_lce_forward, self) cls.forward = types.methodType(llama_lce_forward, cls)
def patch_hf_ga(self): @classmethod
def patch_hf_ga(cls):
# bugfix patch for gradient accumulation # bugfix patch for gradient accumulation
forward_source = inspect.getsource(self.forward) forward_source = inspect.getsource(cls.forward)
forward_source = forward_source.replace( forward_source = forward_source.replace(
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1 HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
) )
@@ -145,8 +146,8 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2 HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
) )
# replace forward method with patched version # replace forward method with patched version
self.forward = types.MethodType( cls.forward = types.MethodType(
compile(forward_source, "<llama_forward_ga_fix>", "exec"), self compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
) )
@@ -155,3 +156,5 @@ def replace_auto_model():
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
return AxolotlLlamaForCausalLM

View File

@@ -380,6 +380,15 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg) plugin_manager.pre_model_load(self.cfg)
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
AxolotlLlamaForCausalLM = replace_auto_model()
AxolotlLlamaForCausalLM.patch_hf_ga()
if self.cfg.activation_offloading:
AxolotlLlamaForCausalLM.enable_act_offloading()
if self.cfg.fsdp: if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import ( from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp, patch_training_loop_for_fsdp,