fix enable_act_offloading
This commit is contained in:
@@ -5,6 +5,8 @@ import types
|
||||
from torchtune.training import OffloadActivations
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||
|
||||
HF_MODEL_OUTPUTS = """
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
@@ -98,9 +100,10 @@ PATCHED_HF_GA_FORWARD_2 = """
|
||||
|
||||
|
||||
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||
act_offloading_ctx_manager = contextlib.nullcontext()
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.act_offloading_ctx_manager = contextlib.nullcontext()
|
||||
|
||||
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||
self.forward = types.MethodType(
|
||||
@@ -117,6 +120,7 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||
)
|
||||
cls.act_offloading_ctx_manager = OffloadActivations()
|
||||
|
||||
@classmethod
|
||||
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||
@@ -145,6 +149,7 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||
forward_source = forward_source.replace(
|
||||
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||
)
|
||||
forward_source = detab_code(forward_source)
|
||||
# replace forward method with patched version
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||
|
||||
Reference in New Issue
Block a user