make sure to set forward first
This commit is contained in:
@@ -105,9 +105,12 @@ class AxolotlLlamaForCausalLM(LlamaForCausalLM):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@classmethod
|
||||
def set_forward(cls):
|
||||
forward_source = inspect.getsource(LlamaForCausalLM.forward)
|
||||
self.forward = types.MethodType(
|
||||
compile(forward_source, "<forward>", "exec"), self
|
||||
forward_source, _ = detab_code(forward_source)
|
||||
cls.forward = types.MethodType(
|
||||
compile(forward_source, "<forward>", "exec"), cls
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -162,5 +165,6 @@ def replace_auto_model():
|
||||
from transformers.models.auto import MODEL_FOR_CAUSAL_LM_MAPPING
|
||||
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING[LlamaConfig] = AxolotlLlamaForCausalLM
|
||||
AxolotlLlamaForCausalLM.set_forward()
|
||||
|
||||
return AxolotlLlamaForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user