From bbd3486f57ab7894ecf8db62527c1d28a61d22fc Mon Sep 17 00:00:00 2001 From: salman Date: Fri, 19 Dec 2025 16:43:47 +0100 Subject: [PATCH] Distributed Muon Optimizer (#3264) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * init * working * updating configs * removing unneeded files * lint * comments * lint * fix regex match * bump contribs version * comments * fixing tests and imports * muon imports in test v2 * test cleanup * bump contribs version --------- Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”> --- examples/qwen2/adamw-pretrain-fsdp2.yaml | 70 ++++++++ examples/qwen2/muon-pretrain-fsdp2.yaml | 70 ++++++++ requirements.txt | 3 +- src/axolotl/core/builders/base.py | 19 ++- src/axolotl/utils/schemas/validation.py | 87 +++++----- tests/core/test_builders.py | 12 +- tests/e2e/multigpu/test_dist_muon_fsdp2.py | 168 ++++++++++++++++++++ tests/test_validation_dataset.py | 2 +- tests/utils/schemas/validation/test_fsdp.py | 11 ++ 9 files changed, 387 insertions(+), 55 deletions(-) create mode 100644 examples/qwen2/adamw-pretrain-fsdp2.yaml create mode 100644 examples/qwen2/muon-pretrain-fsdp2.yaml create mode 100644 tests/e2e/multigpu/test_dist_muon_fsdp2.py diff --git a/examples/qwen2/adamw-pretrain-fsdp2.yaml b/examples/qwen2/adamw-pretrain-fsdp2.yaml new file mode 100644 index 000000000..43fb17aab --- /dev/null +++ b/examples/qwen2/adamw-pretrain-fsdp2.yaml @@ -0,0 +1,70 @@ +base_model: Qwen/Qwen2.5-0.5B +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +# Use random initialization for fair comparison +reinit_weights: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +# Pretraining dataset +pretraining_dataset: + - path: allenai/c4 + name: en + type: pretrain + split: train + +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./outputs/compare-adamw-pretrain + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: dist_muon +wandb_entity: +wandb_watch: +wandb_name: adamw +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 4 +num_epochs: 1 +max_steps: 305 + +# AdamW optimizer settings (standard LR for AdamW) +optimizer: adamw_torch_fused +learning_rate: 0.0002 +weight_decay: 0.01 +lr_scheduler: cosine + +train_on_inputs: true +group_by_length: false +bf16: auto +fp16: false +tf32: false + +gradient_checkpointing: false +logging_steps: 1 +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 0 +saves_per_epoch: 1 + +# Reproducibility +seed: 42 + +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: false + fsdp_reshard_after_forward: true + +special_tokens: diff --git a/examples/qwen2/muon-pretrain-fsdp2.yaml b/examples/qwen2/muon-pretrain-fsdp2.yaml new file mode 100644 index 000000000..35c0b71f4 --- /dev/null +++ b/examples/qwen2/muon-pretrain-fsdp2.yaml @@ -0,0 +1,70 @@ +base_model: Qwen/Qwen2.5-0.5B +model_type: AutoModelForCausalLM +tokenizer_type: AutoTokenizer + +# Use random initialization for fair comparison +reinit_weights: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +# Pretraining dataset +pretraining_dataset: + - path: allenai/c4 + name: en + type: pretrain + split: train + +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./outputs/compare-muon-pretrain + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +wandb_project: dist_muon +wandb_entity: +wandb_watch: +wandb_name: muon +wandb_log_model: + +gradient_accumulation_steps: 1 +micro_batch_size: 4 +num_epochs: 1 +max_steps: 305 + +# Muon optimizer settings +optimizer: muon +learning_rate: 0.02 +weight_decay: 0.01 +lr_scheduler: cosine + +train_on_inputs: true +group_by_length: false +bf16: auto +fp16: false +tf32: false + +gradient_checkpointing: false +logging_steps: 1 +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 0 +saves_per_epoch: 1 + +# Reproducibility +seed: 42 + +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: false + fsdp_reshard_after_forward: true + +special_tokens: diff --git a/requirements.txt b/requirements.txt index 0989325ac..093546815 100644 --- a/requirements.txt +++ b/requirements.txt @@ -67,8 +67,7 @@ openenv-core==0.1.0 schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.7 -axolotl-contribs-mit==0.0.5 - +axolotl-contribs-mit==0.0.6 # telemetry posthog==6.7.11 diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 0d19b369f..06d15ffc8 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -281,11 +281,22 @@ class TrainerBuilderBase(abc.ABC): adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon") if self.cfg.optimizer == "muon": - from axolotl.contribs.mit.muon import ( - MuonOptimizerFactory, - ) + _, device_mesh = build_parallelism_config(self.cfg) + + if device_mesh is not None: + from axolotl.contribs.mit.muon.dist_muon import ( + DistMuonOptimizerFactory, + ) + + optimizer_cls = DistMuonOptimizerFactory + optimizer_kwargs["device_mesh"] = device_mesh + else: + from axolotl.contribs.mit.muon import ( + MuonOptimizerFactory, + ) + + optimizer_cls = MuonOptimizerFactory - optimizer_cls = MuonOptimizerFactory optimizer_kwargs.update(adam_kwargs) elif self.cfg.optimizer == "dion": from axolotl.contribs.mit.dion import ( diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 368976831..36565fb03 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -751,12 +751,19 @@ class OptimizationValidationMixin: @model_validator(mode="before") @classmethod def check_muon_deepspeed_fsdp(cls, data): - if data.get("optimizer") == "muon" and ( - data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config") - ): - raise ValueError( - "Muon optimizer is currently incompatible with DeepSpeed and FSDP" - ) + if data.get("optimizer") == "muon": + if data.get("deepspeed"): + raise ValueError( + "Muon optimizer is currently incompatible with DeepSpeed" + ) + if data.get("fsdp") or data.get("fsdp_config"): + fsdp_version = data.get("fsdp_version") + if fsdp_version is None: + fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1) + if str(fsdp_version) != "2": + raise ValueError( + "Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP." + ) return data @model_validator(mode="before") @@ -840,40 +847,6 @@ class OptimizationValidationMixin: return data - @model_validator(mode="before") - @classmethod - def check_fsdp_version_in_fsdp_config(cls, data): - fsdp_config = data.get("fsdp_config") or {} - if fsdp_config and fsdp_config.get("fsdp_version"): - LOG.warning( - "Configuring `fsdp_version` in `fsdp_config` is deprecated. " - "Please configure `fsdp_version` as a top-level field." - ) - data["fsdp_version"] = fsdp_config.pop("fsdp_version") - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp_config_kwargs_prefix(cls, data): - if fsdp_config := data.get("fsdp_config"): - should_fix = False - for key, _ in fsdp_config.items(): - if key.startswith("fsdp_"): - should_fix = True - LOG.warning_once( - "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " - "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." - ) - if should_fix: - update_fsdp_config = {} - for key, value in fsdp_config.items(): - if key.startswith("fsdp_") and key != "fsdp_version": - update_fsdp_config[key.replace("fsdp_", "")] = value - else: - update_fsdp_config[key] = value - data["fsdp_config"] = update_fsdp_config - return data - @model_validator(mode="after") def check_fsdp_offload_w_8bit_optimizer(self): if ( @@ -975,6 +948,40 @@ class OptimizationValidationMixin: return data + @model_validator(mode="before") + @classmethod + def check_fsdp_version_in_fsdp_config(cls, data): + fsdp_config = data.get("fsdp_config") or {} + if fsdp_config and fsdp_config.get("fsdp_version"): + LOG.warning( + "Configuring `fsdp_version` in `fsdp_config` is deprecated. " + "Please configure `fsdp_version` as a top-level field." + ) + data["fsdp_version"] = fsdp_config.pop("fsdp_version") + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_config_kwargs_prefix(cls, data): + if fsdp_config := data.get("fsdp_config"): + should_fix = False + for key, _ in fsdp_config.items(): + if key.startswith("fsdp_"): + should_fix = True + LOG.warning_once( + "Configuring FSDP fields with the `fsdp_` prefix is deprecated. " + "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`." + ) + if should_fix: + update_fsdp_config = {} + for key, value in fsdp_config.items(): + if key.startswith("fsdp_") and key != "fsdp_version": + update_fsdp_config[key.replace("fsdp_", "")] = value + else: + update_fsdp_config[key] = value + data["fsdp_config"] = update_fsdp_config + return data + class SystemValidationMixin: """Validation methods related to system and hardware configuration.""" diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index 199777896..f9db4d013 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -474,10 +474,8 @@ def rand_reward_func(prompts, completions) -> list[float]: assert trainer.optimizer_cls_and_kwargs is not None - from axolotl.contribs.mit.muon import ( - Muon, - MuonOptimizerFactory, - ) + from axolotl.contribs.mit.muon import MuonOptimizerFactory + from axolotl.contribs.mit.muon.muon import Muon optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs assert optimizer_cls is MuonOptimizerFactory @@ -556,10 +554,8 @@ class TestHFCausalTrainerBuilder: assert trainer.optimizer_cls_and_kwargs is not None - from axolotl.contribs.mit.muon import ( - Muon, - MuonOptimizerFactory, - ) + from axolotl.contribs.mit.muon import MuonOptimizerFactory + from axolotl.contribs.mit.muon.muon import Muon optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs assert optimizer_cls is MuonOptimizerFactory diff --git a/tests/e2e/multigpu/test_dist_muon_fsdp2.py b/tests/e2e/multigpu/test_dist_muon_fsdp2.py new file mode 100644 index 000000000..93db473a9 --- /dev/null +++ b/tests/e2e/multigpu/test_dist_muon_fsdp2.py @@ -0,0 +1,168 @@ +"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality.""" + +import os +from pathlib import Path + +import torch +import yaml +from accelerate.test_utils import execute_subprocess_async +from tbparse import SummaryReader +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0 + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +def verify_training_success(temp_dir): + """Verify that training completed successfully by checking artifacts and loss.""" + output_path = Path(temp_dir) + + model_files = list(output_path.glob("*.bin")) + list( + output_path.glob("*.safetensors") + ) + assert len(model_files) > 0, "No model files found - training may have failed" + + checkpoint_files = list(output_path.glob("checkpoint-*")) + assert len(checkpoint_files) > 0, ( + "No checkpoint files found - training may have failed" + ) + + tb_log_path = most_recent_subdir(temp_dir + "/runs") + if tb_log_path: + event_files = sorted(os.listdir(tb_log_path)) + if event_files: + event_file = os.path.join(tb_log_path, event_files[0]) + reader = SummaryReader(event_file) + df = reader.scalars + train_loss_df = df[df.tag == "train/train_loss"] + if len(train_loss_df) > 0: + final_loss = train_loss_df.value.values[-1] + assert not torch.isnan(torch.tensor(final_loss)), ( + f"Training loss is NaN: {final_loss}" + ) + + +class TestDistMuon: + """Test class for DistMuon optimizer with FSDP2 functionality.""" + + @require_torch_2_7_0 + def test_fft_sft(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.02, + "optimizer": "muon", + "weight_decay": 0.01, + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) + + @require_torch_2_7_0 + def test_lora_sft(self, temp_dir): + cfg = DictDefault( + { + "base_model": "Qwen/Qwen2.5-0.5B", + "sequence_len": 2048, + "val_set_size": 0.01, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + "split": "train[:10%]", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "num_epochs": 1, + "max_steps": 2, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.02, + "optimizer": "muon", + "weight_decay": 0.01, + "lr_scheduler": "cosine", + "flash_attention": True, + "fsdp_version": 2, + "fsdp_config": { + "offload_params": False, + "cpu_ram_efficient_loading": False, + "transformer_layer_cls_to_wrap": "Qwen2DecoderLayer", + "state_dict_type": "FULL_STATE_DICT", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "reshard_after_forward": True, + }, + "use_tensorboard": True, + "bf16": True, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + verify_training_success(temp_dir) diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index 3d3b5db96..464812a90 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -363,5 +363,5 @@ class TestOptimizerValidation(BaseValidation): } ) - with pytest.raises(ValueError, match=r".*is currently incompatible with*"): + with pytest.raises(ValueError, match=r".*only compatible with FSDP2.*"): validate_config(cfg) diff --git a/tests/utils/schemas/validation/test_fsdp.py b/tests/utils/schemas/validation/test_fsdp.py index 65f9c66a3..9fa327797 100644 --- a/tests/utils/schemas/validation/test_fsdp.py +++ b/tests/utils/schemas/validation/test_fsdp.py @@ -123,6 +123,17 @@ class TestFSDPValidation: assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer" assert cfg.fsdp_config.reshard_after_forward is True + def test_muon_fsdp1_rejected(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + optimizer="muon", + fsdp_version=1, + fsdp_config={"reshard_after_forward": True}, + ) + with pytest.raises( + ValueError, match="Muon optimizer is only compatible with FSDP2" + ): + validate_config(cfg) + @pytest.mark.parametrize( "rl", [