fix generic patch for cce (#3405)
This commit is contained in:
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
|
|
||||||
def patch_llama_like(
|
def patch_llama_like(
|
||||||
self,
|
self,
|
||||||
model_type: str,
|
model_type_to_patch: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Generic patch for model architectures with causal lm similar to llama
|
Generic patch for model architectures with causal lm similar to llama
|
||||||
@@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||||
|
|
||||||
def patch_generic(
|
def patch_generic(
|
||||||
maybe_model, patch_options, model_type: str, remote_model_id: str | None
|
maybe_model,
|
||||||
|
patch_options,
|
||||||
|
remote_model_id: str | None,
|
||||||
|
model_type: str,
|
||||||
):
|
):
|
||||||
import cut_cross_entropy.transformers.llama
|
import cut_cross_entropy.transformers.llama
|
||||||
from cut_cross_entropy.transformers.llama import cce_forward
|
from cut_cross_entropy.transformers.llama import cce_forward
|
||||||
@@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
f"Error: {str(e)}"
|
f"Error: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
if model_type not in PATCH_FNS:
|
if model_type_to_patch not in PATCH_FNS:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Setting up generic cce patch for model type: %s", model_type
|
"Setting up generic cce patch for model type: %s", model_type_to_patch
|
||||||
)
|
)
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
|
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
|
||||||
|
)
|
||||||
|
PATCH_FNS[model_type_to_patch] = partial(
|
||||||
|
patch_generic, model_type=model_type_to_patch
|
||||||
)
|
)
|
||||||
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user