diff --git a/requirements.txt b/requirements.txt index 6a5491d38..472b77a79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,3 +43,4 @@ trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda50034822 fastcore>=1.5.29 lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main +yacs diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d52ac7b07..01d3356d2 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -23,6 +23,7 @@ import transformers from accelerate import FullyShardedDataParallelPlugin from accelerate.utils import str_to_bool from datasets import Dataset +from torch import nn from torch.distributed.fsdp import MixedPrecision from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler @@ -270,25 +271,71 @@ class AxolotlTrainer(Trainer): ) def create_optimizer(self): - if self.args.loraplus_lr_ratio is None: - return super().create_optimizer() - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, - ) - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, + if self.optimizer is None: # pylint: disable=access-member-before-definition + decay_parameters = self.get_decay_parameter_names(opt_model) + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + ( optimizer_cls, optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding, - ) + ) = AxolotlTrainer.get_optimizer_cls_and_kwargs(self.args) + + if self.args.loraplus_lr_ratio: + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + loraplus_lr_embedding = getattr( + self.args, "loraplus_lr_embedding", None + ) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + optimizer_kwargs, + loraplus_lr_ratio, + loraplus_lr_embedding, + ) + + else: + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + ) + + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum( + { + p.data_ptr(): p.numel() for p in module.parameters() + }.values() + ) + LOG.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override( + module, "weight", {"optim_bits": 32} + ) + LOG.debug(f"bitsandbytes: will optimize {module} in fp32") + LOG.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init