fix generic patch for cce (#3405)
This commit is contained in:
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
|
||||
def patch_llama_like(
|
||||
self,
|
||||
model_type: str,
|
||||
model_type_to_patch: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generic patch for model architectures with causal lm similar to llama
|
||||
@@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||
|
||||
def patch_generic(
|
||||
maybe_model, patch_options, model_type: str, remote_model_id: str | None
|
||||
maybe_model,
|
||||
patch_options,
|
||||
remote_model_id: str | None,
|
||||
model_type: str,
|
||||
):
|
||||
import cut_cross_entropy.transformers.llama
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
@@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
if model_type not in PATCH_FNS:
|
||||
if model_type_to_patch not in PATCH_FNS:
|
||||
LOG.warning_once(
|
||||
"Setting up generic cce patch for model type: %s", model_type
|
||||
"Setting up generic cce patch for model type: %s", model_type_to_patch
|
||||
)
|
||||
LOG.warning_once(
|
||||
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
|
||||
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
|
||||
)
|
||||
PATCH_FNS[model_type_to_patch] = partial(
|
||||
patch_generic, model_type=model_type_to_patch
|
||||
)
|
||||
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
|
||||
|
||||
Reference in New Issue
Block a user