pre-commit fix
This commit is contained in:
@@ -1,48 +0,0 @@
|
|||||||
"""Patches related to differential transformers implementation."""
|
|
||||||
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
|
||||||
|
|
||||||
from axolotl.integrations.diff_transformer.diff_attn import (
|
|
||||||
LlamaDifferentialAttention,
|
|
||||||
LlamaDifferentialFlashAttention2,
|
|
||||||
LlamaDifferentialSdpaAttention,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_llama_attention_classes():
|
|
||||||
"""Patch transformers to support differential attention"""
|
|
||||||
# Add our attention class to the registry
|
|
||||||
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
|
|
||||||
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
|
|
||||||
LLAMA_ATTENTION_CLASSES[
|
|
||||||
"differential_flash_attention_2"
|
|
||||||
] = LlamaDifferentialFlashAttention2
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
|
|
||||||
config._attn_implementation_autoset = True # pylint: disable=protected-access
|
|
||||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
|
||||||
|
|
||||||
valid_impls = [
|
|
||||||
None,
|
|
||||||
"eager",
|
|
||||||
"sdpa",
|
|
||||||
"flash_attention_2",
|
|
||||||
"differential_eager",
|
|
||||||
"differential_sdpa",
|
|
||||||
"differential_flash_attention_2",
|
|
||||||
]
|
|
||||||
if attn_implementation not in valid_impls:
|
|
||||||
message = (
|
|
||||||
f"Specified `attn_implementation={attn_implementation}` is not supported. "
|
|
||||||
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
|
|
||||||
)
|
|
||||||
raise ValueError(message + ".")
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
# Apply patch
|
|
||||||
PreTrainedModel._autoset_attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
new_autoset
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user