Compare commits
2 Commits
flex_patch
...
optimizers
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76bb09784d | ||
|
|
0542c7dd56 |
@@ -63,3 +63,4 @@ torchao==0.7.0
|
|||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.3
|
axolotl-contribs-lgpl==0.0.3
|
||||||
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|||||||
@@ -41,11 +41,12 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
del model
|
del model
|
||||||
del tokenizer
|
del tokenizer
|
||||||
|
del trainer
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from transformers import (
|
|||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.base import (
|
||||||
@@ -84,6 +85,7 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -549,28 +551,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||||
else:
|
else:
|
||||||
training_arguments_kwargs["run_name"] = None
|
training_arguments_kwargs["run_name"] = None
|
||||||
training_arguments_kwargs["optim"] = (
|
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
|
||||||
)
|
|
||||||
if self.cfg.optim_args:
|
|
||||||
if isinstance(self.cfg.optim_args, dict):
|
|
||||||
optim_args = ",".join(
|
|
||||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optim_args = self.cfg.optim_args
|
|
||||||
training_arguments_kwargs["optim_args"] = optim_args
|
|
||||||
if self.cfg.optim_target_modules:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"optim_target_modules"
|
|
||||||
] = self.cfg.optim_target_modules
|
|
||||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"loraplus_lr_embedding"
|
|
||||||
] = self.cfg.loraplus_lr_embedding
|
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
|
||||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
@@ -656,46 +636,114 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# Handle custom optimizer
|
||||||
if self.cfg.optimizer in [
|
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||||
"optimi_adamw",
|
if self.cfg.optimizer in custom_supported_optimizers:
|
||||||
"ao_adamw_4bit",
|
# Common optimizer kwargs
|
||||||
"ao_adamw_8bit",
|
optimizer_kwargs = {
|
||||||
"ao_adamw_fp8",
|
"lr": training_arguments_kwargs.get("learning_rate"),
|
||||||
"adopt_adamw",
|
"weight_decay": training_arguments_kwargs.get("weight_decay"),
|
||||||
]:
|
}
|
||||||
# Set default so transformers doesn't throw
|
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
|
||||||
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
|
||||||
|
|
||||||
if self.cfg.optimizer == "lion_pytorch":
|
# Adam-specific kwargs
|
||||||
from lion_pytorch import Lion
|
adam_kwargs = {}
|
||||||
|
if training_arguments_kwargs.get(
|
||||||
|
"adam_beta1"
|
||||||
|
) and training_arguments_kwargs.get("adam_beta2"):
|
||||||
|
adam_kwargs["betas"] = (
|
||||||
|
training_arguments_kwargs.get("adam_beta1"),
|
||||||
|
training_arguments_kwargs.get("adam_beta2"),
|
||||||
|
)
|
||||||
|
if training_arguments_kwargs.get("adam_epsilon"):
|
||||||
|
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
|
||||||
|
|
||||||
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
|
if self.cfg.optimizer == "muon":
|
||||||
if "weight_decay" in training_arguments_kwargs:
|
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||||
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
|
MuonOptimizerFactory,
|
||||||
|
|
||||||
if (
|
|
||||||
"adam_beta1" in training_arguments_kwargs
|
|
||||||
and "adam_beta2" in training_arguments_kwargs
|
|
||||||
):
|
|
||||||
lion_kwargs["betas"] = (
|
|
||||||
training_arguments_kwargs["adam_beta1"],
|
|
||||||
training_arguments_kwargs["adam_beta2"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer_kwargs["optimizers"] = (
|
optimizer_cls = MuonOptimizerFactory
|
||||||
Lion(params=self.model.parameters(), **lion_kwargs),
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
None,
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
|
from optimi import AdamW
|
||||||
|
|
||||||
|
optimizer_kwargs["foreach"] = False
|
||||||
|
optimizer_cls = AdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||||
|
# TODO remove 20250401
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW4bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||||
)
|
)
|
||||||
# Set default so transformers doesn't throw
|
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW8bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
|
optimizer_cls = AdamWFp8
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "adopt_adamw":
|
||||||
|
from axolotl.utils.optimizers.adopt import ADOPT
|
||||||
|
|
||||||
|
optimizer_cls = ADOPT
|
||||||
|
adam_kwargs["decouple"] = True
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optimizer_kwargs.update(self.cfg.optim_args)
|
||||||
|
else:
|
||||||
|
# Parse string format "key1=value1,key2=value2"
|
||||||
|
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
|
||||||
|
key, value = mapping.split("=")
|
||||||
|
optimizer_kwargs[key] = value
|
||||||
|
|
||||||
|
trainer_kwargs["optimizer_cls_and_kwargs"] = (
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use transformers' optimizer
|
||||||
|
training_arguments_kwargs["optim"] = self.cfg.optimizer
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optim_args = ",".join(
|
||||||
|
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optim_args = self.cfg.optim_args
|
||||||
|
training_arguments_kwargs["optim_args"] = optim_args
|
||||||
|
|
||||||
if self.cfg.optimizer == "adamw_anyprecision":
|
if self.cfg.optimizer == "adamw_anyprecision":
|
||||||
if Path(self.cfg.torchdistx_path).exists():
|
if Path(self.cfg.torchdistx_path).exists():
|
||||||
sys.path.append(self.cfg.torchdistx_path)
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
|
if self.cfg.optim_target_modules:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optim_target_modules"
|
||||||
|
] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
|
||||||
|
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"loraplus_lr_embedding"
|
||||||
|
] = self.cfg.loraplus_lr_embedding
|
||||||
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
|
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"accelerator_config"
|
"accelerator_config"
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from typing import Dict, Literal, Optional
|
|||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
@@ -22,6 +23,7 @@ from transformers.utils import is_sagemaker_mp_enabled
|
|||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
@@ -166,47 +168,18 @@ class SchedulerMixin(Trainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
class OptimizerMixin(Trainer):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
Mixin class for shared handling of building custom optimizers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
|
||||||
|
|
||||||
def __init__(
|
def create_optimizer_grouped_parameters(
|
||||||
self,
|
self, opt_model, optimizer_kwargs
|
||||||
*_args,
|
) -> list[dict]:
|
||||||
bench_data_collator=None,
|
|
||||||
eval_data_collator=None,
|
|
||||||
dataset_tags=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.bench_data_collator = bench_data_collator
|
|
||||||
self.eval_data_collator = eval_data_collator
|
|
||||||
self.dataset_tags = dataset_tags
|
|
||||||
self._signature_columns = None # workaround for pylint
|
|
||||||
super().__init__(*_args, **kwargs)
|
|
||||||
self.train_data_collator = self.data_collator
|
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
|
||||||
if self.args.orpo_alpha:
|
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
|
||||||
if self.args.torch_compile:
|
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
|
||||||
256
|
|
||||||
)
|
|
||||||
model = torch.compile(
|
|
||||||
model,
|
|
||||||
backend=self.args.torch_compile_backend,
|
|
||||||
mode=self.args.torch_compile_mode,
|
|
||||||
)
|
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
params = {
|
params: dict = {
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
"no_weight_decay": {},
|
"no_weight_decay": {},
|
||||||
@@ -293,23 +266,30 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
and self.args.embedding_lr_scale is None
|
and self.args.embedding_lr_scale is None
|
||||||
and self.args.embedding_lr is None
|
and self.args.embedding_lr is None
|
||||||
and self.args.lr_groups is None
|
and self.args.lr_groups is None
|
||||||
and self.args.alternate_optimizer
|
and self.optimizer_cls_and_kwargs is None
|
||||||
not in [
|
|
||||||
"optimi_adamw",
|
|
||||||
"ao_adamw_8bit",
|
|
||||||
"ao_adamw_4bit",
|
|
||||||
"ao_adamw_fp8",
|
|
||||||
"adopt_adamw",
|
|
||||||
]
|
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
if (
|
||||||
self.args,
|
not self.optimizer
|
||||||
opt_model,
|
and self.optimizer_cls_and_kwargs is not None
|
||||||
|
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||||
|
):
|
||||||
|
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
self.optimizer = optimizer_factory_cls()(
|
||||||
|
opt_model, self.args, **optimizer_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.optimizer:
|
||||||
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
else:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args, opt_model
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
opt_model, optimizer_kwargs
|
opt_model, optimizer_kwargs
|
||||||
)
|
)
|
||||||
@@ -326,50 +306,47 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
**optimizer_kwargs,
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
elif (
|
else:
|
||||||
self.args.embedding_lr_scale is not None
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
or self.args.embedding_lr is not None
|
# e.g. for GaLore optimizer.
|
||||||
or self.args.lr_groups is not None
|
if "params" in optimizer_kwargs:
|
||||||
):
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
|
||||||
from optimi import AdamW
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
AdamW(
|
# e.g. for LOMO optimizer.
|
||||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
if "model" in optimizer_kwargs:
|
||||||
)
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||||
|
"optimizer_dict"
|
||||||
)
|
)
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = optimizer_cls(
|
||||||
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
optimizer_grouped_parameters, **optimizer_kwargs
|
||||||
)
|
)
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
import bitsandbytes
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
|
||||||
from axolotl.utils.optimizers.adopt import ADOPT
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
ADOPT(
|
|
||||||
optimizer_grouped_parameters,
|
skipped = 0
|
||||||
decouple=True,
|
for module in opt_model.modules():
|
||||||
**optimizer_kwargs,
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum(
|
||||||
|
{
|
||||||
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
|
}.values()
|
||||||
)
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
)
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -378,6 +355,45 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||||
|
"""
|
||||||
|
Extend the base Trainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
tag_names = ["axolotl"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*_args,
|
||||||
|
bench_data_collator=None,
|
||||||
|
eval_data_collator=None,
|
||||||
|
dataset_tags=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.bench_data_collator = bench_data_collator
|
||||||
|
self.eval_data_collator = eval_data_collator
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
|
self._signature_columns = None # workaround for pylint
|
||||||
|
super().__init__(*_args, **kwargs)
|
||||||
|
self.train_data_collator = self.data_collator
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
if self.args.orpo_alpha:
|
||||||
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.torch_compile:
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
|
256
|
||||||
|
)
|
||||||
|
model = torch.compile(
|
||||||
|
model,
|
||||||
|
backend=self.args.torch_compile_backend,
|
||||||
|
mode=self.args.torch_compile_mode,
|
||||||
|
)
|
||||||
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import importlib
|
|||||||
import logging
|
import logging
|
||||||
from typing import OrderedDict
|
from typing import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin:
|
||||||
"""
|
"""
|
||||||
@@ -469,3 +471,14 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
plugin.post_train_unload(cfg)
|
plugin.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOptimizerFactory:
|
||||||
|
"""
|
||||||
|
Base class for factories to create custom optimizers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, opt_model, training_args, **optimizer_kwargs
|
||||||
|
) -> "torch.optim.Optimizer":
|
||||||
|
pass
|
||||||
|
|||||||
@@ -461,7 +461,7 @@ def setup_model_and_trainer(
|
|||||||
|
|
||||||
def train(
|
def train(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
|
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]:
|
||||||
"""
|
"""
|
||||||
Train a model on the given dataset.
|
Train a model on the given dataset.
|
||||||
|
|
||||||
@@ -510,4 +510,4 @@ def train(
|
|||||||
# Create model card
|
# Create model card
|
||||||
create_model_card(cfg, trainer)
|
create_model_card(cfg, trainer)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer, trainer
|
||||||
|
|||||||
@@ -64,6 +64,18 @@ class ChatTemplate(str, Enum):
|
|||||||
metharme = "metharme" # pylint: disable=invalid-name
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
|
"""Custom supported optimizers"""
|
||||||
|
|
||||||
|
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
|
lion_pytorch = "lion_pytorch" # pylint: disable=invalid-name
|
||||||
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
"""configurations that are deprecated"""
|
"""configurations that are deprecated"""
|
||||||
|
|
||||||
@@ -494,17 +506,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
embedding_lr_scale: Optional[float] = None
|
embedding_lr_scale: Optional[float] = None
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[
|
Union[OptimizerNames, CustomSupportedOptimizers]
|
||||||
OptimizerNames,
|
|
||||||
Literal[
|
|
||||||
"lion_pytorch",
|
|
||||||
"optimi_adamw",
|
|
||||||
"ao_adamw_4bit",
|
|
||||||
"ao_adamw_8bit",
|
|
||||||
"ao_adamw_fp8",
|
|
||||||
"adopt_adamw",
|
|
||||||
],
|
|
||||||
]
|
|
||||||
] = OptimizerNames.ADAMW_HF
|
] = OptimizerNames.ADAMW_HF
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -1177,6 +1179,13 @@ class AxolotlInputConfig(
|
|||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_lr_groups(cls, data):
|
||||||
|
if data.get("lr_groups") and data.get("loraplus_lr_ratio"):
|
||||||
|
raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.")
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_saves(cls, data):
|
def check_saves(cls, data):
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("axolotl.cli.train.train") as mock_train:
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
mock_train.return_value = (MagicMock(), MagicMock())
|
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||||
|
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
@@ -48,7 +48,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
|
||||||
|
|
||||||
with patch("axolotl.cli.train.train") as mock_train:
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
mock_train.return_value = (MagicMock(), MagicMock())
|
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||||
|
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -190,7 +190,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
@@ -249,7 +249,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (
|
assert (
|
||||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||||
== torch.float32
|
== torch.float32
|
||||||
|
|||||||
@@ -65,8 +65,9 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert trainer.optimizer.optimizer.__class__.__name__ == "AdamW"
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
@require_torch_2_5_1
|
@require_torch_2_5_1
|
||||||
@@ -111,8 +112,57 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert "ADOPT" in trainer.optimizer.optimizer.__class__.__name__
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
@require_torch_2_5_1
|
||||||
|
def test_muon(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "muon",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"weight_decay": 0.01,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||||
|
|||||||
Reference in New Issue
Block a user