make sure to set forward first
This commit is contained in:
@@ -105,9 +105,12 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: LlamaConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_forward(cls):
|
||||||
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||||
self.forward = types.MethodType(
|
forward_source, _ = detab_code(forward_source)
|
||||||
compile(forward_source, "<forward>", "exec"), self
|
cls.forward = types.MethodType(
|
||||||
|
compile(forward_source, "<forward>", "exec"), cls
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -162,5 +165,6 @@ def replace_auto_model():
|
|||||||
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||||
|
AxolotlLlamaForCausalLM.set_forward()
|
||||||
|
|
||||||
return AxolotlLlamaForCausalLM
|
return AxolotlLlamaForCausalLM
|
||||||
|
|||||||
Reference in New Issue
Block a user