replace attention in the yaml config with an enum

This commit is contained in:
Wing Lian
2025-04-04 23:37:30 -04:00
committed by NanoCode012
parent 0d71b0aa5f
commit ba47adc24b
101 changed files with 268 additions and 122 deletions

View File

@@ -59,9 +59,7 @@ gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
sdp_attention:
flash_optimum:
attention: flash
gptq_groupsize:
gptq_model_v1:

View File

@@ -39,8 +39,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attention: xformers
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -49,7 +49,8 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -112,9 +112,7 @@
"early_stopping_patience:\n",
"resume_from_checkpoint:\n",
"logging_steps: 1\n",
"xformers_attention:\n",
"flash_attention: false\n",
"sdp_attention: true\n",
"attention: sdpa\n",
"\n",
"warmup_steps: 1\n",
"max_steps: 25\n",

View File

@@ -52,7 +52,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch:

View File

@@ -55,7 +55,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch:

View File

@@ -39,7 +39,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch:

View File

@@ -35,7 +35,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 100
evals_per_epoch: 2

View File

@@ -59,7 +59,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 100
evals_per_epoch: 2

View File

@@ -43,8 +43,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attention: xformers
gptq_groupsize:
gptq_model_v1:
warmup_steps: 40

View File

@@ -73,8 +73,7 @@ early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
logging_steps: 1
xformers_attention: true
flash_attention:
attention: xformers
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10

View File

@@ -40,8 +40,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attention: xformers
gptq_groupsize:
gptq_model_v1:
warmup_steps: 40

View File

@@ -47,7 +47,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -53,7 +53,8 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -43,7 +43,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -57,7 +57,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -51,8 +51,7 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
eager_attention:
attention: flash
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -53,8 +53,7 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
eager_attention:
attention: flash
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -36,8 +36,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attention: xformers
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10

View File

@@ -47,7 +47,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch:

View File

@@ -46,7 +46,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch:

View File

@@ -45,7 +45,8 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 1

View File

@@ -37,8 +37,7 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
xformers_attention: true
flash_attention:
attention: xformers
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20

View File

@@ -42,7 +42,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false

View File

@@ -53,9 +53,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention:
sdp_attention:
flash_optimum:
attention: flash
warmup_steps: 100
evals_per_epoch: 4
saves_per_epoch: 1

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ gradient_checkpointing_kwargs:
use_reentrant: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -50,8 +50,7 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
attention: flash
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -49,7 +49,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 100
evals_per_epoch: 2

View File

@@ -34,7 +34,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 100
evals_per_epoch: 2

View File

@@ -61,7 +61,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -56,7 +56,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -77,7 +77,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -53,7 +53,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -54,7 +54,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -48,7 +48,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -55,7 +55,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -49,7 +49,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -53,7 +53,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 20
evals_per_epoch: 4

View File

@@ -51,7 +51,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -39,7 +39,8 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ gradient_checkpointing_kwargs:
use_reentrant: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -46,8 +46,7 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
attention: flash
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -39,7 +39,7 @@ tf32: true
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention:
attention: eager
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -42,7 +42,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
save_total_limit: 1
save_steps:

View File

@@ -36,7 +36,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -53,8 +53,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: false
sdp_attention: true
attention: sdpa
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -54,7 +54,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -71,7 +71,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: false
attention: eager
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -51,7 +51,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -59,7 +59,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -48,9 +48,7 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
eager_attention:
attention: eager # PixtralVisionModel does not support Flash Attention 2.0 yet.
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -49,7 +49,8 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -51,7 +51,8 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -69,7 +69,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -40,7 +40,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
save_total_limit: 1
save_steps:

View File

@@ -54,7 +54,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -39,7 +39,7 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
flash_attention:
attention: eager
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20

View File

@@ -39,7 +39,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20

View File

@@ -47,7 +47,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20

View File

@@ -40,7 +40,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20

View File

@@ -48,7 +48,8 @@ gradient_checkpointing_kwargs:
use_reentrant: True
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 100
evals_per_epoch: 4

View File

@@ -51,7 +51,8 @@ gradient_checkpointing_kwargs:
use_reentrant: True
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 100
evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ gradient_checkpointing_kwargs:
use_reentrant: True
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 100
evals_per_epoch: 4

View File

@@ -49,7 +49,8 @@ gradient_checkpointing_kwargs:
use_reentrant: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 100
evals_per_epoch: 4

View File

@@ -44,7 +44,8 @@ gradient_checkpointing_kwargs:
use_reentrant: True
early_stopping_patience: 3
logging_steps: 1
flash_attention: true
attention: flash
eval_steps: 1000
save_steps: 5000

View File

@@ -46,8 +46,7 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
eager_attention:
attention: eager # PixtralVisionModel does not support Flash Attention 2.0 yet
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention:
attention: eager
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention:
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -43,7 +43,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -46,8 +46,7 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
attention: flash
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -49,7 +49,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -47,7 +47,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -47,7 +47,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -43,7 +43,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -40,7 +40,7 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
flash_attention:
attention: flash
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20

View File

@@ -38,7 +38,7 @@ tf32: true
gradient_checkpointing:
resume_from_checkpoint:
logging_steps: 1
flash_attention:
attention: flash
gptq_groupsize:
gptq_model_v1:
warmup_steps: 20

View File

@@ -44,7 +44,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false

View File

@@ -47,7 +47,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
flash_attn_cross_entropy: false
flash_attn_rms_norm: true

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 20
evals_per_epoch: 4

View File

@@ -47,7 +47,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: false
attention: eager
warmup_steps: 10
evals_per_epoch: 0

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -36,7 +36,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch:

View File

@@ -47,7 +47,8 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attention: flash
warmup_steps: 10
evals_per_epoch: 4

View File

@@ -71,8 +71,7 @@ early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
logging_steps: 1
xformers_attention: true
flash_attention:
attention: xformers
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10

View File

@@ -10,7 +10,8 @@ load_in_4bit: true
sequence_len: 1024
bf16: auto
tf32: false
flash_attention: true
attention: flash
special_tokens:
bos_token: "<|startoftext|>"
eos_token: "<|endoftext|>"

View File

@@ -27,7 +27,7 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RLType
from axolotl.utils.schemas.enums import AttentionBackend, ChatTemplate, RLType
from axolotl.utils.schemas.integrations import (
CometConfig,
GradioConfig,
@@ -222,6 +222,10 @@ class AxolotlInputConfig(
},
)
attention: AttentionBackend = Field(
default=AttentionBackend.flash,
json_schema_extra={"description": "attention backend to use"},
)
xformers_attention: bool | None = None
sdp_attention: bool | None = None
s2_attention: bool | None = None
@@ -436,6 +440,86 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def normalize_attn(cls, data): # pylint: disable=too-many-return-statements
# cases where both are set and already match
if data.get("attention") == AttentionBackend.eager and data.get(
"eager_attention"
):
return data
if data.get("attention") == AttentionBackend.flash and data.get(
"flash_attention"
):
return data
if data.get("attention") == AttentionBackend.s2 and data.get("s2_attention"):
return data
if data.get("attention") == AttentionBackend.sdpa and data.get("sdp_attention"):
return data
if data.get("attention") == AttentionBackend.xformers and data.get(
"xformers_attention"
):
return data
# cases where attention is set and the specific *_attention is not set
if not (
data.get("flash_attention")
or data.get("eager_attention")
or data.get("s2_attention")
or data.get("sdp_attention")
or data.get("xformers_attention")
):
if data.get("attention") == AttentionBackend.eager:
data["eager_attention"] = True
elif data.get("attention") == AttentionBackend.flash:
data["flash_attention"] = True
elif data.get("attention") == AttentionBackend.s2:
data["s2_attention"] = True
elif data.get("attention") == AttentionBackend.sdpa:
data["sdp_attention"] = True
elif data.get("attention") == AttentionBackend.xformers:
data["xformers_attention"] = True
return data
# attention should always be set since that's a requirement, defaults to flash
if (
data.get("eager_attention")
and not data.get("attention") == AttentionBackend.eager
):
raise ValueError("attention mismatch with eager_attention already set")
if (
data.get("flash_attention")
and not data.get("attention") == AttentionBackend.flash
):
raise ValueError("attention mismatch with flash_attention already set")
if (
data.get("s2_attention")
and not data.get("attention") == AttentionBackend.s2
):
raise ValueError("attention mismatch with s2_attention already set")
if (
data.get("sdp_attention")
and not data.get("attention") == AttentionBackend.sdpa
):
raise ValueError("attention mismatch with sdp_attention already set")
if (
data.get("xformers_attention")
and not data.get("attention") == AttentionBackend.xformers
):
raise ValueError("attention mismatch with xformers_attention already set")
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@model_validator(mode="before")
@classmethod
# pylint: disable=duplicate-code

Some files were not shown because too many files have changed in this diff Show More