From 5e21b1a9daee7bba2b06a1f181909eb4a9a834ec Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 6 Mar 2025 11:48:44 -0500 Subject: [PATCH] various fixes 20250305 (#2384) * various validation fixes * fix check for non-truthy value --- src/axolotl/integrations/spectrum/args.py | 19 ++++++++++++++++++- .../config/models/input/v0_4_1/__init__.py | 15 ++++++++++++--- .../utils/gradient_checkpointing/__init__.py | 2 +- src/axolotl/utils/models.py | 6 +++--- 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/axolotl/integrations/spectrum/args.py b/src/axolotl/integrations/spectrum/args.py index 03426d841..df5756038 100644 --- a/src/axolotl/integrations/spectrum/args.py +++ b/src/axolotl/integrations/spectrum/args.py @@ -17,7 +17,7 @@ Module for handling Spectrum input arguments. """ from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, model_validator class SpectrumArgs(BaseModel): @@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel): spectrum_top_fraction: Optional[float] = 0.5 spectrum_model_name: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def check_fsdp_use_orig_params(cls, data): + if ( + data.get("fsdp") + and data.get("fsdp_config") + and not data["fsdp_config"].get("use_orig_params") + and data.get("plugins") + and any("SpectrumPlugin" in plugin for plugin in data["plugins"]) + ): + # would otherwise raise + # ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False` + raise ValueError( + "FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set" + ) + return data diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index ce2586afb..5143469be 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -778,9 +778,9 @@ class AxolotlInputConfig( # torch_dtype: Optional[torch.dtype] - gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field( - default=False - ) + gradient_checkpointing: Optional[ + Union[Literal["unsloth", "offload"], bool] + ] = Field(default=False) gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None unfrozen_parameters: Optional[List[str]] = None @@ -1154,6 +1154,15 @@ class AxolotlInputConfig( raise ValueError("gradient_checkpointing is not supported for MPT models") return self + @model_validator(mode="after") + def check_offload_grad_checkpointing(self): + if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth": + LOG.warning( + "`unsloth` is deprecated for gradient_checkpointing, use `offload`" + ) + self.gradient_checkpointing = "offload" + return self + @model_validator(mode="after") def check_better_transformers(self): if self.flash_optimum is True: diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index 4639fc266..8bbf878ad 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import ( ) -def hf_grad_checkpoint_unsloth_wrapper( +def hf_grad_checkpoint_offload_wrapper( decoder_layer, *args, use_reentrant=None ): # pylint: disable=unused-argument return Unsloth_Offloaded_Gradient_Checkpointer.apply( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index add690d9d..1805a749a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -64,7 +64,7 @@ from axolotl.utils.distributed import ( is_local_main_process, zero_only, ) -from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper +from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant @@ -493,8 +493,8 @@ class ModelLoader: patch_fa_peft_integration() - if self.cfg.gradient_checkpointing == "unsloth": - transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper + if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: + transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper if self.cfg.flash_attention: self.patch_attention()