feat: add support and end-to-end tests for multiple custom optimizers… (#3457) [skip ci]
* feat: add support and end-to-end tests for multiple custom optimizers including Optimi AdamW, ADOPT AdamW, Muon, Dion, Schedule-Free AdamW, CAME PyTorch, and Flash AdamW. * feat: Add standalone flashoptim integration test and E2E tests for various custom optimizers including FlashAdamW, FlashAdam, FlashSGD, FlashSGDW, FlashLion, optimi_adamw, adopt_adamw, muon, dion, and schedule_free_adamw. * feat: introduce Pydantic schema validation for dataset, attention, and training configurations. * feat: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers. * test: add e2e tests for custom optimizers including optimi_adamw, adopt_adamw, muon, dion, schedule_free_adamw, came_pytorch, and flash optimizers. * test: fix assertion in flash optimizers test to compare class names directly * fix: address PR review - reuse require_torch_2_7_0 decorator, remove fsdp_config.version check, extract shared FSDP version helper, remove unused imports and optim_args * chore: lint --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -353,6 +353,30 @@ class TrainerBuilderBase(abc.ABC):
|
||||
adam_kwargs["eps"] = (eps1, eps2)
|
||||
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_adamw":
|
||||
from flashoptim import FlashAdamW
|
||||
|
||||
optimizer_cls = FlashAdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_adam":
|
||||
from flashoptim import FlashAdam
|
||||
|
||||
optimizer_cls = FlashAdam
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_sgd":
|
||||
from flashoptim import FlashSGD
|
||||
|
||||
optimizer_cls = FlashSGD
|
||||
elif self.cfg.optimizer == "flash_sgdw":
|
||||
from flashoptim import FlashSGDW
|
||||
|
||||
optimizer_cls = FlashSGDW
|
||||
elif self.cfg.optimizer == "flash_lion":
|
||||
from flashoptim import FlashLion
|
||||
|
||||
optimizer_cls = FlashLion
|
||||
if "betas" in adam_kwargs:
|
||||
optimizer_kwargs["betas"] = adam_kwargs["betas"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
|
||||
|
||||
@@ -87,6 +87,11 @@ class CustomSupportedOptimizers(str, Enum):
|
||||
came_pytorch = "came_pytorch"
|
||||
muon = "muon"
|
||||
dion = "dion"
|
||||
flash_adamw = "flash_adamw"
|
||||
flash_adam = "flash_adam"
|
||||
flash_sgd = "flash_sgd"
|
||||
flash_sgdw = "flash_sgdw"
|
||||
flash_lion = "flash_lion"
|
||||
|
||||
|
||||
class RingAttnFunc(str, Enum):
|
||||
|
||||
@@ -790,6 +790,14 @@ class OptimizationValidationMixin:
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _resolve_fsdp_version(data):
|
||||
"""Resolve FSDP version from top-level fsdp_version or fsdp_config.fsdp_version."""
|
||||
fsdp_version = data.get("fsdp_version")
|
||||
if fsdp_version is None:
|
||||
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||
return fsdp_version
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
@@ -799,15 +807,32 @@ class OptimizationValidationMixin:
|
||||
"Muon optimizer is currently incompatible with DeepSpeed"
|
||||
)
|
||||
if data.get("fsdp") or data.get("fsdp_config"):
|
||||
fsdp_version = data.get("fsdp_version")
|
||||
if fsdp_version is None:
|
||||
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||
fsdp_version = cls._resolve_fsdp_version(data)
|
||||
if str(fsdp_version) != "2":
|
||||
raise ValueError(
|
||||
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_flashoptim_deepspeed_fsdp(cls, data):
|
||||
optimizer = data.get("optimizer") or ""
|
||||
if str(optimizer).startswith("flash_"):
|
||||
if data.get("deepspeed"):
|
||||
raise ValueError(
|
||||
f"{optimizer} optimizer is incompatible with DeepSpeed. "
|
||||
"Flash optimizers only support DDP and FSDP2."
|
||||
)
|
||||
if data.get("fsdp") or data.get("fsdp_config"):
|
||||
fsdp_version = cls._resolve_fsdp_version(data)
|
||||
if str(fsdp_version) != "2":
|
||||
raise ValueError(
|
||||
f"{optimizer} optimizer is only compatible with FSDP2. "
|
||||
"Set fsdp_version: 2 to use flash optimizers with FSDP."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_flattening_fa(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user