add muon optimizer

optimizer_cls_and_kwargs is on trainer_kwargs
only add adamw_kwargs if they're non-null
fix mocks
better handling of override and check the optimizer
unwrap optimizer
This commit is contained in:
Wing Lian
2025-02-27 23:15:41 -05:00
parent 0134093acc
commit 0542c7dd56
10 changed files with 296 additions and 158 deletions

View File

@@ -63,3 +63,4 @@ torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3 axolotl-contribs-lgpl==0.0.3
axolotl-contribs-mit==0.0.3

View File

@@ -41,11 +41,12 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta) model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
del model del model
del tokenizer del tokenizer
del trainer
plugin_manager.post_train_unload(cfg) plugin_manager.post_train_unload(cfg)

View File

@@ -35,6 +35,7 @@ from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
TrainerCallback, TrainerCallback,
) )
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import ( from axolotl.core.trainers.base import (
@@ -84,6 +85,7 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype from axolotl.utils.models import ensure_dtype
try: try:
@@ -549,28 +551,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else: else:
training_arguments_kwargs["run_name"] = None training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -656,46 +636,114 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model: if self.cfg.reward_model:
training_arguments_kwargs["max_length"] = self.cfg.sequence_len training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code # Handle custom optimizer
if self.cfg.optimizer in [ custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
"optimi_adamw", if self.cfg.optimizer in custom_supported_optimizers:
"ao_adamw_4bit", # Common optimizer kwargs
"ao_adamw_8bit", optimizer_kwargs = {
"ao_adamw_fp8", "lr": training_arguments_kwargs.get("learning_rate"),
"adopt_adamw", "weight_decay": training_arguments_kwargs.get("weight_decay"),
]: }
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
if self.cfg.optimizer == "lion_pytorch": # Adam-specific kwargs
from lion_pytorch import Lion adam_kwargs = {}
if training_arguments_kwargs.get(
"adam_beta1"
) and training_arguments_kwargs.get("adam_beta2"):
adam_kwargs["betas"] = (
training_arguments_kwargs.get("adam_beta1"),
training_arguments_kwargs.get("adam_beta2"),
)
if training_arguments_kwargs.get("adam_epsilon"):
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} if self.cfg.optimizer == "muon":
if "weight_decay" in training_arguments_kwargs: from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] MuonOptimizerFactory,
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
) )
trainer_kwargs["optimizers"] = ( optimizer_cls = MuonOptimizerFactory
Lion(params=self.model.parameters(), **lion_kwargs), optimizer_kwargs.update(adam_kwargs)
None, elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optimizer_kwargs.update(self.cfg.optim_args)
else:
# Parse string format "key1=value1,key2=value2"
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optimizer_kwargs[key] = value
trainer_kwargs["optimizer_cls_and_kwargs"] = (
optimizer_cls,
optimizer_kwargs,
) )
# Set default so transformers doesn't throw else:
training_arguments_kwargs["optim"] = "adamw_hf" # Use transformers' optimizer
training_arguments_kwargs["optim"] = self.cfg.optimizer
# Parse any additional optimizer args from config
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optimizer == "adamw_anyprecision": if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists(): if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path) sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config: if self.cfg.accelerator_config:
training_arguments_kwargs[ training_arguments_kwargs[
"accelerator_config" "accelerator_config"

View File

@@ -14,6 +14,7 @@ from typing import Dict, Literal, Optional
import torch import torch
from datasets import Dataset from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer from transformers import Trainer
@@ -22,6 +23,7 @@ from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
@@ -166,47 +168,18 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer): class OptimizerMixin(Trainer):
""" """
Extend the base Trainer for axolotl helpers Mixin class for shared handling of building custom optimizers
""" """
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__( def create_optimizer_grouped_parameters(
self, self, opt_model, optimizer_kwargs
*_args, ) -> list[dict]:
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model) decay_parameters = self.get_decay_parameter_names(opt_model)
params = { params: dict = {
"to_weight_decay": {}, # LayerNorm and bias "to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens, "embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {}, "no_weight_decay": {},
@@ -293,23 +266,30 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None and self.args.embedding_lr is None
and self.args.lr_groups is None and self.args.lr_groups is None
and self.args.alternate_optimizer and self.optimizer_cls_and_kwargs is None
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
): ):
return super().create_optimizer() return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( if (
self.args, not self.optimizer
opt_model, and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
) )
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs opt_model, optimizer_kwargs
) )
@@ -326,50 +306,47 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_embedding=loraplus_lr_embedding, loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs, **optimizer_kwargs,
) )
elif ( else:
self.args.embedding_lr_scale is not None # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
or self.args.embedding_lr is not None # e.g. for GaLore optimizer.
or self.args.lr_groups is not None if "params" in optimizer_kwargs:
): optimizer_grouped_parameters = optimizer_kwargs.pop("params")
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
AdamW( # e.g. for LOMO optimizer.
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
) )
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init self.optimizer = optimizer_cls(
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs) optimizer_grouped_parameters, **optimizer_kwargs
) )
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init if optimizer_cls.__name__ == "Adam8bit":
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs) import bitsandbytes
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init skipped = 0
ADOPT( for module in opt_model.modules():
optimizer_grouped_parameters, if isinstance(module, nn.Embedding):
decouple=True, skipped += sum(
**optimizer_kwargs, {
) p.data_ptr(): p.numel() for p in module.parameters()
) }.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -378,6 +355,45 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return self.optimizer return self.optimizer
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches: if self.args.multipack_real_batches:

View File

@@ -23,6 +23,8 @@ import importlib
import logging import logging
from typing import OrderedDict from typing import OrderedDict
import torch
class BasePlugin: class BasePlugin:
""" """
@@ -469,3 +471,14 @@ class PluginManager:
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
plugin.post_train_unload(cfg) plugin.post_train_unload(cfg)
class BaseOptimizerFactory:
"""
Base class for factories to create custom optimizers
"""
def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> "torch.optim.Optimizer":
pass

View File

@@ -15,7 +15,7 @@ from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin, Trainer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
@@ -461,7 +461,7 @@ def setup_model_and_trainer(
def train( def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]: ) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
""" """
Train a model on the given dataset. Train a model on the given dataset.
@@ -510,4 +510,4 @@ def train(
# Create model card # Create model card
create_model_card(cfg, trainer) create_model_card(cfg, trainer)
return model, tokenizer return model, tokenizer, trainer

View File

@@ -64,6 +64,18 @@ class ChatTemplate(str, Enum):
metharme = "metharme" # pylint: disable=invalid-name metharme = "metharme" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel): class DeprecatedParameters(BaseModel):
"""configurations that are deprecated""" """configurations that are deprecated"""
@@ -494,17 +506,7 @@ class HyperparametersConfig(BaseModel):
embedding_lr_scale: Optional[float] = None embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[
Union[ Union[OptimizerNames, CustomSupportedOptimizers]
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF ] = OptimizerNames.ADAMW_HF
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, default=None,
@@ -1177,6 +1179,13 @@ class AxolotlInputConfig(
LOG.warning("adamw hyperparameters found, but no adamw optimizer set") LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self return self
@model_validator(mode="before")
@classmethod
def check_lr_groups(cls, data):
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_saves(cls, data): def check_saves(cls, data):

View File

@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
config_path.write_text(valid_test_config) config_path.write_text(valid_test_config)
with patch("axolotl.cli.train.train") as mock_train: with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock()) mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
result = cli_runner.invoke( result = cli_runner.invoke(
cli, cli,
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
config_path = self._test_cli_overrides(tmp_path, valid_test_config) config_path = self._test_cli_overrides(tmp_path, valid_test_config)
with patch("axolotl.cli.train.train") as mock_train: with patch("axolotl.cli.train.train") as mock_train:
mock_train.return_value = (MagicMock(), MagicMock()) mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
result = cli_runner.invoke( result = cli_runner.invoke(
cli, cli,

View File

@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert ( assert (
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32 == torch.float32

View File

@@ -65,8 +65,9 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
@with_temp_dir @with_temp_dir
@require_torch_2_5_1 @require_torch_2_5_1
@@ -111,8 +112,57 @@ class TestCustomOptimizers(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_5_1
def test_muon(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir @with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir): def test_fft_schedule_free_adamw(self, temp_dir):