make sure to set forward first

This commit is contained in:
Wing Lian
2024-12-12 17:29:34 -05:00
parent 15f2fa4c8e
commit 7ac9cbebb9

View File

@@ -105,9 +105,12 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config: LlamaConfig):
super().__init__(config)
@classmethod
def set_forward(cls):
forward_source = inspect.getsource(LlamaForCausalLM.forward)
self.forward = types.MethodType(
compile(forward_source, "<forward>", "exec"), self
forward_source, _ = detab_code(forward_source)
cls.forward = types.MethodType(
compile(forward_source, "<forward>", "exec"), cls
)
@classmethod
@@ -162,5 +165,6 @@ def replace_auto_model():
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
AxolotlLlamaForCausalLM.set_forward()
return AxolotlLlamaForCausalLM