various fixes 20250305 (#2384)
* various validation fixes * fix check for non-truthy value
This commit is contained in:
@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
|
|||||||
"""
|
"""
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
|
||||||
class SpectrumArgs(BaseModel):
|
class SpectrumArgs(BaseModel):
|
||||||
@@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel):
|
|||||||
|
|
||||||
spectrum_top_fraction: Optional[float] = 0.5
|
spectrum_top_fraction: Optional[float] = 0.5
|
||||||
spectrum_model_name: Optional[str] = None
|
spectrum_model_name: Optional[str] = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_use_orig_params(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("fsdp")
|
||||||
|
and data.get("fsdp_config")
|
||||||
|
and not data["fsdp_config"].get("use_orig_params")
|
||||||
|
and data.get("plugins")
|
||||||
|
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
|
||||||
|
):
|
||||||
|
# would otherwise raise
|
||||||
|
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
|
||||||
|
raise ValueError(
|
||||||
|
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -778,9 +778,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: Optional[torch.dtype]
|
# torch_dtype: Optional[torch.dtype]
|
||||||
|
|
||||||
gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
|
gradient_checkpointing: Optional[
|
||||||
default=False
|
Union[Literal["unsloth", "offload"], bool]
|
||||||
)
|
] = Field(default=False)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
@@ -1154,6 +1154,15 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_offload_grad_checkpointing(self):
|
||||||
|
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
||||||
|
LOG.warning(
|
||||||
|
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
||||||
|
)
|
||||||
|
self.gradient_checkpointing = "offload"
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_better_transformers(self):
|
def check_better_transformers(self):
|
||||||
if self.flash_optimum is True:
|
if self.flash_optimum is True:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from axolotl.utils.gradient_checkpointing.unsloth import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def hf_grad_checkpoint_unsloth_wrapper(
|
def hf_grad_checkpoint_offload_wrapper(
|
||||||
decoder_layer, *args, use_reentrant=None
|
decoder_layer, *args, use_reentrant=None
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ from axolotl.utils.distributed import (
|
|||||||
is_local_main_process,
|
is_local_main_process,
|
||||||
zero_only,
|
zero_only,
|
||||||
)
|
)
|
||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
@@ -493,8 +493,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_fa_peft_integration()
|
patch_fa_peft_integration()
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing == "unsloth":
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|||||||
Reference in New Issue
Block a user