use as class methods
This commit is contained in:
@@ -107,19 +107,19 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
compile(forward_source, "<forward>", "exec"), self
|
compile(forward_source, "<forward>", "exec"), self
|
||||||
)
|
)
|
||||||
|
|
||||||
def enable_act_offloading(self):
|
@classmethod
|
||||||
self.act_offloading_ctx_manager = OffloadActivations()
|
def enable_act_offloading(cls):
|
||||||
|
forward_source = inspect.getsource(cls.forward)
|
||||||
forward_source = inspect.getsource(self.forward)
|
|
||||||
forward_source = forward_source.replace(
|
forward_source = forward_source.replace(
|
||||||
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
HF_MODEL_OUTPUTS, PATCHED_HF_MODEL_OUTPUTS
|
||||||
)
|
)
|
||||||
# replace forward method with patched version
|
# replace forward method with patched version
|
||||||
self.forward = types.MethodType(
|
cls.forward = types.MethodType(
|
||||||
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), self
|
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||||
)
|
)
|
||||||
|
|
||||||
def enable_liger_fce(self, enable_act_offloading=True):
|
@classmethod
|
||||||
|
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||||
from liger_kernel.transformers.model.llama import (
|
from liger_kernel.transformers.model.llama import (
|
||||||
lce_forward as llama_lce_forward,
|
lce_forward as llama_lce_forward,
|
||||||
)
|
)
|
||||||
@@ -128,16 +128,17 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
lce_source = inspect.getsource(llama_lce_forward)
|
lce_source = inspect.getsource(llama_lce_forward)
|
||||||
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
lce_source = lce_source.replace(LCE_MODEL_OUTPUTS, PATCHED_LCE_OUTPUTS)
|
||||||
# replace forward method with patched version
|
# replace forward method with patched version
|
||||||
self.forward = types.MethodType(
|
cls.forward = types.MethodType(
|
||||||
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
compile(lce_source, "<llama_lce_forward_w_act_offloading>", "exec"),
|
||||||
self,
|
cls,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.forward = types.methodType(llama_lce_forward, self)
|
cls.forward = types.methodType(llama_lce_forward, cls)
|
||||||
|
|
||||||
def patch_hf_ga(self):
|
@classmethod
|
||||||
|
def patch_hf_ga(cls):
|
||||||
# bugfix patch for gradient accumulation
|
# bugfix patch for gradient accumulation
|
||||||
forward_source = inspect.getsource(self.forward)
|
forward_source = inspect.getsource(cls.forward)
|
||||||
forward_source = forward_source.replace(
|
forward_source = forward_source.replace(
|
||||||
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
HF_GA_FORWARD_1, PATCHED_HF_GA_FORWARD_1
|
||||||
)
|
)
|
||||||
@@ -145,8 +146,8 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||||
)
|
)
|
||||||
# replace forward method with patched version
|
# replace forward method with patched version
|
||||||
self.forward = types.MethodType(
|
cls.forward = types.MethodType(
|
||||||
compile(forward_source, "<llama_forward_ga_fix>", "exec"), self
|
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -155,3 +156,5 @@ def replace_auto_model():
|
|||||||
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||||
|
|
||||||
|
return AxolotlLlamaForCausalLM
|
||||||
|
|||||||
@@ -380,6 +380,15 @@ class ModelLoader:
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.pre_model_load(self.cfg)
|
plugin_manager.pre_model_load(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "llama":
|
||||||
|
from axolotl.monkeypatch.models.llama.modeling_llama import replace_auto_model
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM = replace_auto_model()
|
||||||
|
|
||||||
|
AxolotlLlamaForCausalLM.patch_hf_ga()
|
||||||
|
if self.cfg.activation_offloading:
|
||||||
|
AxolotlLlamaForCausalLM.enable_act_offloading()
|
||||||
|
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||||
patch_training_loop_for_fsdp,
|
patch_training_loop_for_fsdp,
|
||||||
|
|||||||
Reference in New Issue
Block a user