diff --git a/requirements.txt b/requirements.txt index 18bb00692..cd5690f0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,3 +63,4 @@ torchao==0.7.0 schedulefree==1.3.0 axolotl-contribs-lgpl==0.0.3 +axolotl-contribs-mit==0.0.3 diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 7ac15e04f..032f12b66 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -41,11 +41,12 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta) + model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) plugin_manager = PluginManager.get_instance() del model del tokenizer + del trainer plugin_manager.post_train_unload(cfg) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index fe9c8bcae..0c9204747 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -35,6 +35,7 @@ from transformers import ( EarlyStoppingCallback, TrainerCallback, ) +from transformers.training_args import OptimizerNames from trl.trainer.utils import RewardDataCollatorWithPadding from axolotl.core.trainers.base import ( @@ -84,6 +85,7 @@ from axolotl.utils.collators import ( V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator +from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers from axolotl.utils.models import ensure_dtype try: @@ -549,28 +551,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name else: training_arguments_kwargs["run_name"] = None - training_arguments_kwargs["optim"] = ( - self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" - ) - if self.cfg.optim_args: - if isinstance(self.cfg.optim_args, dict): - optim_args = ",".join( - [f"{key}={value}" for key, value in self.cfg.optim_args.items()] - ) - else: - optim_args = self.cfg.optim_args - training_arguments_kwargs["optim_args"] = optim_args - if self.cfg.optim_target_modules: - training_arguments_kwargs[ - "optim_target_modules" - ] = self.cfg.optim_target_modules - training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio - training_arguments_kwargs[ - "loraplus_lr_embedding" - ] = self.cfg.loraplus_lr_embedding - training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr - training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale - training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine" @@ -656,46 +636,114 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.reward_model: training_arguments_kwargs["max_length"] = self.cfg.sequence_len - # pylint: disable=duplicate-code - if self.cfg.optimizer in [ - "optimi_adamw", - "ao_adamw_4bit", - "ao_adamw_8bit", - "ao_adamw_fp8", - "adopt_adamw", - ]: - # Set default so transformers doesn't throw - training_arguments_kwargs["optim"] = "adamw_hf" - training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer + # Handle custom optimizer + custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers] + if self.cfg.optimizer in custom_supported_optimizers: + # Common optimizer kwargs + optimizer_kwargs = { + "lr": training_arguments_kwargs.get("learning_rate"), + "weight_decay": training_arguments_kwargs.get("weight_decay"), + } - if self.cfg.optimizer == "lion_pytorch": - from lion_pytorch import Lion + # Adam-specific kwargs + adam_kwargs = {} + if training_arguments_kwargs.get( + "adam_beta1" + ) and training_arguments_kwargs.get("adam_beta2"): + adam_kwargs["betas"] = ( + training_arguments_kwargs.get("adam_beta1"), + training_arguments_kwargs.get("adam_beta2"), + ) + if training_arguments_kwargs.get("adam_epsilon"): + adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon") - lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} - if "weight_decay" in training_arguments_kwargs: - lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] - - if ( - "adam_beta1" in training_arguments_kwargs - and "adam_beta2" in training_arguments_kwargs - ): - lion_kwargs["betas"] = ( - training_arguments_kwargs["adam_beta1"], - training_arguments_kwargs["adam_beta2"], + if self.cfg.optimizer == "muon": + from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module + MuonOptimizerFactory, ) - trainer_kwargs["optimizers"] = ( - Lion(params=self.model.parameters(), **lion_kwargs), - None, + optimizer_cls = MuonOptimizerFactory + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "optimi_adamw": + from optimi import AdamW + + optimizer_kwargs["foreach"] = False + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "ao_adamw_4bit": + # TODO remove 20250401 + from torchao.prototype.low_bit_optim import AdamW4bit + + optimizer_cls = AdamW4bit + optimizer_kwargs.update(adam_kwargs) + + LOG.warning( + f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead." + ) + elif self.cfg.optimizer == "ao_adamw_8bit": + from torchao.prototype.low_bit_optim import AdamW8bit + + optimizer_cls = AdamW8bit + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "ao_adamw_fp8": + from torchao.prototype.low_bit_optim import AdamWFp8 + + optimizer_cls = AdamWFp8 + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "adopt_adamw": + from axolotl.utils.optimizers.adopt import ADOPT + + optimizer_cls = ADOPT + adam_kwargs["decouple"] = True + optimizer_kwargs.update(adam_kwargs) + + # Parse any additional optimizer args from config + if self.cfg.optim_args: + if isinstance(self.cfg.optim_args, dict): + optimizer_kwargs.update(self.cfg.optim_args) + else: + # Parse string format "key1=value1,key2=value2" + for mapping in self.cfg.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optimizer_kwargs[key] = value + + trainer_kwargs["optimizer_cls_and_kwargs"] = ( + optimizer_cls, + optimizer_kwargs, ) - # Set default so transformers doesn't throw - training_arguments_kwargs["optim"] = "adamw_hf" + else: + # Use transformers' optimizer + training_arguments_kwargs["optim"] = self.cfg.optimizer + + # Parse any additional optimizer args from config + if self.cfg.optim_args: + if isinstance(self.cfg.optim_args, dict): + optim_args = ",".join( + [f"{key}={value}" for key, value in self.cfg.optim_args.items()] + ) + else: + optim_args = self.cfg.optim_args + training_arguments_kwargs["optim_args"] = optim_args if self.cfg.optimizer == "adamw_anyprecision": if Path(self.cfg.torchdistx_path).exists(): sys.path.append(self.cfg.torchdistx_path) importlib.import_module("torchdistx") + if self.cfg.optim_target_modules: + training_arguments_kwargs[ + "optim_target_modules" + ] = self.cfg.optim_target_modules + + training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr + training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale + + training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio + training_arguments_kwargs[ + "loraplus_lr_embedding" + ] = self.cfg.loraplus_lr_embedding + training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups + if self.cfg.accelerator_config: training_arguments_kwargs[ "accelerator_config" diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 27f00f1fd..c14ed59b5 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -14,6 +14,7 @@ from typing import Dict, Literal, Optional import torch from datasets import Dataset from peft.optimizers import create_loraplus_optimizer +from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import Trainer @@ -22,6 +23,7 @@ from transformers.utils import is_sagemaker_mp_enabled from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl.trainer.utils import pad_to_length +from axolotl.integrations.base import BaseOptimizerFactory from axolotl.monkeypatch.relora import ReLoRAScheduler from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( @@ -166,47 +168,18 @@ class SchedulerMixin(Trainer): return self.lr_scheduler -class AxolotlTrainer(SchedulerMixin, Trainer): +class OptimizerMixin(Trainer): """ - Extend the base Trainer for axolotl helpers + Mixin class for shared handling of building custom optimizers """ args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] - tag_names = ["axolotl"] - def __init__( - self, - *_args, - bench_data_collator=None, - eval_data_collator=None, - dataset_tags=None, - **kwargs, - ): - self.bench_data_collator = bench_data_collator - self.eval_data_collator = eval_data_collator - self.dataset_tags = dataset_tags - self._signature_columns = None # workaround for pylint - super().__init__(*_args, **kwargs) - self.train_data_collator = self.data_collator - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - if self.args.orpo_alpha: - self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - - def _wrap_model(self, model, training=True, dataloader=None): - if self.args.torch_compile: - torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access - 256 - ) - model = torch.compile( - model, - backend=self.args.torch_compile_backend, - mode=self.args.torch_compile_mode, - ) - return super()._wrap_model(model, training=training, dataloader=dataloader) - - def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs): + def create_optimizer_grouped_parameters( + self, opt_model, optimizer_kwargs + ) -> list[dict]: decay_parameters = self.get_decay_parameter_names(opt_model) - params = { + params: dict = { "to_weight_decay": {}, # LayerNorm and bias "embeddings": {}, # lm_head, embed_tokens, "no_weight_decay": {}, @@ -293,23 +266,30 @@ class AxolotlTrainer(SchedulerMixin, Trainer): and self.args.embedding_lr_scale is None and self.args.embedding_lr is None and self.args.lr_groups is None - and self.args.alternate_optimizer - not in [ - "optimi_adamw", - "ao_adamw_8bit", - "ao_adamw_4bit", - "ao_adamw_fp8", - "adopt_adamw", - ] + and self.optimizer_cls_and_kwargs is None ): return super().create_optimizer() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, + + if ( + not self.optimizer + and self.optimizer_cls_and_kwargs is not None + and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory) + ): + optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs + self.optimizer = optimizer_factory_cls()( + opt_model, self.args, **optimizer_kwargs ) + + if not self.optimizer: + if self.optimizer_cls_and_kwargs is not None: + optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs + else: + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs( + self.args, opt_model + ) + optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( opt_model, optimizer_kwargs ) @@ -326,50 +306,47 @@ class AxolotlTrainer(SchedulerMixin, Trainer): loraplus_lr_embedding=loraplus_lr_embedding, **optimizer_kwargs, ) - elif ( - self.args.embedding_lr_scale is not None - or self.args.embedding_lr is not None - or self.args.lr_groups is not None - ): - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "optimi_adamw": - from optimi import AdamW + else: + # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("params") - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW( - optimizer_grouped_parameters, foreach=False, **optimizer_kwargs + # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for LOMO optimizer. + if "model" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("model") + + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` + # to avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop( + "optimizer_dict" ) - ) - elif self.args.alternate_optimizer == "ao_adamw_4bit": - from torchao.prototype.low_bit_optim import AdamW4bit - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs) + self.optimizer = optimizer_cls( + optimizer_grouped_parameters, **optimizer_kwargs ) - elif self.args.alternate_optimizer == "ao_adamw_8bit": - from torchao.prototype.low_bit_optim import AdamW8bit - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "ao_adamw_fp8": - from torchao.prototype.low_bit_optim import AdamWFp8 + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "adopt_adamw": - from axolotl.utils.optimizers.adopt import ADOPT + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - ADOPT( - optimizer_grouped_parameters, - decouple=True, - **optimizer_kwargs, - ) - ) + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum( + { + p.data_ptr(): p.numel() for p in module.parameters() + }.values() + ) + LOG.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override( + module, "weight", {"optim_bits": 32} + ) + LOG.debug(f"bitsandbytes: will optimize {module} in fp32") + LOG.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init @@ -378,6 +355,45 @@ class AxolotlTrainer(SchedulerMixin, Trainer): return self.optimizer + +class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): + """ + Extend the base Trainer for axolotl helpers + """ + + args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] + tag_names = ["axolotl"] + + def __init__( + self, + *_args, + bench_data_collator=None, + eval_data_collator=None, + dataset_tags=None, + **kwargs, + ): + self.bench_data_collator = bench_data_collator + self.eval_data_collator = eval_data_collator + self.dataset_tags = dataset_tags + self._signature_columns = None # workaround for pylint + super().__init__(*_args, **kwargs) + self.train_data_collator = self.data_collator + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if self.args.orpo_alpha: + self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.torch_compile: + torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access + 256 + ) + model = torch.compile( + model, + backend=self.args.torch_compile_backend, + mode=self.args.torch_compile_mode, + ) + return super()._wrap_model(model, training=training, dataloader=dataloader) + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.args.sample_packing and not self.args.pretraining: if self.args.multipack_real_batches: diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 211d5e51b..11015e31a 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -23,6 +23,8 @@ import importlib import logging from typing import OrderedDict +import torch + class BasePlugin: """ @@ -469,3 +471,14 @@ class PluginManager: """ for plugin in self.plugins.values(): plugin.post_train_unload(cfg) + + +class BaseOptimizerFactory: + """ + Base class for factories to create custom optimizers + """ + + def __call__( + self, opt_model, training_args, **optimizer_kwargs + ) -> "torch.optim.Optimizer": + pass diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b2f4bf1e9..aa10d0e06 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -15,7 +15,7 @@ from accelerate.logging import get_logger from accelerate.utils import save_fsdp_model from datasets import Dataset from peft import PeftConfig, PeftModel -from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin +from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin, Trainer from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import Trainer @@ -461,7 +461,7 @@ def setup_model_and_trainer( def train( cfg: DictDefault, dataset_meta: TrainDatasetMeta -) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]: +) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]: """ Train a model on the given dataset. @@ -510,4 +510,4 @@ def train( # Create model card create_model_card(cfg, trainer) - return model, tokenizer + return model, tokenizer, trainer diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 180e02823..fb23c3dfb 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -64,6 +64,18 @@ class ChatTemplate(str, Enum): metharme = "metharme" # pylint: disable=invalid-name +class CustomSupportedOptimizers(str, Enum): + """Custom supported optimizers""" + + optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name + ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name + ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name + ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name + adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name + lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name + muon = "muon" # pylint: disable=invalid-name + + class DeprecatedParameters(BaseModel): """configurations that are deprecated""" @@ -494,17 +506,7 @@ class HyperparametersConfig(BaseModel): embedding_lr_scale: Optional[float] = None weight_decay: Optional[float] = 0.0 optimizer: Optional[ - Union[ - OptimizerNames, - Literal[ - "lion_pytorch", - "optimi_adamw", - "ao_adamw_4bit", - "ao_adamw_8bit", - "ao_adamw_fp8", - "adopt_adamw", - ], - ] + Union[OptimizerNames, CustomSupportedOptimizers] ] = OptimizerNames.ADAMW_HF optim_args: Optional[Union[str, Dict[str, Any]]] = Field( default=None, @@ -1177,6 +1179,13 @@ class AxolotlInputConfig( LOG.warning("adamw hyperparameters found, but no adamw optimizer set") return self + @model_validator(mode="before") + @classmethod + def check_lr_groups(cls, data): + if data.get("lr_groups") and data.get("loraplus_lr_ratio"): + raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.") + return data + @model_validator(mode="before") @classmethod def check_saves(cls, data): diff --git a/tests/cli/test_cli_train.py b/tests/cli/test_cli_train.py index 560f3caf5..a51251033 100644 --- a/tests/cli/test_cli_train.py +++ b/tests/cli/test_cli_train.py @@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest): config_path.write_text(valid_test_config) with patch("axolotl.cli.train.train") as mock_train: - mock_train.return_value = (MagicMock(), MagicMock()) + mock_train.return_value = (MagicMock(), MagicMock(), MagicMock()) result = cli_runner.invoke( cli, @@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest): config_path = self._test_cli_overrides(tmp_path, valid_test_config) with patch("axolotl.cli.train.train") as mock_train: - mock_train.return_value = (MagicMock(), MagicMock()) + mock_train.return_value = (MagicMock(), MagicMock(), MagicMock()) result = cli_runner.invoke( cli, diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 5de5ab403..f31920be6 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, dataset_meta=dataset_meta) + model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, dataset_meta=dataset_meta) + model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, dataset_meta=dataset_meta) + model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 @@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - model, _ = train(cfg=cfg, dataset_meta=dataset_meta) + model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) assert ( model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype == torch.float32 diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 4b0ad1142..43a4735aa 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -65,8 +65,9 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, dataset_meta=dataset_meta) + _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW" @with_temp_dir @require_torch_2_5_1 @@ -111,8 +112,57 @@ class TestCustomOptimizers(unittest.TestCase): cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - train(cfg=cfg, dataset_meta=dataset_meta) + _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__ + + @with_temp_dir + @require_torch_2_5_1 + def test_muon(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "muon", + "lr_scheduler": "cosine", + "weight_decay": 0.01, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + assert "Muon" in trainer.optimizer.optimizer.__class__.__name__ @with_temp_dir def test_fft_schedule_free_adamw(self, temp_dir):