use as class methods

This commit is contained in:
Wing Lian
2024-12-12 17:19:43 -05:00
parent 3872d5eaed
commit 8b79f1cbf6
2 changed files with 26 additions and 14 deletions

View File

@@ -107,19 +107,19 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
compile(forward_source, "<forward>", "exec"), self
)
def enable_act_offloading(self):
self.act_offloading_ctx_manager = OffloadActivations()
forward_source = inspect.getsource(self.forward)
@classmethod
def enable_act_offloading(cls):
forward_source = inspect.getsource(cls.forward)
forward_source = forward_source.replace(
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
)
# replace forward method with patched version
self.forward = types.MethodType(
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), self
cls.forward = types.MethodType(
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
)
def enable_liger_fce(self, enable_act_offloading=True):
@classmethod
def enable_liger_fce(cls, enable_act_offloading=True):
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
@@ -128,16 +128,17 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
lce_source = inspect.getsource(llama_lce_forward)
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
# replace forward method with patched version
self.forward = types.MethodType(
cls.forward = types.MethodType(
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
self,
cls,
)
else:
self.forward = types.methodType(llama_lce_forward, self)
cls.forward = types.methodType(llama_lce_forward, cls)
def patch_hf_ga(self):
@classmethod
def patch_hf_ga(cls):
# bugfix patch for gradient accumulation
forward_source = inspect.getsource(self.forward)
forward_source = inspect.getsource(cls.forward)
forward_source = forward_source.replace(
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
)
@@ -145,8 +146,8 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
)
# replace forward method with patched version
self.forward = types.MethodType(
compile(forward_source, "<llama_forward_ga_fix>", "exec"), self
cls.forward = types.MethodType(
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
)
@@ -155,3 +156,5 @@ def replace_auto_model():
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
return AxolotlLlamaForCausalLM

View File

@@ -380,6 +380,15 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
AxolotlLlamaForCausalLM = replace_auto_model()
AxolotlLlamaForCausalLM.patch_hf_ga()
if self.cfg.activation_offloading:
AxolotlLlamaForCausalLM.enable_act_offloading()
if self.cfg.fsdp:
from axolotl.monkeypatch.trainer_fsdp_optim import (
patch_training_loop_for_fsdp,