Compare commits

..

11 Commits

Author SHA1 Message Date
mhenrhcsen
380921ee56 Update ModelLoader to set default vocab_size if not defined in model config, enhancing compatibility with tokenizer defaults. 2025-07-17 19:53:41 +02:00
mhenrhcsen
6e71819560 Update ModelLoader to set vocab_size for GraniteSpeechConfig if not defined, ensuring compatibility with tokenizer defaults. 2025-07-17 19:49:10 +02:00
mhenrhcsen
ea234afa8a Enhance model loading logic to include support for GraniteSpeechConfig, allowing for the use of the specific model class for Granite Speech. 2025-07-17 19:45:23 +02:00
mhenrhcsen
738adb2258 fixes 2025-07-16 21:50:56 +02:00
mhenrhcsen
f40e8caa28 checks 2025-07-16 21:30:01 +02:00
mhenrhcsen
f9bdf1fb44 checks 2025-07-16 21:23:25 +02:00
mhenrhcsen
2f670a5988 Fix: Update model loading logic to conditionally upcast based on lm_head presence for btlm models 2025-07-16 21:16:47 +02:00
mhenrhcsen
84ad69afad Fix: Ensure tie_weights method is called only if it exists in the model class 2025-07-16 21:03:49 +02:00
Wing Lian
36cbe13d18 activation offloading with cuda streams doesn't work with LoRA (#2927) 2025-07-16 11:59:20 -04:00
Wing Lian
2c408b5c5e Apply generic fused liger ce, cce, and tiledmlp for arbitrary models (#2908)
* Apply generic fused liger ce for unknown models

* fix deepseek liger modeling

* generic cce and config tiled mlp to use original mlp and auto detect compute params

* fix weight and lint

* update warnings

* address PR feedback

* use lookup for model class prefixes

* revert inadvertent change to flash attn verison

* remove un-needed pylint annotations

* fix import
2025-07-15 22:40:41 -04:00
Wing Lian
942005f526 use modal==1.0.2 for nightlies and for cli (#2925) [skip ci]
* use modal==1.0.2 for nightlies and for cli

* use latest cce fork for upstream changes

* increase timeout
2025-07-15 20:31:23 -04:00
164 changed files with 1140 additions and 240 deletions

View File

@@ -92,7 +92,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60
timeout-minutes: 120
needs: [pre-commit, pytest]
strategy:
@@ -116,7 +116,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -26,3 +26,5 @@ timeout: 86400
# Preprocess specific configurations
memory_preprocess: 32
timeout_preprocess: 14400
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -35,7 +35,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -56,3 +55,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19\""
]
},
{

View File

@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -79,3 +79,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -46,7 +46,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -69,3 +68,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -50,3 +50,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -66,3 +66,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -62,3 +62,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -54,3 +54,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -46,3 +46,5 @@ evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -67,3 +67,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -58,3 +58,5 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -57,3 +57,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -77,3 +77,5 @@ fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -72,3 +72,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -42,3 +42,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -71,3 +71,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -83,3 +83,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -65,3 +65,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ special_tokens:
use_ray: true
ray_num_workers: 4
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -63,3 +63,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -57,3 +57,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -54,3 +54,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -75,3 +75,5 @@ llmcompressor:
]
start: 0
save_compressed: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -86,3 +86,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -90,3 +90,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -83,3 +83,5 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -86,3 +86,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -84,3 +84,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -82,3 +82,5 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -87,3 +87,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -70,3 +70,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_activation_checkpointing: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -61,3 +61,5 @@ flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -48,3 +48,5 @@ weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -43,3 +43,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -80,3 +80,5 @@ weight_decay: 0.0
special_tokens:
bos_token: "<|im_start|>"
eos_token: "<|im_end|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -74,3 +74,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -69,3 +69,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -72,3 +72,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -77,3 +77,5 @@ fsdp_config:
fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -81,3 +81,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -51,3 +51,5 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -50,3 +50,5 @@ weight_decay: 0.05
special_tokens:
pad_token: <custom_token_7>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -63,3 +63,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 4
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -60,3 +60,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -71,3 +71,5 @@ fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -59,3 +59,5 @@ warmup_ratio: 0.2
debug: true
weight_decay: 0.1
resize_token_embeddings_to_32x: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <pad>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -54,3 +54,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -55,3 +55,5 @@ eval_steps: 100
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -67,3 +67,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -26,7 +26,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
@@ -50,3 +49,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -67,3 +67,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -76,3 +76,5 @@ fsdp_config:
fsdp_activation_checkpointing: true
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -66,3 +66,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -26,7 +26,7 @@ hf_transfer
sentencepiece
gradio==5.23.3
modal==0.70.5
modal==1.0.2
pydantic==2.10.6
addict
fire

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"'
)

View File

@@ -36,6 +36,7 @@ from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -135,6 +136,8 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
if self.cfg.save_first_step:
callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg))

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"
```
## Usage

View File

@@ -19,11 +19,13 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
from functools import partial
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
@@ -32,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"`'
)
@@ -84,6 +86,7 @@ class CutCrossEntropyPlugin(BasePlugin):
"""Apply cut cross entropy before model loading if enabled."""
if cfg.cut_cross_entropy:
self._check_requirements()
self.patch_llama_like(cfg.model_config_type)
from cut_cross_entropy.transformers.patch import cce_patch
@@ -93,3 +96,48 @@ class CutCrossEntropyPlugin(BasePlugin):
# The patch checks model_type internally
cce_patch(cfg.model_config_type)
def patch_llama_like(
self,
model_type: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
"""
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model, patch_options, model_type: str
): # pylint: disable=unused-argument
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
try:
# Dynamically import the module and CausalLM class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(
module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]
)
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
patch_options
)
model_cls.forward = cce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
if model_type not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -22,6 +22,8 @@ except ImportError:
TransformersKwargs,
)
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
def kldiv_forward_llama_like(
self,
@@ -97,7 +99,7 @@ def kldiv_forward_llama_like(
def apply_kernel(model_type):
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = kldiv_forward_llama_like

View File

@@ -18,170 +18,10 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import sys
from .args import LigerArgs
from .plugin import LigerPlugin
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
LOG = get_logger(__name__)
class LigerPlugin(BasePlugin):
"""
Plugin for LIGER integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.torch_compile:
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
import liger_kernel.ops.fused_linear_cross_entropy
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_forward",
)
patch_with_compile_disable(
liger_kernel.ops.fused_linear_cross_entropy,
"fused_linear_cross_entropy_backward",
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
raise ValueError(
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs["fused_linear_cross_entropy"] = (
cfg.liger_fused_linear_cross_entropy
)
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
from .models.jamba import lce_forward as jamba_lce_forward
if cfg.liger_rope:
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_layer_norm:
modeling_jamba.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
)
modeling_mod = sys.modules[model.__class__.__module__]
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_layer_norm:
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,
)
apply_liger_kernel_to_llama4(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3":
from axolotl.integrations.liger.models.qwen3 import (
apply_liger_kernel_to_qwen3,
)
apply_liger_kernel_to_qwen3(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
)
apply_liger_kernel_to_qwen3_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite
apply_liger_kernel_to_granite(
rope=cfg.liger_rope,
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
rms_norm=cfg.liger_rms_norm,
swiglu=cfg.liger_glu_activation,
)
else:
LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)
__all__ = [
"LigerArgs",
"LigerPlugin",
]

Some files were not shown because too many files have changed in this diff Show More