various fixes 20250305 (#2384)
* various validation fixes * fix check for non-truthy value
This commit is contained in:
@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class SpectrumArgs(BaseModel):
|
||||
@@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel):
|
||||
|
||||
spectrum_top_fraction: Optional[float] = 0.5
|
||||
spectrum_model_name: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_use_orig_params(cls, data):
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and data.get("fsdp_config")
|
||||
and not data["fsdp_config"].get("use_orig_params")
|
||||
and data.get("plugins")
|
||||
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
|
||||
):
|
||||
# would otherwise raise
|
||||
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
|
||||
raise ValueError(
|
||||
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -778,9 +778,9 @@ class AxolotlInputConfig(
|
||||
|
||||
# torch_dtype: Optional[torch.dtype]
|
||||
|
||||
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
|
||||
default=False
|
||||
)
|
||||
gradient_checkpointing: Optional[
|
||||
Union[Literal["unsloth", "offload"], bool]
|
||||
] = Field(default=False)
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
unfrozen_parameters: Optional[List[str]] = None
|
||||
@@ -1154,6 +1154,15 @@ class AxolotlInputConfig(
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_offload_grad_checkpointing(self):
|
||||
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
||||
LOG.warning(
|
||||
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
||||
)
|
||||
self.gradient_checkpointing = "offload"
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_better_transformers(self):
|
||||
if self.flash_optimum is True:
|
||||
|
||||
@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
)
|
||||
|
||||
|
||||
def hf_grad_checkpoint_unsloth_wrapper(
|
||||
def hf_grad_checkpoint_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
|
||||
@@ -64,7 +64,7 @@ from axolotl.utils.distributed import (
|
||||
is_local_main_process,
|
||||
zero_only,
|
||||
)
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
|
||||
@@ -493,8 +493,8 @@ class ModelLoader:
|
||||
|
||||
patch_fa_peft_integration()
|
||||
|
||||
if self.cfg.gradient_checkpointing == "unsloth":
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
self.patch_attention()
|
||||
|
||||
Reference in New Issue
Block a user