ADOPT optimizer integration (#2032) [skip ci]
* adopt integration * stuff * doc and test for ADOPT * rearrangement * fixed formatting * hacking pre-commit * chore: lint * update module doc for adopt optimizer * remove un-necessary example yaml for adopt optimizer * skip test adopt if torch<2.5.1 * formatting * use version.parse * specifies required torch version for adopt_adamw --------- Co-authored-by: sunny <sunnyliu19981005@gmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -405,6 +405,7 @@ lr_div_factor: # Learning rate div factor
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - sgd
|
||||
|
||||
@@ -436,7 +436,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.alternate_optimizer
|
||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
@@ -505,6 +511,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -1625,11 +1639,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.reward_model:
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if self.cfg.optimizer in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]:
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
|
||||
@@ -428,6 +428,7 @@ class HyperparametersConfig(BaseModel):
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
],
|
||||
]
|
||||
] = OptimizerNames.ADAMW_HF.value
|
||||
|
||||
508
src/axolotl/utils/optimizers/adopt.py
Normal file
508
src/axolotl/utils/optimizers/adopt.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""
|
||||
Copied from https://github.com/iShohei220/adopt
|
||||
|
||||
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
|
||||
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
|
||||
"""
|
||||
# mypy: ignore-errors
|
||||
# pylint: skip-file
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import (
|
||||
Optimizer,
|
||||
ParamsT,
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_get_capturable_supported_devices,
|
||||
_get_scalar_dtype,
|
||||
_get_value,
|
||||
_use_grad_for_differentiable,
|
||||
_view_as_real,
|
||||
)
|
||||
|
||||
__all__ = ["ADOPT", "adopt"]
|
||||
|
||||
|
||||
class ADOPT(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
decoupled: bool = False,
|
||||
*,
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
):
|
||||
if isinstance(lr, Tensor):
|
||||
if foreach and not capturable:
|
||||
raise ValueError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
if lr.numel() != 1:
|
||||
raise ValueError("Tensor lr must be 1-element")
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
decoupled=decoupled,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
fused=fused,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
if fused:
|
||||
# TODO: support fused
|
||||
raise RuntimeError("`fused` is not currently supported")
|
||||
|
||||
if differentiable:
|
||||
raise RuntimeError("`fused` does not support `differentiable`")
|
||||
self._step_supports_amp_scaling = True
|
||||
# TODO(crcrpar): [low prec params & their higher prec copy]
|
||||
# Support AMP with FP16/BF16 model params which would need
|
||||
# higher prec copy of params to do update math in higher prec to
|
||||
# alleviate the loss of information.
|
||||
if foreach:
|
||||
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault("maximize", False)
|
||||
group.setdefault("foreach", None)
|
||||
group.setdefault("capturable", False)
|
||||
group.setdefault("differentiable", False)
|
||||
fused = group.setdefault("fused", None)
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (
|
||||
torch.tensor(
|
||||
step_val,
|
||||
dtype=_get_scalar_dtype(is_fused=fused),
|
||||
device=p.device,
|
||||
)
|
||||
if group["capturable"] or group["fused"]
|
||||
else torch.tensor(step_val, dtype=_get_scalar_dtype())
|
||||
)
|
||||
|
||||
def _init_group(
|
||||
self,
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
):
|
||||
has_complex = False
|
||||
for p in group["params"]:
|
||||
if p.grad is not None:
|
||||
has_complex |= torch.is_complex(p)
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError("ADOPT does not support sparse gradients")
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
# Lazy state initialization
|
||||
if len(state) == 0:
|
||||
if group["fused"]:
|
||||
_device_dtype_check_for_fused(p)
|
||||
# note(crcrpar): [special device hosting for step]
|
||||
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
state["step"] = (
|
||||
torch.zeros(
|
||||
(),
|
||||
dtype=_get_scalar_dtype(is_fused=group["fused"]),
|
||||
device=p.device,
|
||||
)
|
||||
if group["capturable"] or group["fused"]
|
||||
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||
)
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
exp_avgs.append(state["exp_avg"])
|
||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||
|
||||
if group["differentiable"] and state["step"].requires_grad:
|
||||
raise RuntimeError(
|
||||
"`requires_grad` is not supported for `step` in differentiable mode"
|
||||
)
|
||||
|
||||
# Foreach without capturable does not support a tensor lr
|
||||
if (
|
||||
group["foreach"]
|
||||
and torch.is_tensor(group["lr"])
|
||||
and not group["capturable"]
|
||||
):
|
||||
raise RuntimeError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
|
||||
state_steps.append(state["step"])
|
||||
return has_complex
|
||||
|
||||
@_use_grad_for_differentiable
|
||||
def step(self, closure=None):
|
||||
"""Perform a single optimization step.
|
||||
|
||||
Args:
|
||||
closure (Callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
self._cuda_graph_capture_health_check()
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
has_complex = self._init_group(
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
)
|
||||
|
||||
adopt(
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
has_complex=has_complex,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group["lr"],
|
||||
weight_decay=group["weight_decay"],
|
||||
decoupled=group["decoupled"],
|
||||
eps=group["eps"],
|
||||
maximize=group["maximize"],
|
||||
foreach=group["foreach"],
|
||||
capturable=group["capturable"],
|
||||
differentiable=group["differentiable"],
|
||||
fused=group["fused"],
|
||||
grad_scale=getattr(self, "grad_scale", None),
|
||||
found_inf=getattr(self, "found_inf", None),
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _single_tensor_adopt(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
has_complex: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
):
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||
# have overloads to handle both float and Tensor lrs, so we just assert it's
|
||||
# a float since most people using JIT are using floats
|
||||
assert isinstance(lr, float)
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
exp_avg = exp_avgs[i]
|
||||
exp_avg_sq = exp_avg_sqs[i]
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
param.add_(param, alpha=-lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
if exp_avg is not None:
|
||||
exp_avg = torch.view_as_real(exp_avg)
|
||||
if exp_avg_sq is not None:
|
||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||
param = torch.view_as_real(param)
|
||||
|
||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||
if step == 1:
|
||||
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||
continue
|
||||
|
||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||
if step == 2:
|
||||
exp_avg.addcdiv_(grad, denom)
|
||||
else:
|
||||
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
||||
|
||||
param.add_(exp_avg, alpha=-lr)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||
|
||||
|
||||
def _multi_tensor_adopt(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
has_complex: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
):
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
if isinstance(lr, Tensor) and not capturable:
|
||||
raise RuntimeError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
assert all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||
)
|
||||
for (
|
||||
device_params_,
|
||||
device_grads_,
|
||||
device_exp_avgs_,
|
||||
device_exp_avg_sqs_,
|
||||
device_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
|
||||
# Handle complex parameters
|
||||
if has_complex:
|
||||
_view_as_real(
|
||||
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
||||
)
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
torch._foreach_add_(
|
||||
device_params, device_params, alpha=-lr * weight_decay
|
||||
)
|
||||
else:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
|
||||
if device_state_steps[0] == 1:
|
||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||
continue
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
||||
|
||||
if device_state_steps[0] == 2:
|
||||
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
||||
else:
|
||||
torch._foreach_mul_(device_exp_avgs, beta1)
|
||||
torch._foreach_addcdiv_(
|
||||
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
||||
)
|
||||
|
||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
torch._foreach_addcmul_(
|
||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||
)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||
def adopt(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs ADOPT algorithm computation."""
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||
foreach = False
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
foreach = False
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
if fused and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||
|
||||
# if fused and not torch.jit.is_scripting():
|
||||
# func = _fused_adopt
|
||||
# elif foreach and not torch.jit.is_scripting():
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adopt
|
||||
else:
|
||||
func = _single_tensor_adopt
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
has_complex=has_complex,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
decoupled=decoupled,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
)
|
||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import require_torch_2_5_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -65,3 +65,46 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
def test_adopt_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adopt_adamw",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@@ -6,11 +6,13 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# from importlib.metadata import version
|
||||
from packaging import version
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
@@ -43,12 +45,24 @@ def require_torch_2_3_1(test_case):
|
||||
"""
|
||||
|
||||
def is_min_2_3_1():
|
||||
torch_version = version("torch")
|
||||
return torch_version >= "2.3.1"
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.3.1")
|
||||
|
||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_5_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.3.1
|
||||
"""
|
||||
|
||||
def is_min_2_5_1():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.5.1")
|
||||
|
||||
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
Reference in New Issue
Block a user