From 80d5b066ecfc3ae6e97ad70d93942e816b4a9a72 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 14 Jun 2025 11:53:43 -0700 Subject: [PATCH] Fix: adding magistral fsdp config, fixing not eval with test_datasets, handle mllama attention (#2789) [skip ci] * feat: add fsdp config for magistral * fix: add mllama self attention handling for lora kernels * fix: no eval if val_set_size 0 despite having test_datasets * fix: add note for cce for vlm in newer model --- .../magistral/magistral-small-fsdp-qlora.yaml | 72 +++++++++++++++++++ src/axolotl/core/builders/base.py | 4 +- .../integrations/cut_cross_entropy/README.md | 8 +++ src/axolotl/monkeypatch/lora_kernels.py | 5 ++ 4 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 examples/magistral/magistral-small-fsdp-qlora.yaml diff --git a/examples/magistral/magistral-small-fsdp-qlora.yaml b/examples/magistral/magistral-small-fsdp-qlora.yaml new file mode 100644 index 000000000..b10e8baf6 --- /dev/null +++ b/examples/magistral/magistral-small-fsdp-qlora.yaml @@ -0,0 +1,72 @@ +base_model: mistralai/Magistral-Small-2506 + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: false +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer + fsdp_activation_checkpointing: true diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index ac49b4e88..e399cf3c5 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -380,8 +380,8 @@ class TrainerBuilderBase(abc.ABC): ) # eval_strategy and eval_steps - if not self.eval_dataset or self.cfg.val_set_size == 0: - # do not eval if no eval_dataset or val_set_size=0 + if not self.eval_dataset and self.cfg.val_set_size == 0: + # do not eval if no eval_dataset and val_set_size=0 training_args_kwargs["eval_strategy"] = "no" elif self.cfg.eval_steps: training_args_kwargs["eval_strategy"] = "steps" diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 627ebd935..bddf3ced2 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -24,6 +24,14 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform ## Usage +**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet. + +```bash +git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764 + +pip3 install --no-build-isolation -e . +``` + ```yaml plugins: - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index a7875eefe..63fbfa359 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -145,6 +145,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: return Qwen2Attention + if model_type == "mllama": + from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention + + return MllamaTextSelfAttention + try: # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}"