ADOPT optimizer integration (#2032) [skip ci]
* adopt integration * stuff * doc and test for ADOPT * rearrangement * fixed formatting * hacking pre-commit * chore: lint * update module doc for adopt optimizer * remove un-necessary example yaml for adopt optimizer * skip test adopt if torch<2.5.1 * formatting * use version.parse * specifies required torch version for adopt_adamw --------- Co-authored-by: sunny <sunnyliu19981005@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -405,6 +405,7 @@ lr_div_factor: # Learning rate div factor
|
|||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
|
# - adopt_adamw (only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
# - sgd
|
# - sgd
|
||||||
|
|||||||
@@ -436,7 +436,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
not in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
|
]
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
@@ -505,6 +511,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
)
|
)
|
||||||
|
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||||
|
from axolotl.utils.optimizers.adopt import ADOPT
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
ADOPT(
|
||||||
|
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -1625,11 +1639,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
if self.cfg.optimizer in [
|
if self.cfg.optimizer in [
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
]:
|
]:
|
||||||
# Set default so transformers doesn't throw
|
# Set default so transformers doesn't throw
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
|
|||||||
@@ -428,6 +428,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
|
|||||||
508
src/axolotl/utils/optimizers/adopt.py
Normal file
508
src/axolotl/utils/optimizers/adopt.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
"""
|
||||||
|
Copied from https://github.com/iShohei220/adopt
|
||||||
|
|
||||||
|
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
|
||||||
|
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
|
||||||
|
"""
|
||||||
|
# mypy: ignore-errors
|
||||||
|
# pylint: skip-file
|
||||||
|
# mypy: allow-untyped-decorators
|
||||||
|
# mypy: allow-untyped-defs
|
||||||
|
from typing import List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim.optimizer import (
|
||||||
|
Optimizer,
|
||||||
|
ParamsT,
|
||||||
|
_default_to_fused_or_foreach,
|
||||||
|
_device_dtype_check_for_fused,
|
||||||
|
_disable_dynamo_if_unsupported,
|
||||||
|
_get_capturable_supported_devices,
|
||||||
|
_get_scalar_dtype,
|
||||||
|
_get_value,
|
||||||
|
_use_grad_for_differentiable,
|
||||||
|
_view_as_real,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["ADOPT", "adopt"]
|
||||||
|
|
||||||
|
|
||||||
|
class ADOPT(Optimizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: ParamsT,
|
||||||
|
lr: Union[float, Tensor] = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||||
|
eps: float = 1e-6,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
decoupled: bool = False,
|
||||||
|
*,
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
maximize: bool = False,
|
||||||
|
capturable: bool = False,
|
||||||
|
differentiable: bool = False,
|
||||||
|
fused: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
if isinstance(lr, Tensor):
|
||||||
|
if foreach and not capturable:
|
||||||
|
raise ValueError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
if lr.numel() != 1:
|
||||||
|
raise ValueError("Tensor lr must be 1-element")
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||||
|
if not 0.0 <= weight_decay:
|
||||||
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled=decoupled,
|
||||||
|
maximize=maximize,
|
||||||
|
foreach=foreach,
|
||||||
|
capturable=capturable,
|
||||||
|
differentiable=differentiable,
|
||||||
|
fused=fused,
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
if fused:
|
||||||
|
# TODO: support fused
|
||||||
|
raise RuntimeError("`fused` is not currently supported")
|
||||||
|
|
||||||
|
if differentiable:
|
||||||
|
raise RuntimeError("`fused` does not support `differentiable`")
|
||||||
|
self._step_supports_amp_scaling = True
|
||||||
|
# TODO(crcrpar): [low prec params & their higher prec copy]
|
||||||
|
# Support AMP with FP16/BF16 model params which would need
|
||||||
|
# higher prec copy of params to do update math in higher prec to
|
||||||
|
# alleviate the loss of information.
|
||||||
|
if foreach:
|
||||||
|
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super().__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault("maximize", False)
|
||||||
|
group.setdefault("foreach", None)
|
||||||
|
group.setdefault("capturable", False)
|
||||||
|
group.setdefault("differentiable", False)
|
||||||
|
fused = group.setdefault("fused", None)
|
||||||
|
for p in group["params"]:
|
||||||
|
p_state = self.state.get(p, [])
|
||||||
|
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||||
|
step_val = float(p_state["step"])
|
||||||
|
p_state["step"] = (
|
||||||
|
torch.tensor(
|
||||||
|
step_val,
|
||||||
|
dtype=_get_scalar_dtype(is_fused=fused),
|
||||||
|
device=p.device,
|
||||||
|
)
|
||||||
|
if group["capturable"] or group["fused"]
|
||||||
|
else torch.tensor(step_val, dtype=_get_scalar_dtype())
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_group(
|
||||||
|
self,
|
||||||
|
group,
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
):
|
||||||
|
has_complex = False
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is not None:
|
||||||
|
has_complex |= torch.is_complex(p)
|
||||||
|
params_with_grad.append(p)
|
||||||
|
if p.grad.is_sparse:
|
||||||
|
raise RuntimeError("ADOPT does not support sparse gradients")
|
||||||
|
grads.append(p.grad)
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
# Lazy state initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
if group["fused"]:
|
||||||
|
_device_dtype_check_for_fused(p)
|
||||||
|
# note(crcrpar): [special device hosting for step]
|
||||||
|
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||||
|
# This is because kernel launches are costly on CUDA and XLA.
|
||||||
|
state["step"] = (
|
||||||
|
torch.zeros(
|
||||||
|
(),
|
||||||
|
dtype=_get_scalar_dtype(is_fused=group["fused"]),
|
||||||
|
device=p.device,
|
||||||
|
)
|
||||||
|
if group["capturable"] or group["fused"]
|
||||||
|
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||||
|
)
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state["exp_avg"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
|
||||||
|
exp_avgs.append(state["exp_avg"])
|
||||||
|
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||||
|
|
||||||
|
if group["differentiable"] and state["step"].requires_grad:
|
||||||
|
raise RuntimeError(
|
||||||
|
"`requires_grad` is not supported for `step` in differentiable mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Foreach without capturable does not support a tensor lr
|
||||||
|
if (
|
||||||
|
group["foreach"]
|
||||||
|
and torch.is_tensor(group["lr"])
|
||||||
|
and not group["capturable"]
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
state_steps.append(state["step"])
|
||||||
|
return has_complex
|
||||||
|
|
||||||
|
@_use_grad_for_differentiable
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Perform a single optimization step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
closure (Callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
self._cuda_graph_capture_health_check()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
params_with_grad: List[Tensor] = []
|
||||||
|
grads: List[Tensor] = []
|
||||||
|
exp_avgs: List[Tensor] = []
|
||||||
|
exp_avg_sqs: List[Tensor] = []
|
||||||
|
state_steps: List[Tensor] = []
|
||||||
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
|
has_complex = self._init_group(
|
||||||
|
group,
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
adopt(
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
has_complex=has_complex,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=group["lr"],
|
||||||
|
weight_decay=group["weight_decay"],
|
||||||
|
decoupled=group["decoupled"],
|
||||||
|
eps=group["eps"],
|
||||||
|
maximize=group["maximize"],
|
||||||
|
foreach=group["foreach"],
|
||||||
|
capturable=group["capturable"],
|
||||||
|
differentiable=group["differentiable"],
|
||||||
|
fused=group["fused"],
|
||||||
|
grad_scale=getattr(self, "grad_scale", None),
|
||||||
|
found_inf=getattr(self, "found_inf", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _single_tensor_adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
grad_scale: Optional[Tensor],
|
||||||
|
found_inf: Optional[Tensor],
|
||||||
|
*,
|
||||||
|
has_complex: bool,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
capturable: bool,
|
||||||
|
differentiable: bool,
|
||||||
|
):
|
||||||
|
assert grad_scale is None and found_inf is None
|
||||||
|
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||||
|
# have overloads to handle both float and Tensor lrs, so we just assert it's
|
||||||
|
# a float since most people using JIT are using floats
|
||||||
|
assert isinstance(lr, float)
|
||||||
|
|
||||||
|
for i, param in enumerate(params):
|
||||||
|
grad = grads[i] if not maximize else -grads[i]
|
||||||
|
exp_avg = exp_avgs[i]
|
||||||
|
exp_avg_sq = exp_avg_sqs[i]
|
||||||
|
step_t = state_steps[i]
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if not torch._utils.is_compiling() and capturable:
|
||||||
|
capturable_supported_devices = _get_capturable_supported_devices()
|
||||||
|
assert (
|
||||||
|
param.device.type == step_t.device.type
|
||||||
|
and param.device.type in capturable_supported_devices
|
||||||
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
|
# update step
|
||||||
|
step_t += 1
|
||||||
|
|
||||||
|
if weight_decay != 0:
|
||||||
|
if decoupled:
|
||||||
|
param.add_(param, alpha=-lr * weight_decay)
|
||||||
|
else:
|
||||||
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
|
|
||||||
|
if torch.is_complex(param):
|
||||||
|
grad = torch.view_as_real(grad)
|
||||||
|
if exp_avg is not None:
|
||||||
|
exp_avg = torch.view_as_real(exp_avg)
|
||||||
|
if exp_avg_sq is not None:
|
||||||
|
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||||
|
param = torch.view_as_real(param)
|
||||||
|
|
||||||
|
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||||
|
if step == 1:
|
||||||
|
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||||
|
continue
|
||||||
|
|
||||||
|
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||||
|
if step == 2:
|
||||||
|
exp_avg.addcdiv_(grad, denom)
|
||||||
|
else:
|
||||||
|
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
||||||
|
|
||||||
|
param.add_(exp_avg, alpha=-lr)
|
||||||
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||||
|
|
||||||
|
|
||||||
|
def _multi_tensor_adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
grad_scale: Optional[Tensor],
|
||||||
|
found_inf: Optional[Tensor],
|
||||||
|
*,
|
||||||
|
has_complex: bool,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
capturable: bool,
|
||||||
|
differentiable: bool,
|
||||||
|
):
|
||||||
|
if len(params) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(lr, Tensor) and not capturable:
|
||||||
|
raise RuntimeError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if not torch._utils.is_compiling() and capturable:
|
||||||
|
capturable_supported_devices = _get_capturable_supported_devices(
|
||||||
|
supports_xla=False
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
p.device.type == step.device.type
|
||||||
|
and p.device.type in capturable_supported_devices
|
||||||
|
for p, step in zip(params, state_steps)
|
||||||
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
|
assert grad_scale is None and found_inf is None
|
||||||
|
|
||||||
|
assert not differentiable, "_foreach ops don't support autograd"
|
||||||
|
|
||||||
|
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||||
|
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||||
|
)
|
||||||
|
for (
|
||||||
|
device_params_,
|
||||||
|
device_grads_,
|
||||||
|
device_exp_avgs_,
|
||||||
|
device_exp_avg_sqs_,
|
||||||
|
device_state_steps_,
|
||||||
|
), _ in grouped_tensors.values():
|
||||||
|
device_params = cast(List[Tensor], device_params_)
|
||||||
|
device_grads = cast(List[Tensor], device_grads_)
|
||||||
|
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||||
|
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||||
|
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||||
|
|
||||||
|
# Handle complex parameters
|
||||||
|
if has_complex:
|
||||||
|
_view_as_real(
|
||||||
|
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
||||||
|
)
|
||||||
|
|
||||||
|
if maximize:
|
||||||
|
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Update steps
|
||||||
|
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||||
|
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||||
|
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||||
|
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch._foreach_add_(device_state_steps, 1)
|
||||||
|
|
||||||
|
if weight_decay != 0:
|
||||||
|
if decoupled:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_params, device_params, alpha=-lr * weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||||
|
if maximize:
|
||||||
|
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||||
|
else:
|
||||||
|
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||||
|
device_grads, device_params, alpha=weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 1:
|
||||||
|
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||||
|
continue
|
||||||
|
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 2:
|
||||||
|
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
||||||
|
else:
|
||||||
|
torch._foreach_mul_(device_exp_avgs, beta1)
|
||||||
|
torch._foreach_addcdiv_(
|
||||||
|
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
||||||
|
)
|
||||||
|
|
||||||
|
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||||
|
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||||
|
torch._foreach_addcmul_(
|
||||||
|
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||||
|
def adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||||
|
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
capturable: bool = False,
|
||||||
|
differentiable: bool = False,
|
||||||
|
fused: Optional[bool] = None,
|
||||||
|
grad_scale: Optional[Tensor] = None,
|
||||||
|
found_inf: Optional[Tensor] = None,
|
||||||
|
has_complex: bool = False,
|
||||||
|
*,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
):
|
||||||
|
r"""Functional API that performs ADOPT algorithm computation."""
|
||||||
|
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||||
|
# the default when neither have been user-specified. Note that we default to foreach
|
||||||
|
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||||
|
# bake-in time before making it the default, even if it is typically faster.
|
||||||
|
if fused is None and foreach is None:
|
||||||
|
_, foreach = _default_to_fused_or_foreach(
|
||||||
|
params, differentiable, use_fused=False
|
||||||
|
)
|
||||||
|
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||||
|
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||||
|
foreach = False
|
||||||
|
if fused is None:
|
||||||
|
fused = False
|
||||||
|
if foreach is None:
|
||||||
|
foreach = False
|
||||||
|
|
||||||
|
# this check is slow during compilation, so we skip it
|
||||||
|
# if it's strictly needed we can add this check back in dynamo
|
||||||
|
if not torch._utils.is_compiling() and not all(
|
||||||
|
isinstance(t, torch.Tensor) for t in state_steps
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
if foreach and torch.jit.is_scripting():
|
||||||
|
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||||
|
if fused and torch.jit.is_scripting():
|
||||||
|
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||||
|
|
||||||
|
# if fused and not torch.jit.is_scripting():
|
||||||
|
# func = _fused_adopt
|
||||||
|
# elif foreach and not torch.jit.is_scripting():
|
||||||
|
if foreach and not torch.jit.is_scripting():
|
||||||
|
func = _multi_tensor_adopt
|
||||||
|
else:
|
||||||
|
func = _single_tensor_adopt
|
||||||
|
|
||||||
|
func(
|
||||||
|
params,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
has_complex=has_complex,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled=decoupled,
|
||||||
|
eps=eps,
|
||||||
|
maximize=maximize,
|
||||||
|
capturable=capturable,
|
||||||
|
differentiable=differentiable,
|
||||||
|
grad_scale=grad_scale,
|
||||||
|
found_inf=found_inf,
|
||||||
|
)
|
||||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
from .utils import require_torch_2_5_1, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -65,3 +65,46 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
@require_torch_2_5_1
|
||||||
|
def test_adopt_adamw(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adopt_adamw",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from importlib.metadata import version
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# from importlib.metadata import version
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
def with_temp_dir(test_func):
|
def with_temp_dir(test_func):
|
||||||
@wraps(test_func)
|
@wraps(test_func)
|
||||||
@@ -43,12 +45,24 @@ def require_torch_2_3_1(test_case):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def is_min_2_3_1():
|
def is_min_2_3_1():
|
||||||
torch_version = version("torch")
|
torch_version = version.parse(torch.__version__)
|
||||||
return torch_version >= "2.3.1"
|
return torch_version >= version.parse("2.3.1")
|
||||||
|
|
||||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_2_5_1(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires torch >= 2.3.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_min_2_5_1():
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
return torch_version >= version.parse("2.5.1")
|
||||||
|
|
||||||
|
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def is_hopper():
|
def is_hopper():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
return compute_capability == (9, 0)
|
return compute_capability == (9, 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user