feat: allow custom optim for rl methods
This commit is contained in:
@@ -389,6 +389,117 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
self.cfg.cosine_constant_lr_ratio
|
self.cfg.cosine_constant_lr_ratio
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle custom optimizer
|
||||||
|
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||||
|
if self.cfg.optimizer in custom_supported_optimizers:
|
||||||
|
# Common optimizer kwargs
|
||||||
|
optimizer_kwargs = {
|
||||||
|
"lr": training_args_kwargs.get("learning_rate"),
|
||||||
|
"weight_decay": training_args_kwargs.get("weight_decay"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Adam-specific kwargs
|
||||||
|
adam_kwargs: dict = {}
|
||||||
|
if training_args_kwargs.get("adam_beta1") and training_args_kwargs.get(
|
||||||
|
"adam_beta2"
|
||||||
|
):
|
||||||
|
adam_kwargs["betas"] = (
|
||||||
|
training_args_kwargs.get("adam_beta1"),
|
||||||
|
training_args_kwargs.get("adam_beta2"),
|
||||||
|
)
|
||||||
|
if training_args_kwargs.get("adam_epsilon"):
|
||||||
|
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
||||||
|
|
||||||
|
if self.cfg.optimizer == "muon":
|
||||||
|
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||||
|
MuonOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_cls = MuonOptimizerFactory
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
|
from optimi import AdamW
|
||||||
|
|
||||||
|
optimizer_kwargs["foreach"] = False
|
||||||
|
optimizer_cls = AdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||||
|
# TODO remove 20250401
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW4bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||||
|
)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW8bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
|
optimizer_cls = AdamWFp8
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "adopt_adamw":
|
||||||
|
from axolotl.utils.optimizers.adopt import ADOPT
|
||||||
|
|
||||||
|
optimizer_cls = ADOPT
|
||||||
|
adam_kwargs["decouple"] = True
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "came_pytorch":
|
||||||
|
from came_pytorch import CAME
|
||||||
|
|
||||||
|
optimizer_cls = CAME
|
||||||
|
|
||||||
|
beta1 = training_args_kwargs.get("adam_beta1", 0.9)
|
||||||
|
beta2 = training_args_kwargs.get("adam_beta2", 0.999)
|
||||||
|
beta3 = training_args_kwargs.get("adam_beta2", 0.9999)
|
||||||
|
eps1 = training_args_kwargs.get("adam_epsilon", 1e-30)
|
||||||
|
eps2 = training_args_kwargs.get("adam_epsilon2", 1e-16)
|
||||||
|
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
||||||
|
adam_kwargs["eps"] = (eps1, eps2)
|
||||||
|
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optimizer_kwargs.update(self.cfg.optim_args)
|
||||||
|
else:
|
||||||
|
# Parse string format "key1=value1,key2=value2"
|
||||||
|
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
|
||||||
|
key, value = mapping.split("=")
|
||||||
|
optimizer_kwargs[key] = value
|
||||||
|
|
||||||
|
training_args_kwargs["optimizer_cls_and_kwargs"] = (
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use transformers' optimizer
|
||||||
|
training_args_kwargs["optim"] = self.cfg.optimizer
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optim_args = ",".join(
|
||||||
|
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optim_args = self.cfg.optim_args
|
||||||
|
training_args_kwargs["optim_args"] = optim_args
|
||||||
|
|
||||||
|
if self.cfg.optimizer == "adamw_anyprecision":
|
||||||
|
if Path(self.cfg.torchdistx_path).exists():
|
||||||
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
|
if self.cfg.optim_target_modules:
|
||||||
|
training_args_kwargs["optim_target_modules"] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
return training_args_kwargs
|
return training_args_kwargs
|
||||||
|
|
||||||
|
|
||||||
@@ -675,119 +786,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
# Handle custom optimizer
|
|
||||||
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
|
||||||
if self.cfg.optimizer in custom_supported_optimizers:
|
|
||||||
# Common optimizer kwargs
|
|
||||||
optimizer_kwargs = {
|
|
||||||
"lr": training_arguments_kwargs.get("learning_rate"),
|
|
||||||
"weight_decay": training_arguments_kwargs.get("weight_decay"),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Adam-specific kwargs
|
|
||||||
adam_kwargs = {}
|
|
||||||
if training_arguments_kwargs.get(
|
|
||||||
"adam_beta1"
|
|
||||||
) and training_arguments_kwargs.get("adam_beta2"):
|
|
||||||
adam_kwargs["betas"] = (
|
|
||||||
training_arguments_kwargs.get("adam_beta1"),
|
|
||||||
training_arguments_kwargs.get("adam_beta2"),
|
|
||||||
)
|
|
||||||
if training_arguments_kwargs.get("adam_epsilon"):
|
|
||||||
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
|
|
||||||
|
|
||||||
if self.cfg.optimizer == "muon":
|
|
||||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
|
||||||
MuonOptimizerFactory,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "optimi_adamw":
|
|
||||||
from optimi import AdamW
|
|
||||||
|
|
||||||
optimizer_kwargs["foreach"] = False
|
|
||||||
optimizer_cls = AdamW
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
|
||||||
# TODO remove 20250401
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
|
||||||
|
|
||||||
optimizer_cls = AdamW4bit
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
|
||||||
)
|
|
||||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
|
||||||
|
|
||||||
optimizer_cls = AdamW8bit
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
|
||||||
|
|
||||||
optimizer_cls = AdamWFp8
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "adopt_adamw":
|
|
||||||
from axolotl.utils.optimizers.adopt import ADOPT
|
|
||||||
|
|
||||||
optimizer_cls = ADOPT
|
|
||||||
adam_kwargs["decouple"] = True
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "came_pytorch":
|
|
||||||
from came_pytorch import CAME
|
|
||||||
|
|
||||||
optimizer_cls = CAME
|
|
||||||
|
|
||||||
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
|
|
||||||
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
|
|
||||||
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
|
|
||||||
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
|
|
||||||
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
|
|
||||||
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
|
||||||
adam_kwargs["eps"] = (eps1, eps2)
|
|
||||||
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
|
|
||||||
# Parse any additional optimizer args from config
|
|
||||||
if self.cfg.optim_args:
|
|
||||||
if isinstance(self.cfg.optim_args, dict):
|
|
||||||
optimizer_kwargs.update(self.cfg.optim_args)
|
|
||||||
else:
|
|
||||||
# Parse string format "key1=value1,key2=value2"
|
|
||||||
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
|
|
||||||
key, value = mapping.split("=")
|
|
||||||
optimizer_kwargs[key] = value
|
|
||||||
|
|
||||||
trainer_kwargs["optimizer_cls_and_kwargs"] = (
|
|
||||||
optimizer_cls,
|
|
||||||
optimizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use transformers' optimizer
|
|
||||||
training_arguments_kwargs["optim"] = self.cfg.optimizer
|
|
||||||
|
|
||||||
# Parse any additional optimizer args from config
|
|
||||||
if self.cfg.optim_args:
|
|
||||||
if isinstance(self.cfg.optim_args, dict):
|
|
||||||
optim_args = ",".join(
|
|
||||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optim_args = self.cfg.optim_args
|
|
||||||
training_arguments_kwargs["optim_args"] = optim_args
|
|
||||||
|
|
||||||
if self.cfg.optimizer == "adamw_anyprecision":
|
|
||||||
if Path(self.cfg.torchdistx_path).exists():
|
|
||||||
sys.path.append(self.cfg.torchdistx_path)
|
|
||||||
importlib.import_module("torchdistx")
|
|
||||||
|
|
||||||
if self.cfg.optim_target_modules:
|
|
||||||
training_arguments_kwargs["optim_target_modules"] = (
|
|
||||||
self.cfg.optim_target_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user