Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)

* added torch check for adopt, wip

* lint

* gonna put torch version checking somewhere else

* added ENVcapabilities class for torch version checking

* lint + pydantic

* ENVCapabilities -> EnvCapabilities

* forgot to git add v0_4_1/__init__.py

* removed redundancy

* add check if env_capabilities not specified

* make env_capabilities compulsory [skip e2e]

* fixup env_capabilities

* modified test_validation.py to accomodate env_capabilities

* adopt torch version test [skip e2e]

* raise error

* test correct torch version

* test torch version above requirement

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* removed unused is_totch_min

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Sunny Liu
2024-12-02 20:15:39 -05:00
committed by GitHub
parent 9f6d0b5587
commit d5f58b6509
10 changed files with 189 additions and 66 deletions

View File

@@ -409,7 +409,7 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused # - adamw_torch_fused
# - adamw_torch_xla # - adamw_torch_xla
# - adamw_apex_fused # - adamw_apex_fused
# - adopt_adamw (only for torch version >= 2.5.1) # - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adafactor # - adafactor
# - adamw_anyprecision # - adamw_anyprecision
# - sgd # - sgd

View File

@@ -100,8 +100,8 @@ def print_dep_versions():
print("*" * 40) print("*" * 40)
print("**** Axolotl Dependency Versions *****") print("**** Axolotl Dependency Versions *****")
for pkg in packages: for pkg in packages:
version = _is_package_available(pkg, return_version=True) pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}") print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
print("*" * 40) print("*" * 40)
@@ -444,6 +444,9 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version, "compute_capability": gpu_version,
}, },
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
) )
prepare_optim_env(cfg) prepare_optim_env(cfg)

View File

@@ -562,7 +562,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT( ADOPT(
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
) )
) )

View File

@@ -229,7 +229,11 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase AxolotlInputConfig = AxolotlInputConfigBase
@@ -239,14 +243,24 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
AxolotlInputConfig, # pylint: disable=invalid-name AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args() ) = merge_input_args()
if capabilities: if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."
)
return DictDefault( return DictDefault(
dict( dict(
AxolotlConfigWCapabilities( AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities **cfg.to_dict(),
capabilities=capabilities,
env_capabilities=env_capabilities,
).model_dump(exclude_none=True) ).model_dump(exclude_none=True)
) )
) )
return DictDefault( return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True)) dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
) )

View File

@@ -9,6 +9,7 @@ import os
from enum import Enum from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from packaging import version
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
Field, Field,
@@ -21,7 +22,7 @@ from transformers import SchedulerType
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import GPUCapabilities from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
LOG = logging.getLogger("axolotl.utils.config.models.input") LOG = logging.getLogger("axolotl.utils.config.models.input")
@@ -1477,6 +1478,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options""" """wrapper to valdiate gpu capabilities with the configured options"""
capabilities: GPUCapabilities capabilities: GPUCapabilities
env_capabilities: EnvCapabilities
@model_validator(mode="after") @model_validator(mode="after")
def check_bf16(self): def check_bf16(self):
@@ -1551,3 +1553,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training." "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data

View File

@@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel):
n_gpu: int = Field(default=1) n_gpu: int = Field(default=1)
n_node: int = Field(default=1) n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None) compute_capability: Optional[str] = Field(default=None)
class EnvCapabilities(BaseModel):
"""model to manage the environment capabilities statically"""
torch_version: Optional[str] = Field(default=None)

View File

@@ -6,21 +6,29 @@ Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeo
""" """
# mypy: ignore-errors # mypy: ignore-errors
# pylint: skip-file # pylint: skip-file
# flake8: noqa
# mypy: allow-untyped-decorators # mypy: allow-untyped-decorators
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union, cast from typing import Callable, List, Optional, Tuple, Union, cast
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim.optimizer import ( from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
DeviceDict,
Optimizer, Optimizer,
ParamsT, ParamsT,
_capturable_doc,
_default_to_fused_or_foreach, _default_to_fused_or_foreach,
_device_dtype_check_for_fused, _device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported, _disable_dynamo_if_unsupported,
_foreach_doc,
_fused_doc,
_get_capturable_supported_devices, _get_capturable_supported_devices,
_get_scalar_dtype, _get_scalar_dtype,
_get_value, _get_value,
_maximize_doc,
_stack_if_compiling,
_use_grad_for_differentiable, _use_grad_for_differentiable,
_view_as_real, _view_as_real,
) )
@@ -35,8 +43,9 @@ class ADOPT(Optimizer):
lr: Union[float, Tensor] = 1e-3, lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999), betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6, eps: float = 1e-6,
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
weight_decay: float = 0.0, weight_decay: float = 0.0,
decoupled: bool = False, decouple: bool = False,
*, *,
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
maximize: bool = False, maximize: bool = False,
@@ -62,12 +71,14 @@ class ADOPT(Optimizer):
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
self.clip_lambda = clip_lambda
defaults = dict( defaults = dict(
lr=lr, lr=lr,
betas=betas, betas=betas,
eps=eps, eps=eps,
weight_decay=weight_decay, weight_decay=weight_decay,
decoupled=decoupled, decouple=decouple,
maximize=maximize, maximize=maximize,
foreach=foreach, foreach=foreach,
capturable=capturable, capturable=capturable,
@@ -219,8 +230,9 @@ class ADOPT(Optimizer):
beta1=beta1, beta1=beta1,
beta2=beta2, beta2=beta2,
lr=group["lr"], lr=group["lr"],
clip_lambda=self.clip_lambda,
weight_decay=group["weight_decay"], weight_decay=group["weight_decay"],
decoupled=group["decoupled"], decouple=group["decouple"],
eps=group["eps"], eps=group["eps"],
maximize=group["maximize"], maximize=group["maximize"],
foreach=group["foreach"], foreach=group["foreach"],
@@ -247,8 +259,9 @@ def _single_tensor_adopt(
beta1: float, beta1: float,
beta2: float, beta2: float,
lr: Union[float, Tensor], lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float, weight_decay: float,
decoupled: bool, decouple: bool,
eps: float, eps: float,
maximize: bool, maximize: bool,
capturable: bool, capturable: bool,
@@ -276,14 +289,10 @@ def _single_tensor_adopt(
and param.device.type in capturable_supported_devices and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step step = step_t if capturable or differentiable else _get_value(step_t)
step_t += 1
if weight_decay != 0: if weight_decay != 0 and not decouple:
if decoupled: grad = grad.add(param, alpha=weight_decay)
param.add_(param, alpha=-lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param): if torch.is_complex(param):
grad = torch.view_as_real(grad) grad = torch.view_as_real(grad)
@@ -293,20 +302,29 @@ def _single_tensor_adopt(
exp_avg_sq = torch.view_as_real(exp_avg_sq) exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param) param = torch.view_as_real(param)
step = step_t if capturable or differentiable else _get_value(step_t) if step == 0:
if step == 1:
exp_avg_sq.addcmul_(grad, grad.conj()) exp_avg_sq.addcmul_(grad, grad.conj())
# update step
step_t += 1
continue continue
if weight_decay != 0 and decouple:
param.add_(param, alpha=-lr * weight_decay)
denom = torch.clamp(exp_avg_sq.sqrt(), eps) denom = torch.clamp(exp_avg_sq.sqrt(), eps)
if step == 2: normed_grad = grad.div(denom)
exp_avg.addcdiv_(grad, denom) if clip_lambda is not None:
else: clip = clip_lambda(step)
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) normed_grad.clamp_(-clip, clip)
exp_avg.lerp_(normed_grad, 1 - beta1)
param.add_(exp_avg, alpha=-lr) param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
# update step
step_t += 1
def _multi_tensor_adopt( def _multi_tensor_adopt(
params: List[Tensor], params: List[Tensor],
@@ -321,8 +339,9 @@ def _multi_tensor_adopt(
beta1: float, beta1: float,
beta2: float, beta2: float,
lr: Union[float, Tensor], lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float, weight_decay: float,
decoupled: bool, decouple: bool,
eps: float, eps: float,
maximize: bool, maximize: bool,
capturable: bool, capturable: bool,
@@ -376,6 +395,51 @@ def _multi_tensor_adopt(
if maximize: if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
if weight_decay != 0 and not decouple:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 0:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
continue
if weight_decay != 0 and decouple:
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
if clip_lambda is not None:
clip = clip_lambda(device_state_steps[0])
torch._foreach_maximum_(normed_grad, -clip)
torch._foreach_minimum_(normed_grad, clip)
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
# Update steps # Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
@@ -387,41 +451,6 @@ def _multi_tensor_adopt(
else: else:
torch._foreach_add_(device_state_steps, 1) torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
if decoupled:
torch._foreach_add_(
device_params, device_params, alpha=-lr * weight_decay
)
else:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 1:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
continue
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
if device_state_steps[0] == 2:
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
else:
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_addcdiv_(
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
def adopt( def adopt(
@@ -443,8 +472,9 @@ def adopt(
beta1: float, beta1: float,
beta2: float, beta2: float,
lr: Union[float, Tensor], lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float, weight_decay: float,
decoupled: bool, decouple: bool,
eps: float, eps: float,
maximize: bool, maximize: bool,
): ):
@@ -497,8 +527,9 @@ def adopt(
beta1=beta1, beta1=beta1,
beta2=beta2, beta2=beta2,
lr=lr, lr=lr,
clip_lambda=clip_lambda,
weight_decay=weight_decay, weight_decay=weight_decay,
decoupled=decoupled, decouple=decouple,
eps=eps, eps=eps,
maximize=maximize, maximize=maximize,
capturable=capturable, capturable=capturable,

View File

@@ -53,7 +53,7 @@ def require_torch_2_3_1(test_case):
def require_torch_2_5_1(test_case): def require_torch_2_5_1(test_case):
""" """
Decorator marking a test that requires torch >= 2.3.1 Decorator marking a test that requires torch >= 2.5.1
""" """
def is_min_2_5_1(): def is_min_2_5_1():

View File

@@ -672,6 +672,9 @@ class TestValidation(BaseValidation):
{ {
"bf16": True, "bf16": True,
"capabilities": {"bf16": False}, "capabilities": {"bf16": False},
"env_capabilities": {
"torch_version": "2.5.1",
},
} }
) )
| minimal_cfg | minimal_cfg
@@ -1160,6 +1163,38 @@ class TestValidation(BaseValidation):
in self._caplog.records[0].message in self._caplog.records[0].message
) )
def test_torch_version_adopt_req(self, minimal_cfg):
cfg = (
DictDefault(
{
"optimizer": "adopt_adamw",
}
)
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*ADOPT optimizer is incompatible with torch version*",
):
env_capabilities = {"torch_version": "2.3.0"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.1"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.2"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
class TestValidationCheckModelConfig(BaseValidation): class TestValidationCheckModelConfig(BaseValidation):
""" """

View File

@@ -72,6 +72,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1, "n_gpu": 1,
"compute_capability": "8.0", "compute_capability": "8.0",
}, },
env_capabilities={
"torch_version": "2.5.1",
},
) )
_check_config() _check_config()
@@ -124,6 +127,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1, "n_gpu": 1,
"compute_capability": "8.0", "compute_capability": "8.0",
}, },
env_capabilities={
"torch_version": "2.5.1",
},
) )
_check_config() _check_config()
@@ -177,6 +183,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1, "n_gpu": 1,
"compute_capability": "8.0", "compute_capability": "8.0",
}, },
env_capabilities={
"torch_version": "2.5.1",
},
) )
_check_config() _check_config()
@@ -231,6 +240,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1, "n_gpu": 1,
"compute_capability": "8.0", "compute_capability": "8.0",
}, },
env_capabilities={
"torch_version": "2.5.1",
},
) )
_check_config() _check_config()