use as class methods
This commit is contained in:
@@ -107,19 +107,19 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||
compile(forward_source, "<forward>", "exec"), self
|
||||
)
|
||||
|
||||
def enable_act_offloading(self):
|
||||
self.act_offloading_ctx_manager = OffloadActivations()
|
||||
|
||||
forward_source = inspect.getsource(self.forward)
|
||||
@classmethod
|
||||
def enable_act_offloading(cls):
|
||||
forward_source = inspect.getsource(cls.forward)
|
||||
forward_source = forward_source.replace(
|
||||
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
||||
)
|
||||
# replace forward method with patched version
|
||||
self.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), self
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||
)
|
||||
|
||||
def enable_liger_fce(self, enable_act_offloading=True):
|
||||
@classmethod
|
||||
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||
from liger_kernel.transformers.model.llama import (
|
||||
lce_forward as llama_lce_forward,
|
||||
)
|
||||
@@ -128,16 +128,17 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||
lce_source = inspect.getsource(llama_lce_forward)
|
||||
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
||||
# replace forward method with patched version
|
||||
self.forward = types.MethodType(
|
||||
cls.forward = types.MethodType(
|
||||
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
||||
self,
|
||||
cls,
|
||||
)
|
||||
else:
|
||||
self.forward = types.methodType(llama_lce_forward, self)
|
||||
cls.forward = types.methodType(llama_lce_forward, cls)
|
||||
|
||||
def patch_hf_ga(self):
|
||||
@classmethod
|
||||
def patch_hf_ga(cls):
|
||||
# bugfix patch for gradient accumulation
|
||||
forward_source = inspect.getsource(self.forward)
|
||||
forward_source = inspect.getsource(cls.forward)
|
||||
forward_source = forward_source.replace(
|
||||
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
||||
)
|
||||
@@ -145,8 +146,8 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||
)
|
||||
# replace forward method with patched version
|
||||
self.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_ga_fix>", "exec"), self
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||
)
|
||||
|
||||
|
||||
@@ -155,3 +156,5 @@ def replace_auto_model():
|
||||
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||
|
||||
return AxolotlLlamaForCausalLM
|
||||
|
||||
@@ -380,6 +380,15 @@ class ModelLoader:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.pre_model_load(self.cfg)
|
||||
|
||||
if self.cfg.model_config_type == "llama":
|
||||
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
|
||||
|
||||
AxolotlLlamaForCausalLM = replace_auto_model()
|
||||
|
||||
AxolotlLlamaForCausalLM.patch_hf_ga()
|
||||
if self.cfg.activation_offloading:
|
||||
AxolotlLlamaForCausalLM.enable_act_offloading()
|
||||
|
||||
if self.cfg.fsdp:
|
||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||
patch_training_loop_for_fsdp,
|
||||
|
||||
Reference in New Issue
Block a user