fix enable_act_offloading
This commit is contained in:
@@ -5,6 +5,8 @@ import types
|
|||||||
from torchtune.training import OffloadActivations
|
from torchtune.training import OffloadActivations
|
||||||
from transformers import LlamaConfig, LlamaForCausalLM
|
from transformers import LlamaConfig, LlamaForCausalLM
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
HF_MODEL_OUTPUTS = """
|
HF_MODEL_OUTPUTS = """
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -98,9 +100,10 @@ PATCHED_HF_GA_FORWARD_2 = """
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||||
|
act_offloading_ctx_manager = contextlib.nullcontext()
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: LlamaConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.act_offloading_ctx_manager = contextlib.nullcontext()
|
|
||||||
|
|
||||||
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||||
self.forward = types.MethodType(
|
self.forward = types.MethodType(
|
||||||
@@ -117,6 +120,7 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
cls.forward = types.MethodType(
|
cls.forward = types.MethodType(
|
||||||
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
compile(forward_source, "<llama_forward_w_act_offloading>", "exec"), cls
|
||||||
)
|
)
|
||||||
|
cls.act_offloading_ctx_manager = OffloadActivations()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enable_liger_fce(cls, enable_act_offloading=True):
|
def enable_liger_fce(cls, enable_act_offloading=True):
|
||||||
@@ -145,6 +149,7 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
forward_source = forward_source.replace(
|
forward_source = forward_source.replace(
|
||||||
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
HF_GA_FORWARD_2, PATCHED_HF_GA_FORWARD_2
|
||||||
)
|
)
|
||||||
|
forward_source = detab_code(forward_source)
|
||||||
# replace forward method with patched version
|
# replace forward method with patched version
|
||||||
cls.forward = types.MethodType(
|
cls.forward = types.MethodType(
|
||||||
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
compile(forward_source, "<llama_forward_ga_fix>", "exec"), cls
|
||||||
|
|||||||
Reference in New Issue
Block a user