diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index a842f5961..d79801a6a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -389,6 +389,117 @@ class TrainerBuilderBase(abc.ABC): self.cfg.cosine_constant_lr_ratio ) + # Handle custom optimizer + custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers] + if self.cfg.optimizer in custom_supported_optimizers: + # Common optimizer kwargs + optimizer_kwargs = { + "lr": training_args_kwargs.get("learning_rate"), + "weight_decay": training_args_kwargs.get("weight_decay"), + } + + # Adam-specific kwargs + adam_kwargs: dict = {} + if training_args_kwargs.get("adam_beta1") and training_args_kwargs.get( + "adam_beta2" + ): + adam_kwargs["betas"] = ( + training_args_kwargs.get("adam_beta1"), + training_args_kwargs.get("adam_beta2"), + ) + if training_args_kwargs.get("adam_epsilon"): + adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon") + + if self.cfg.optimizer == "muon": + from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module + MuonOptimizerFactory, + ) + + optimizer_cls = MuonOptimizerFactory + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "optimi_adamw": + from optimi import AdamW + + optimizer_kwargs["foreach"] = False + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "ao_adamw_4bit": + # TODO remove 20250401 + from torchao.prototype.low_bit_optim import AdamW4bit + + optimizer_cls = AdamW4bit + optimizer_kwargs.update(adam_kwargs) + + LOG.warning( + f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead." + ) + elif self.cfg.optimizer == "ao_adamw_8bit": + from torchao.prototype.low_bit_optim import AdamW8bit + + optimizer_cls = AdamW8bit + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "ao_adamw_fp8": + from torchao.prototype.low_bit_optim import AdamWFp8 + + optimizer_cls = AdamWFp8 + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "adopt_adamw": + from axolotl.utils.optimizers.adopt import ADOPT + + optimizer_cls = ADOPT + adam_kwargs["decouple"] = True + optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "came_pytorch": + from came_pytorch import CAME + + optimizer_cls = CAME + + beta1 = training_args_kwargs.get("adam_beta1", 0.9) + beta2 = training_args_kwargs.get("adam_beta2", 0.999) + beta3 = training_args_kwargs.get("adam_beta2", 0.9999) + eps1 = training_args_kwargs.get("adam_epsilon", 1e-30) + eps2 = training_args_kwargs.get("adam_epsilon2", 1e-16) + adam_kwargs["betas"] = (beta1, beta2, beta3) + adam_kwargs["eps"] = (eps1, eps2) + + optimizer_kwargs.update(adam_kwargs) + + # Parse any additional optimizer args from config + if self.cfg.optim_args: + if isinstance(self.cfg.optim_args, dict): + optimizer_kwargs.update(self.cfg.optim_args) + else: + # Parse string format "key1=value1,key2=value2" + for mapping in self.cfg.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optimizer_kwargs[key] = value + + training_args_kwargs["optimizer_cls_and_kwargs"] = ( + optimizer_cls, + optimizer_kwargs, + ) + else: + # Use transformers' optimizer + training_args_kwargs["optim"] = self.cfg.optimizer + + # Parse any additional optimizer args from config + if self.cfg.optim_args: + if isinstance(self.cfg.optim_args, dict): + optim_args = ",".join( + [f"{key}={value}" for key, value in self.cfg.optim_args.items()] + ) + else: + optim_args = self.cfg.optim_args + training_args_kwargs["optim_args"] = optim_args + + if self.cfg.optimizer == "adamw_anyprecision": + if Path(self.cfg.torchdistx_path).exists(): + sys.path.append(self.cfg.torchdistx_path) + importlib.import_module("torchdistx") + + if self.cfg.optim_target_modules: + training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules + return training_args_kwargs @@ -675,119 +786,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} - # Handle custom optimizer - custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers] - if self.cfg.optimizer in custom_supported_optimizers: - # Common optimizer kwargs - optimizer_kwargs = { - "lr": training_arguments_kwargs.get("learning_rate"), - "weight_decay": training_arguments_kwargs.get("weight_decay"), - } - - # Adam-specific kwargs - adam_kwargs = {} - if training_arguments_kwargs.get( - "adam_beta1" - ) and training_arguments_kwargs.get("adam_beta2"): - adam_kwargs["betas"] = ( - training_arguments_kwargs.get("adam_beta1"), - training_arguments_kwargs.get("adam_beta2"), - ) - if training_arguments_kwargs.get("adam_epsilon"): - adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon") - - if self.cfg.optimizer == "muon": - from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module - MuonOptimizerFactory, - ) - - optimizer_cls = MuonOptimizerFactory - optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer == "optimi_adamw": - from optimi import AdamW - - optimizer_kwargs["foreach"] = False - optimizer_cls = AdamW - optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer == "ao_adamw_4bit": - # TODO remove 20250401 - from torchao.prototype.low_bit_optim import AdamW4bit - - optimizer_cls = AdamW4bit - optimizer_kwargs.update(adam_kwargs) - - LOG.warning( - f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead." - ) - elif self.cfg.optimizer == "ao_adamw_8bit": - from torchao.prototype.low_bit_optim import AdamW8bit - - optimizer_cls = AdamW8bit - optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer == "ao_adamw_fp8": - from torchao.prototype.low_bit_optim import AdamWFp8 - - optimizer_cls = AdamWFp8 - optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer == "adopt_adamw": - from axolotl.utils.optimizers.adopt import ADOPT - - optimizer_cls = ADOPT - adam_kwargs["decouple"] = True - optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer == "came_pytorch": - from came_pytorch import CAME - - optimizer_cls = CAME - - beta1 = training_arguments_kwargs.get("adam_beta1", 0.9) - beta2 = training_arguments_kwargs.get("adam_beta2", 0.999) - beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999) - eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30) - eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16) - adam_kwargs["betas"] = (beta1, beta2, beta3) - adam_kwargs["eps"] = (eps1, eps2) - - optimizer_kwargs.update(adam_kwargs) - - # Parse any additional optimizer args from config - if self.cfg.optim_args: - if isinstance(self.cfg.optim_args, dict): - optimizer_kwargs.update(self.cfg.optim_args) - else: - # Parse string format "key1=value1,key2=value2" - for mapping in self.cfg.optim_args.replace(" ", "").split(","): - key, value = mapping.split("=") - optimizer_kwargs[key] = value - - trainer_kwargs["optimizer_cls_and_kwargs"] = ( - optimizer_cls, - optimizer_kwargs, - ) - else: - # Use transformers' optimizer - training_arguments_kwargs["optim"] = self.cfg.optimizer - - # Parse any additional optimizer args from config - if self.cfg.optim_args: - if isinstance(self.cfg.optim_args, dict): - optim_args = ",".join( - [f"{key}={value}" for key, value in self.cfg.optim_args.items()] - ) - else: - optim_args = self.cfg.optim_args - training_arguments_kwargs["optim_args"] = optim_args - - if self.cfg.optimizer == "adamw_anyprecision": - if Path(self.cfg.torchdistx_path).exists(): - sys.path.append(self.cfg.torchdistx_path) - importlib.import_module("torchdistx") - - if self.cfg.optim_target_modules: - training_arguments_kwargs["optim_target_modules"] = ( - self.cfg.optim_target_modules - ) - training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale