Distributed Muon Optimizer (#3264)
* init * working * updating configs * removing unneeded files * lint * comments * lint * fix regex match * bump contribs version * comments * fixing tests and imports * muon imports in test v2 * test cleanup * bump contribs version --------- Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
This commit is contained in:
70
examples/qwen2/adamw-pretrain-fsdp2.yaml
Normal file
70
examples/qwen2/adamw-pretrain-fsdp2.yaml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-0.5B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
# Use random initialization for fair comparison
|
||||||
|
reinit_weights: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# Pretraining dataset
|
||||||
|
pretraining_dataset:
|
||||||
|
- path: allenai/c4
|
||||||
|
name: en
|
||||||
|
type: pretrain
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/compare-adamw-pretrain
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project: dist_muon
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name: adamw
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 305
|
||||||
|
|
||||||
|
# AdamW optimizer settings (standard LR for AdamW)
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
learning_rate: 0.0002
|
||||||
|
weight_decay: 0.01
|
||||||
|
lr_scheduler: cosine
|
||||||
|
|
||||||
|
train_on_inputs: true
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 0
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
70
examples/qwen2/muon-pretrain-fsdp2.yaml
Normal file
70
examples/qwen2/muon-pretrain-fsdp2.yaml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-0.5B
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
# Use random initialization for fair comparison
|
||||||
|
reinit_weights: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# Pretraining dataset
|
||||||
|
pretraining_dataset:
|
||||||
|
- path: allenai/c4
|
||||||
|
name: en
|
||||||
|
type: pretrain
|
||||||
|
split: train
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/compare-muon-pretrain
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project: dist_muon
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name: muon
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 305
|
||||||
|
|
||||||
|
# Muon optimizer settings
|
||||||
|
optimizer: muon
|
||||||
|
learning_rate: 0.02
|
||||||
|
weight_decay: 0.01
|
||||||
|
lr_scheduler: cosine
|
||||||
|
|
||||||
|
train_on_inputs: true
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 0
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_cpu_ram_efficient_loading: false
|
||||||
|
fsdp_reshard_after_forward: true
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
@@ -67,8 +67,7 @@ openenv-core==0.1.0
|
|||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.7
|
axolotl-contribs-lgpl==0.0.7
|
||||||
axolotl-contribs-mit==0.0.5
|
axolotl-contribs-mit==0.0.6
|
||||||
|
|
||||||
# telemetry
|
# telemetry
|
||||||
posthog==6.7.11
|
posthog==6.7.11
|
||||||
|
|
||||||
|
|||||||
@@ -281,11 +281,22 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
||||||
|
|
||||||
if self.cfg.optimizer == "muon":
|
if self.cfg.optimizer == "muon":
|
||||||
from axolotl.contribs.mit.muon import (
|
_, device_mesh = build_parallelism_config(self.cfg)
|
||||||
MuonOptimizerFactory,
|
|
||||||
)
|
if device_mesh is not None:
|
||||||
|
from axolotl.contribs.mit.muon.dist_muon import (
|
||||||
|
DistMuonOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_cls = DistMuonOptimizerFactory
|
||||||
|
optimizer_kwargs["device_mesh"] = device_mesh
|
||||||
|
else:
|
||||||
|
from axolotl.contribs.mit.muon import (
|
||||||
|
MuonOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_cls = MuonOptimizerFactory
|
||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "dion":
|
elif self.cfg.optimizer == "dion":
|
||||||
from axolotl.contribs.mit.dion import (
|
from axolotl.contribs.mit.dion import (
|
||||||
|
|||||||
@@ -751,12 +751,19 @@ class OptimizationValidationMixin:
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_muon_deepspeed_fsdp(cls, data):
|
def check_muon_deepspeed_fsdp(cls, data):
|
||||||
if data.get("optimizer") == "muon" and (
|
if data.get("optimizer") == "muon":
|
||||||
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
|
if data.get("deepspeed"):
|
||||||
):
|
raise ValueError(
|
||||||
raise ValueError(
|
"Muon optimizer is currently incompatible with DeepSpeed"
|
||||||
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
|
)
|
||||||
)
|
if data.get("fsdp") or data.get("fsdp_config"):
|
||||||
|
fsdp_version = data.get("fsdp_version")
|
||||||
|
if fsdp_version is None:
|
||||||
|
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||||
|
if str(fsdp_version) != "2":
|
||||||
|
raise ValueError(
|
||||||
|
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -840,40 +847,6 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
|
||||||
fsdp_config = data.get("fsdp_config") or {}
|
|
||||||
if fsdp_config and fsdp_config.get("fsdp_version"):
|
|
||||||
LOG.warning(
|
|
||||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
|
||||||
"Please configure `fsdp_version` as a top-level field."
|
|
||||||
)
|
|
||||||
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
|
||||||
if fsdp_config := data.get("fsdp_config"):
|
|
||||||
should_fix = False
|
|
||||||
for key, _ in fsdp_config.items():
|
|
||||||
if key.startswith("fsdp_"):
|
|
||||||
should_fix = True
|
|
||||||
LOG.warning_once(
|
|
||||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
|
||||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
|
||||||
)
|
|
||||||
if should_fix:
|
|
||||||
update_fsdp_config = {}
|
|
||||||
for key, value in fsdp_config.items():
|
|
||||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
|
||||||
update_fsdp_config[key.replace("fsdp_", "")] = value
|
|
||||||
else:
|
|
||||||
update_fsdp_config[key] = value
|
|
||||||
data["fsdp_config"] = update_fsdp_config
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_fsdp_offload_w_8bit_optimizer(self):
|
def check_fsdp_offload_w_8bit_optimizer(self):
|
||||||
if (
|
if (
|
||||||
@@ -975,6 +948,40 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||||
|
fsdp_config = data.get("fsdp_config") or {}
|
||||||
|
if fsdp_config and fsdp_config.get("fsdp_version"):
|
||||||
|
LOG.warning(
|
||||||
|
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||||
|
"Please configure `fsdp_version` as a top-level field."
|
||||||
|
)
|
||||||
|
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||||
|
if fsdp_config := data.get("fsdp_config"):
|
||||||
|
should_fix = False
|
||||||
|
for key, _ in fsdp_config.items():
|
||||||
|
if key.startswith("fsdp_"):
|
||||||
|
should_fix = True
|
||||||
|
LOG.warning_once(
|
||||||
|
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||||
|
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||||
|
)
|
||||||
|
if should_fix:
|
||||||
|
update_fsdp_config = {}
|
||||||
|
for key, value in fsdp_config.items():
|
||||||
|
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||||
|
update_fsdp_config[key.replace("fsdp_", "")] = value
|
||||||
|
else:
|
||||||
|
update_fsdp_config[key] = value
|
||||||
|
data["fsdp_config"] = update_fsdp_config
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class SystemValidationMixin:
|
class SystemValidationMixin:
|
||||||
"""Validation methods related to system and hardware configuration."""
|
"""Validation methods related to system and hardware configuration."""
|
||||||
|
|||||||
@@ -474,10 +474,8 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
|||||||
|
|
||||||
assert trainer.optimizer_cls_and_kwargs is not None
|
assert trainer.optimizer_cls_and_kwargs is not None
|
||||||
|
|
||||||
from axolotl.contribs.mit.muon import (
|
from axolotl.contribs.mit.muon import MuonOptimizerFactory
|
||||||
Muon,
|
from axolotl.contribs.mit.muon.muon import Muon
|
||||||
MuonOptimizerFactory,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
||||||
assert optimizer_cls is MuonOptimizerFactory
|
assert optimizer_cls is MuonOptimizerFactory
|
||||||
@@ -556,10 +554,8 @@ class TestHFCausalTrainerBuilder:
|
|||||||
|
|
||||||
assert trainer.optimizer_cls_and_kwargs is not None
|
assert trainer.optimizer_cls_and_kwargs is not None
|
||||||
|
|
||||||
from axolotl.contribs.mit.muon import (
|
from axolotl.contribs.mit.muon import MuonOptimizerFactory
|
||||||
Muon,
|
from axolotl.contribs.mit.muon.muon import Muon
|
||||||
MuonOptimizerFactory,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
||||||
assert optimizer_cls is MuonOptimizerFactory
|
assert optimizer_cls is MuonOptimizerFactory
|
||||||
|
|||||||
168
tests/e2e/multigpu/test_dist_muon_fsdp2.py
Normal file
168
tests/e2e/multigpu/test_dist_muon_fsdp2.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from tbparse import SummaryReader
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||||
|
|
||||||
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def verify_training_success(temp_dir):
|
||||||
|
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||||
|
output_path = Path(temp_dir)
|
||||||
|
|
||||||
|
model_files = list(output_path.glob("*.bin")) + list(
|
||||||
|
output_path.glob("*.safetensors")
|
||||||
|
)
|
||||||
|
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||||
|
|
||||||
|
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||||
|
assert len(checkpoint_files) > 0, (
|
||||||
|
"No checkpoint files found - training may have failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||||
|
if tb_log_path:
|
||||||
|
event_files = sorted(os.listdir(tb_log_path))
|
||||||
|
if event_files:
|
||||||
|
event_file = os.path.join(tb_log_path, event_files[0])
|
||||||
|
reader = SummaryReader(event_file)
|
||||||
|
df = reader.scalars
|
||||||
|
train_loss_df = df[df.tag == "train/train_loss"]
|
||||||
|
if len(train_loss_df) > 0:
|
||||||
|
final_loss = train_loss_df.value.values[-1]
|
||||||
|
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||||
|
f"Training loss is NaN: {final_loss}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDistMuon:
|
||||||
|
"""Test class for DistMuon optimizer with FSDP2 functionality."""
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_fft_sft(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.02,
|
||||||
|
"optimizer": "muon",
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_sft(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.01,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.02,
|
||||||
|
"optimizer": "muon",
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_training_success(temp_dir)
|
||||||
@@ -363,5 +363,5 @@ class TestOptimizerValidation(BaseValidation):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
with pytest.raises(ValueError, match=r".*only compatible with FSDP2.*"):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|||||||
@@ -123,6 +123,17 @@ class TestFSDPValidation:
|
|||||||
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
|
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
|
||||||
assert cfg.fsdp_config.reshard_after_forward is True
|
assert cfg.fsdp_config.reshard_after_forward is True
|
||||||
|
|
||||||
|
def test_muon_fsdp1_rejected(self, min_base_cfg):
|
||||||
|
cfg = min_base_cfg | DictDefault(
|
||||||
|
optimizer="muon",
|
||||||
|
fsdp_version=1,
|
||||||
|
fsdp_config={"reshard_after_forward": True},
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Muon optimizer is only compatible with FSDP2"
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rl",
|
"rl",
|
||||||
[
|
[
|
||||||
|
|||||||
Reference in New Issue
Block a user