fix enable_act_offloading

This commit is contained in:
Wing Lian
2024-12-12 17:22:34 -05:00
parent 8b79f1cbf6
commit 43a2f9a155

View File

@@ -5,6 +5,8 @@ import types
from torchtune.training import OffloadActivations
from transformers import LlamaConfig, LlamaForCausalLM
from axolotl.monkeypatch.unsloth_ import detab_code
HF_MODEL_OUTPUTS = """
outputs = self.model(
input_ids=input_ids,
@@ -98,9 +100,10 @@ PATCHED_HF_GA_FORWARD_2 = """
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
act_offloading_ctx_manager = contextlib.nullcontext()
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.act_offloading_ctx_manager = contextlib.nullcontext()
forward_source = inspect.getsource(LlamaForCausalLM.forward)
self.forward = types.MethodType(
@@ -117,6 +120,7 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
cls.forward = types.MethodType(
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
)
cls.act_offloading_ctx_manager = OffloadActivations()
@classmethod
def enable_liger_fce(cls, enable_act_offloading=True):
@@ -145,6 +149,7 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
forward_source = forward_source.replace(
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
)
forward_source = detab_code(forward_source)
# replace forward method with patched version
cls.forward = types.MethodType(
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls