Compare commits

...

4 Commits

Author SHA1 Message Date
NanoCode012
ef883b6960 chore: refactor normalize_attn to use mapping and loop 2025-05-07 17:10:18 +07:00
NanoCode012
d0c4930dd5 fix: set replit mpt model to use eager attention 2025-05-07 17:10:18 +07:00
Wing Lian
6ee7cb30fa fixes from PR feedback 2025-05-07 17:10:18 +07:00
Wing Lian
ba47adc24b replace attention in the yaml config with an enum 2025-05-07 17:10:18 +07:00
101 changed files with 247 additions and 122 deletions

View File

@@ -59,9 +59,7 @@ gradient_checkpointing: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
sdp_attention:
flash_optimum:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:

View File

@@ -39,8 +39,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
xformers_attention: true attention: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -49,7 +49,8 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -112,9 +112,7 @@
"early_stopping_patience:\n", "early_stopping_patience:\n",
"resume_from_checkpoint:\n", "resume_from_checkpoint:\n",
"logging_steps: 1\n", "logging_steps: 1\n",
"xformers_attention:\n", "attention: sdpa\n",
"flash_attention: false\n",
"sdp_attention: true\n",
"\n", "\n",
"warmup_steps: 1\n", "warmup_steps: 1\n",
"max_steps: 25\n", "max_steps: 25\n",

View File

@@ -52,7 +52,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: evals_per_epoch:

View File

@@ -55,7 +55,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: evals_per_epoch:

View File

@@ -39,7 +39,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: evals_per_epoch:

View File

@@ -35,7 +35,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 2 evals_per_epoch: 2

View File

@@ -59,7 +59,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 2 evals_per_epoch: 2

View File

@@ -43,8 +43,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
xformers_attention: true attention: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 40 warmup_steps: 40

View File

@@ -73,8 +73,7 @@ early_stopping_patience: 3
resume_from_checkpoint: resume_from_checkpoint:
auto_resume_from_checkpoints: true auto_resume_from_checkpoints: true
logging_steps: 1 logging_steps: 1
xformers_attention: true attention: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10

View File

@@ -40,8 +40,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
xformers_attention: true attention: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 40 warmup_steps: 40

View File

@@ -47,7 +47,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -53,7 +53,8 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -43,7 +43,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -57,7 +57,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -51,8 +51,7 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
eager_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -53,8 +53,7 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
eager_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -36,8 +36,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
xformers_attention: true attention: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10

View File

@@ -47,7 +47,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: evals_per_epoch:

View File

@@ -46,7 +46,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: evals_per_epoch:

View File

@@ -45,7 +45,8 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:
use_reentrant: true use_reentrant: true
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -37,8 +37,7 @@ bf16: auto
tf32: true tf32: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 5 logging_steps: 5
xformers_attention: true attention: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20

View File

@@ -42,7 +42,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
flash_attn_cross_entropy: false flash_attn_cross_entropy: false
flash_attn_rms_norm: true flash_attn_rms_norm: true
flash_attn_fuse_qkv: false flash_attn_fuse_qkv: false

View File

@@ -53,9 +53,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: attention: flash
sdp_attention:
flash_optimum:
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
flash_attn_cross_entropy: false flash_attn_cross_entropy: false
flash_attn_rms_norm: true flash_attn_rms_norm: true
flash_attn_fuse_qkv: false flash_attn_fuse_qkv: false

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ gradient_checkpointing_kwargs:
use_reentrant: true use_reentrant: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -50,8 +50,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
eager_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -49,7 +49,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 2 evals_per_epoch: 2

View File

@@ -34,7 +34,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 2 evals_per_epoch: 2

View File

@@ -61,7 +61,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -56,7 +56,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -77,7 +77,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -53,7 +53,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -54,7 +54,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -48,7 +48,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -55,7 +55,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -49,7 +49,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -53,7 +53,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -51,7 +51,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -39,7 +39,8 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs: gradient_checkpointing_kwargs:
use_reentrant: true use_reentrant: true
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ gradient_checkpointing_kwargs:
use_reentrant: true use_reentrant: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,8 +46,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
eager_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -39,7 +39,7 @@ tf32: true
gradient_checkpointing: false gradient_checkpointing: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: attention: eager
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -42,7 +42,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
save_total_limit: 1 save_total_limit: 1
save_steps: save_steps:

View File

@@ -36,7 +36,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -53,8 +53,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: false attention: sdpa
sdp_attention: true
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -54,7 +54,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -71,7 +71,7 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: false attention: eager
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -51,7 +51,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -59,7 +59,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -48,9 +48,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
logging_steps: 1 logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet. attention: eager # PixtralVisionModel does not support Flash Attention 2.0 yet.
eager_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1

View File

@@ -49,7 +49,8 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -51,7 +51,8 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -69,7 +69,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -40,7 +40,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
save_total_limit: 1 save_total_limit: 1
save_steps: save_steps:

View File

@@ -54,7 +54,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
loss_watchdog_threshold: 5.0 loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3 loss_watchdog_patience: 3

View File

@@ -39,7 +39,7 @@ bf16: auto
tf32: true tf32: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 5 logging_steps: 5
flash_attention: attention: eager
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20

View File

@@ -39,7 +39,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20

View File

@@ -47,7 +47,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20

View File

@@ -40,7 +40,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20

View File

@@ -48,7 +48,8 @@ gradient_checkpointing_kwargs:
use_reentrant: True use_reentrant: True
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -51,7 +51,8 @@ gradient_checkpointing_kwargs:
use_reentrant: True use_reentrant: True
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -48,7 +48,8 @@ gradient_checkpointing_kwargs:
use_reentrant: True use_reentrant: True
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -49,7 +49,8 @@ gradient_checkpointing_kwargs:
use_reentrant: true use_reentrant: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 100 warmup_steps: 100
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -44,7 +44,8 @@ gradient_checkpointing_kwargs:
use_reentrant: True use_reentrant: True
early_stopping_patience: 3 early_stopping_patience: 3
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
eval_steps: 1000 eval_steps: 1000
save_steps: 5000 save_steps: 5000

View File

@@ -46,8 +46,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
logging_steps: 1 logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet attention: eager # PixtralVisionModel does not support Flash Attention 2.0 yet
eager_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: false gradient_checkpointing: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: attention: eager
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: false gradient_checkpointing: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -43,7 +43,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,7 +46,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -46,8 +46,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
eager_attention:
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1

View File

@@ -49,7 +49,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -47,7 +47,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -47,7 +47,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -43,7 +43,8 @@ gradient_checkpointing_kwargs:
use_reentrant: false use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: evals_per_epoch:

View File

@@ -40,7 +40,7 @@ bf16: auto
tf32: true tf32: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 5 logging_steps: 5
flash_attention: attention: flash
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20

View File

@@ -38,7 +38,7 @@ tf32: true
gradient_checkpointing: gradient_checkpointing:
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: attention: eager
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 20 warmup_steps: 20

View File

@@ -44,7 +44,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
flash_attn_cross_entropy: false flash_attn_cross_entropy: false
flash_attn_rms_norm: true flash_attn_rms_norm: true
flash_attn_fuse_qkv: false flash_attn_fuse_qkv: false

View File

@@ -47,7 +47,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
flash_attn_cross_entropy: false flash_attn_cross_entropy: false
flash_attn_rms_norm: true flash_attn_rms_norm: true

View File

@@ -46,7 +46,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 20 warmup_steps: 20
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -47,7 +47,7 @@ tf32: true
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: false attention: eager
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 0 evals_per_epoch: 0

View File

@@ -45,7 +45,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -36,7 +36,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: evals_per_epoch:

View File

@@ -47,7 +47,8 @@ tf32: false
gradient_checkpointing: true gradient_checkpointing: true
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true attention: flash
warmup_steps: 10 warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4

View File

@@ -71,8 +71,7 @@ early_stopping_patience: 3
resume_from_checkpoint: resume_from_checkpoint:
auto_resume_from_checkpoints: true auto_resume_from_checkpoints: true
logging_steps: 1 logging_steps: 1
xformers_attention: true attention: xformers
flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10

View File

@@ -10,7 +10,8 @@ load_in_4bit: true
sequence_len: 1024 sequence_len: 1024
bf16: auto bf16: auto
tf32: false tf32: false
flash_attention: true attention: flash
special_tokens: special_tokens:
bos_token: "<|startoftext|>" bos_token: "<|startoftext|>"
eos_token: "<|endoftext|>" eos_token: "<|endoftext|>"

View File

@@ -27,7 +27,7 @@ from axolotl.utils.schemas.datasets import (
StepwiseSupervisedDataset, StepwiseSupervisedDataset,
) )
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RLType from axolotl.utils.schemas.enums import AttentionBackend, ChatTemplate, RLType
from axolotl.utils.schemas.integrations import ( from axolotl.utils.schemas.integrations import (
CometConfig, CometConfig,
GradioConfig, GradioConfig,
@@ -222,6 +222,10 @@ class AxolotlInputConfig(
}, },
) )
attention: AttentionBackend | None = Field(
default=None,
json_schema_extra={"description": "attention backend to use"},
)
xformers_attention: bool | None = None xformers_attention: bool | None = None
sdp_attention: bool | None = None sdp_attention: bool | None = None
s2_attention: bool | None = None s2_attention: bool | None = None
@@ -436,6 +440,65 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def normalize_attn(cls, data): # pylint: disable=too-many-return-statements
attention = data.get("attention")
# Define mapping between enum values and flag names
backend_mapping = {
AttentionBackend.eager: "eager_attention",
AttentionBackend.flash: "flash_attention",
AttentionBackend.flex: "flex_attention",
AttentionBackend.s2: "s2_attention",
AttentionBackend.sdpa: "sdp_attention",
AttentionBackend.xformers: "xformers_attention",
}
# Check if any attention flag is set
any_flag_set = any(
data.get(flag_name) for flag_name in backend_mapping.values()
)
# CASE 1: attention is set but no flags are set - set the corresponding flag
if attention and not any_flag_set:
flag_name = backend_mapping.get(attention)
if flag_name:
data[flag_name] = True
return data
# CASE 2: no attention set but flags are set - set attention from flags
if not attention:
LOG.warning(
"*_attention will be deprecated soon. One of `attention: eager | flash | flex | s2 | sdp | xformers` is recommended"
)
# Find the first True flag and set attention accordingly
for backend, flag_name in backend_mapping.items():
if data.get(flag_name):
data["attention"] = backend
return data
# CASE 3: both attention and flags are set - check for consistency
if attention:
expected_flag = backend_mapping.get(attention)
for backend, flag_name in backend_mapping.items():
# If a flag is set that doesn't match the attention value
if data.get(flag_name) and flag_name != expected_flag:
raise ValueError(f"attention mismatch with {flag_name} already set")
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

Some files were not shown because too many files have changed in this diff Show More