Compare commits

..

5 Commits

Author SHA1 Message Date
Wing Lian
4dc75cc713 Merge branch 'main' into kwargs-refactor 2025-07-16 17:10:26 -04:00
Wing Lian
36cbe13d18 activation offloading with cuda streams doesn't work with LoRA (#2927) 2025-07-16 11:59:20 -04:00
Wing Lian
2c408b5c5e Apply generic fused liger ce, cce, and tiledmlp for arbitrary models (#2908)
* Apply generic fused liger ce for unknown models

* fix deepseek liger modeling

* generic cce and config tiled mlp to use original mlp and auto detect compute params

* fix weight and lint

* update warnings

* address PR feedback

* use lookup for model class prefixes

* revert inadvertent change to flash attn verison

* remove un-needed pylint annotations

* fix import
2025-07-15 22:40:41 -04:00
Wing Lian
942005f526 use modal==1.0.2 for nightlies and for cli (#2925) [skip ci]
* use modal==1.0.2 for nightlies and for cli

* use latest cce fork for upstream changes

* increase timeout
2025-07-15 20:31:23 -04:00
Wing Lian
9394983633 fix for upstream refactor of KwargsForCausalLM 2025-07-12 13:57:36 -04:00
162 changed files with 1016 additions and 206 deletions

View File

@@ -92,7 +92,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 60 timeout-minutes: 120
needs: [pre-commit, pytest] needs: [pre-commit, pytest]
strategy: strategy:
@@ -116,7 +116,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==1.0.2 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -26,3 +26,5 @@ timeout: 86400
# Preprocess specific configurations # Preprocess specific configurations
memory_preprocess: 32 memory_preprocess: 32
timeout_preprocess: 14400 timeout_preprocess: 14400
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -35,7 +35,6 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4
@@ -56,3 +55,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -40,7 +40,7 @@
"%%capture\n", "%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n", "# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154\"" "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19\""
] ]
}, },
{ {

View File

@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -79,3 +79,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -46,7 +46,6 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 4
@@ -69,3 +68,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -50,3 +50,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -66,3 +66,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -62,3 +62,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -54,3 +54,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -46,3 +46,5 @@ evals_per_epoch: 2
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1 weight_decay: 0.1
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ special_tokens:
bos_token: "<s>" bos_token: "<s>"
eos_token: "</s>" eos_token: "</s>"
unk_token: "<unk>" unk_token: "<unk>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ special_tokens:
bos_token: "<s>" bos_token: "<s>"
eos_token: "</s>" eos_token: "</s>"
unk_token: "<unk>" unk_token: "<unk>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -67,3 +67,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -58,3 +58,5 @@ special_tokens:
bos_token: "<s>" bos_token: "<s>"
eos_token: "</s>" eos_token: "</s>"
unk_token: "<unk>" unk_token: "<unk>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -57,3 +57,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -77,3 +77,5 @@ fsdp_config:
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -72,3 +72,5 @@ fsdp_config:
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|> eos_token: <|eot_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -42,3 +42,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -71,3 +71,5 @@ warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -83,3 +83,5 @@ warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -65,3 +65,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: "<|end_of_text|>" pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ special_tokens:
use_ray: true use_ray: true
ray_num_workers: 4 ray_num_workers: 4
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -63,3 +63,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: "<|end_of_text|>" pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -57,3 +57,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: "<|end_of_text|>" pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: "<|end_of_text|>" pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
special_tokens: special_tokens:
pad_token: <|end_of_text|> pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -54,3 +54,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: "<|end_of_text|>" pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -75,3 +75,5 @@ llmcompressor:
] ]
start: 0 start: 0
save_compressed: true save_compressed: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -86,3 +86,5 @@ fsdp_config:
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|> eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -90,3 +90,5 @@ fsdp_config:
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|> eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -83,3 +83,5 @@ weight_decay: 0.0
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|> eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -86,3 +86,5 @@ fsdp_config:
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|> eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -84,3 +84,5 @@ fsdp_config:
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|> eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -82,3 +82,5 @@ weight_decay: 0.0
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|> eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -87,3 +87,5 @@ fsdp_config:
special_tokens: special_tokens:
pad_token: <|finetune_right_pad_id|> pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|> eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -70,3 +70,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_activation_checkpointing: true fsdp_activation_checkpointing: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -61,3 +61,5 @@ flash_attention: true
warmup_ratio: 0.1 warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -48,3 +48,5 @@ weight_decay: 0.0
special_tokens: special_tokens:
tokens: tokens:
save_safetensors: False save_safetensors: False
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ special_tokens:
eos_token: "<|im_end|>" eos_token: "<|im_end|>"
tokens: tokens:
- "<|im_start|>" - "<|im_start|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -43,3 +43,5 @@ evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -80,3 +80,5 @@ weight_decay: 0.0
special_tokens: special_tokens:
bos_token: "<|im_start|>" bos_token: "<|im_start|>"
eos_token: "<|im_end|>" eos_token: "<|im_end|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -74,3 +74,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -72,3 +72,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -77,3 +77,5 @@ fsdp_config:
fsdp_forward_prefetch: false fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE fsdp_backward_prefetch: BACKWARD_PRE
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -81,3 +81,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -51,3 +51,5 @@ special_tokens:
eos_token: "<|im_end|>" eos_token: "<|im_end|>"
tokens: tokens:
- "<|im_start|>" - "<|im_start|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -50,3 +50,5 @@ weight_decay: 0.05
special_tokens: special_tokens:
pad_token: <custom_token_7> pad_token: <custom_token_7>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -63,3 +63,5 @@ warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
saves_per_epoch: 4 saves_per_epoch: 4
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true resize_token_embeddings_to_32x: true
special_tokens: special_tokens:
pad_token: "<|endoftext|>" pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true resize_token_embeddings_to_32x: true
special_tokens: special_tokens:
pad_token: "<|endoftext|>" pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true resize_token_embeddings_to_32x: true
special_tokens: special_tokens:
pad_token: "<|endoftext|>" pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -71,3 +71,5 @@ fsdp_config:
resize_token_embeddings_to_32x: true resize_token_embeddings_to_32x: true
special_tokens: special_tokens:
pad_token: "<|endoftext|>" pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -59,3 +59,5 @@ warmup_ratio: 0.2
debug: true debug: true
weight_decay: 0.1 weight_decay: 0.1
resize_token_embeddings_to_32x: true resize_token_embeddings_to_32x: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
pad_token: <pad> pad_token: <pad>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -54,3 +54,5 @@ warmup_steps: 10
evals_per_epoch: 4 evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ eval_steps: 100
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -67,3 +67,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -26,7 +26,6 @@ wandb_watch:
wandb_name: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 4
@@ -50,3 +49,5 @@ evals_per_epoch:
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1 evals_per_epoch: 1
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -67,3 +67,5 @@ evals_per_epoch: 4
saves_per_epoch: 1 saves_per_epoch: 1
weight_decay: 0.0 weight_decay: 0.0
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -76,3 +76,5 @@ fsdp_config:
fsdp_activation_checkpointing: true fsdp_activation_checkpointing: true
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -66,3 +66,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD fsdp_sharding_strategy: FULL_SHARD
special_tokens: special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -26,7 +26,7 @@ hf_transfer
sentencepiece sentencepiece
gradio==5.23.3 gradio==5.23.3
modal==0.70.5 modal==1.0.2
pydantic==2.10.6 pydantic==2.10.6
addict addict
fire fire

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print( print(
UNINSTALL_PREFIX UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"' + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"'
) )

View File

@@ -36,6 +36,7 @@ from axolotl.utils.callbacks import (
GCCallback, GCCallback,
GPUStatsCallback, GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
) )
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -135,6 +136,8 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append( callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.save_first_step:
callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg)) callbacks.append(GPUStatsCallback(cfg=self.cfg))

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip - If you are installing from pip
```bash ```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899" pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"
``` ```
## Usage ## Usage

View File

@@ -19,11 +19,13 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team. from Apple's ML team.
""" """
import importlib import importlib
from functools import partial
import torch import torch
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
@@ -32,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using " "Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`' '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"`'
) )
@@ -84,6 +86,7 @@ class CutCrossEntropyPlugin(BasePlugin):
"""Apply cut cross entropy before model loading if enabled.""" """Apply cut cross entropy before model loading if enabled."""
if cfg.cut_cross_entropy: if cfg.cut_cross_entropy:
self._check_requirements() self._check_requirements()
self.patch_llama_like(cfg.model_config_type)
from cut_cross_entropy.transformers.patch import cce_patch from cut_cross_entropy.transformers.patch import cce_patch
@@ -93,3 +96,48 @@ class CutCrossEntropyPlugin(BasePlugin):
# The patch checks model_type internally # The patch checks model_type internally
cce_patch(cfg.model_config_type) cce_patch(cfg.model_config_type)
def patch_llama_like(
self,
model_type: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
"""
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model, patch_options, model_type: str
): # pylint: disable=unused-argument
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
try:
# Dynamically import the module and CausalLM class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(
module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]
)
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
patch_options
)
model_cls.forward = cce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
if model_type not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -22,6 +22,8 @@ except ImportError:
TransformersKwargs, TransformersKwargs,
) )
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
def kldiv_forward_llama_like( def kldiv_forward_llama_like(
self, self,
@@ -97,7 +99,7 @@ def kldiv_forward_llama_like(
def apply_kernel(model_type): def apply_kernel(model_type):
# Dynamically import the module and attention class # Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}" module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")]) model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = kldiv_forward_llama_like model_cls.forward = kldiv_forward_llama_like

View File

@@ -18,170 +18,10 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training. Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight. It is designed to be performant, correct, and light-weight.
""" """
import inspect from .args import LigerArgs
import sys from .plugin import LigerPlugin
from axolotl.integrations.base import BasePlugin __all__ = [
from axolotl.utils.logging import get_logger "LigerArgs",
"LigerPlugin",
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 ]
from .utils import patch_with_compile_disable
LOG = get_logger(__name__)
class LigerPlugin(BasePlugin):
"""
Plugin for LIGER integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
)
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
raise ValueError(
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs["fused_linear_cross_entropy"] = (
cfg.liger_fused_linear_cross_entropy
)
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
from .models.jamba import lce_forward as jamba_lce_forward
if cfg.liger_rope:
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_layer_norm:
modeling_jamba.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
)
modeling_mod = sys.modules[model.__class__.__module__]
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_layer_norm:
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,
)
apply_liger_kernel_to_llama4(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3":
from axolotl.integrations.liger.models.qwen3 import (
apply_liger_kernel_to_qwen3,
)
apply_liger_kernel_to_qwen3(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
)
apply_liger_kernel_to_qwen3_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite
apply_liger_kernel_to_granite(
rope=cfg.liger_rope,
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
else:
LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)

Some files were not shown because too many files have changed in this diff Show More