Compare commits
2 Commits
fix/cp-was
...
optimizers
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76bb09784d | ||
|
|
0542c7dd56 |
@@ -63,3 +63,4 @@ torchao==0.7.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
axolotl-contribs-lgpl==0.0.3
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
@@ -41,11 +41,12 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
del trainer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.training_args import OptimizerNames
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
@@ -84,6 +85,7 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||
from axolotl.utils.models import ensure_dtype
|
||||
|
||||
try:
|
||||
@@ -549,28 +551,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
else:
|
||||
training_arguments_kwargs["run_name"] = None
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optim_args = ",".join(
|
||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||
)
|
||||
else:
|
||||
optim_args = self.cfg.optim_args
|
||||
training_arguments_kwargs["optim_args"] = optim_args
|
||||
if self.cfg.optim_target_modules:
|
||||
training_arguments_kwargs[
|
||||
"optim_target_modules"
|
||||
] = self.cfg.optim_target_modules
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
@@ -656,46 +636,114 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.reward_model:
|
||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if self.cfg.optimizer in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]:
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
||||
# Handle custom optimizer
|
||||
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||
if self.cfg.optimizer in custom_supported_optimizers:
|
||||
# Common optimizer kwargs
|
||||
optimizer_kwargs = {
|
||||
"lr": training_arguments_kwargs.get("learning_rate"),
|
||||
"weight_decay": training_arguments_kwargs.get("weight_decay"),
|
||||
}
|
||||
|
||||
if self.cfg.optimizer == "lion_pytorch":
|
||||
from lion_pytorch import Lion
|
||||
# Adam-specific kwargs
|
||||
adam_kwargs = {}
|
||||
if training_arguments_kwargs.get(
|
||||
"adam_beta1"
|
||||
) and training_arguments_kwargs.get("adam_beta2"):
|
||||
adam_kwargs["betas"] = (
|
||||
training_arguments_kwargs.get("adam_beta1"),
|
||||
training_arguments_kwargs.get("adam_beta2"),
|
||||
)
|
||||
if training_arguments_kwargs.get("adam_epsilon"):
|
||||
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
|
||||
|
||||
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
|
||||
if "weight_decay" in training_arguments_kwargs:
|
||||
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
|
||||
|
||||
if (
|
||||
"adam_beta1" in training_arguments_kwargs
|
||||
and "adam_beta2" in training_arguments_kwargs
|
||||
):
|
||||
lion_kwargs["betas"] = (
|
||||
training_arguments_kwargs["adam_beta1"],
|
||||
training_arguments_kwargs["adam_beta2"],
|
||||
if self.cfg.optimizer == "muon":
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
trainer_kwargs["optimizers"] = (
|
||||
Lion(params=self.model.parameters(), **lion_kwargs),
|
||||
None,
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
optimizer_kwargs["foreach"] = False
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||
# TODO remove 20250401
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
LOG.warning(
|
||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||
)
|
||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
optimizer_cls = AdamW8bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
optimizer_cls = AdamWFp8
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
|
||||
optimizer_cls = ADOPT
|
||||
adam_kwargs["decouple"] = True
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
# Parse any additional optimizer args from config
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optimizer_kwargs.update(self.cfg.optim_args)
|
||||
else:
|
||||
# Parse string format "key1=value1,key2=value2"
|
||||
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
|
||||
key, value = mapping.split("=")
|
||||
optimizer_kwargs[key] = value
|
||||
|
||||
trainer_kwargs["optimizer_cls_and_kwargs"] = (
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
else:
|
||||
# Use transformers' optimizer
|
||||
training_arguments_kwargs["optim"] = self.cfg.optimizer
|
||||
|
||||
# Parse any additional optimizer args from config
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optim_args = ",".join(
|
||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||
)
|
||||
else:
|
||||
optim_args = self.cfg.optim_args
|
||||
training_arguments_kwargs["optim_args"] = optim_args
|
||||
|
||||
if self.cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(self.cfg.torchdistx_path).exists():
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
if self.cfg.optim_target_modules:
|
||||
training_arguments_kwargs[
|
||||
"optim_target_modules"
|
||||
] = self.cfg.optim_target_modules
|
||||
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs[
|
||||
"accelerator_config"
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import Dict, Literal, Optional
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import Trainer
|
||||
@@ -22,6 +23,7 @@ from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
@@ -166,47 +168,18 @@ class SchedulerMixin(Trainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
class OptimizerMixin(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
Mixin class for shared handling of building custom optimizers
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
||||
def create_optimizer_grouped_parameters(
|
||||
self, opt_model, optimizer_kwargs
|
||||
) -> list[dict]:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params = {
|
||||
params: dict = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
@@ -293,23 +266,30 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
and self.optimizer_cls_and_kwargs is not None
|
||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||
):
|
||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
self.optimizer = optimizer_factory_cls()(
|
||||
opt_model, self.args, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if not self.optimizer:
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||
self.args, opt_model
|
||||
)
|
||||
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
@@ -326,50 +306,47 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
elif (
|
||||
self.args.embedding_lr_scale is not None
|
||||
or self.args.embedding_lr is not None
|
||||
or self.args.lr_groups is not None
|
||||
):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
else:
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW(
|
||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||
"optimizer_dict"
|
||||
)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
self.optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters, **optimizer_kwargs
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters,
|
||||
decouple=True,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
)
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -378,6 +355,45 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
|
||||
return self.optimizer
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
|
||||
@@ -23,6 +23,8 @@ import importlib
|
||||
import logging
|
||||
from typing import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
"""
|
||||
@@ -469,3 +471,14 @@ class PluginManager:
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train_unload(cfg)
|
||||
|
||||
|
||||
class BaseOptimizerFactory:
|
||||
"""
|
||||
Base class for factories to create custom optimizers
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, opt_model, training_args, **optimizer_kwargs
|
||||
) -> "torch.optim.Optimizer":
|
||||
pass
|
||||
|
||||
@@ -461,7 +461,7 @@ def setup_model_and_trainer(
|
||||
|
||||
def train(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
|
||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||
"""
|
||||
Train a model on the given dataset.
|
||||
|
||||
@@ -510,4 +510,4 @@ def train(
|
||||
# Create model card
|
||||
create_model_card(cfg, trainer)
|
||||
|
||||
return model, tokenizer
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -64,6 +64,18 @@ class ChatTemplate(str, Enum):
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CustomSupportedOptimizers(str, Enum):
|
||||
"""Custom supported optimizers"""
|
||||
|
||||
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name
|
||||
muon = "muon" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
"""configurations that are deprecated"""
|
||||
|
||||
@@ -494,17 +506,7 @@ class HyperparametersConfig(BaseModel):
|
||||
embedding_lr_scale: Optional[float] = None
|
||||
weight_decay: Optional[float] = 0.0
|
||||
optimizer: Optional[
|
||||
Union[
|
||||
OptimizerNames,
|
||||
Literal[
|
||||
"lion_pytorch",
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
],
|
||||
]
|
||||
Union[OptimizerNames, CustomSupportedOptimizers]
|
||||
] = OptimizerNames.ADAMW_HF
|
||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
@@ -1177,6 +1179,13 @@ class AxolotlInputConfig(
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lr_groups(cls, data):
|
||||
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
|
||||
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_saves(cls, data):
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
||||
|
||||
with patch("axolotl.cli.train.train") as mock_train:
|
||||
mock_train.return_value = (MagicMock(), MagicMock())
|
||||
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
|
||||
@@ -65,8 +65,9 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
@@ -111,8 +112,57 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
def test_muon(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "muon",
|
||||
"lr_scheduler": "cosine",
|
||||
"weight_decay": 0.01,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||
|
||||
Reference in New Issue
Block a user