diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 08d168025..5abc38cff 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin): def patch_llama_like( self, - model_type: str, + model_type_to_patch: str, ) -> None: """ Generic patch for model architectures with causal lm similar to llama @@ -112,7 +112,10 @@ class CutCrossEntropyPlugin(BasePlugin): from cut_cross_entropy.transformers.patch import PATCH_FNS def patch_generic( - maybe_model, patch_options, model_type: str, remote_model_id: str | None + maybe_model, + patch_options, + remote_model_id: str | None, + model_type: str, ): import cut_cross_entropy.transformers.llama from cut_cross_entropy.transformers.llama import cce_forward @@ -136,11 +139,13 @@ class CutCrossEntropyPlugin(BasePlugin): f"Error: {str(e)}" ) from e - if model_type not in PATCH_FNS: + if model_type_to_patch not in PATCH_FNS: LOG.warning_once( - "Setting up generic cce patch for model type: %s", model_type + "Setting up generic cce patch for model type: %s", model_type_to_patch ) LOG.warning_once( - f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected." + f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected." + ) + PATCH_FNS[model_type_to_patch] = partial( + patch_generic, model_type=model_type_to_patch ) - PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)