This commit is contained in:
Salman Mohammadi
2025-03-18 11:23:23 +00:00
parent 57b0ad1467
commit 690908cf2f

View File

@@ -126,8 +126,12 @@ class DeprecatedParameters(BaseModel):
class RemappedParameters(BaseModel):
"""parameters that have been remapped to other names"""
overrides_of_model_config: Optional[Dict[str, Any]] = Field(default=None, alias="model_config")
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(default=None, alias="model_kwargs")
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
default=None, alias="model_config"
)
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(
default=None, alias="model_kwargs"
)
type_of_model: Optional[str] = Field(default=None, alias="model_type")
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
@@ -196,8 +200,12 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None
field_messages: Optional[str] = None
message_field_role: Optional[str] = None # deprecated, use message_property_mappings
message_field_content: Optional[str] = None # deprecated, use message_property_mappings
message_field_role: Optional[
str
] = None # deprecated, use message_property_mappings
message_field_content: Optional[
str
] = None # deprecated, use message_property_mappings
message_property_mappings: Optional[Dict[str, str]] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
@@ -227,8 +235,12 @@ class SFTDataset(BaseModel):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"):
raise ValueError("chat_template_jinja is required when chat_template is set to jinja")
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
@@ -300,7 +312,9 @@ DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedData
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(default=4, json_schema_extra={"description": "Quantization bits for LoftQ"})
loftq_bits: int = Field(
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
)
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
@@ -345,7 +359,9 @@ class LoraConfig(BaseModel):
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "load qlora model in sharded format for FSDP using answer.ai technique."},
json_schema_extra={
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
@@ -353,11 +369,15 @@ class LoraConfig(BaseModel):
loraplus_lr_ratio: Optional[float] = Field(
default=None,
json_schema_extra={"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."},
json_schema_extra={
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
},
)
loraplus_lr_embedding: Optional[float] = Field(
default=1e-6,
json_schema_extra={"description": "loraplus learning rate for lora embedding layers."},
json_schema_extra={
"description": "loraplus learning rate for lora embedding layers."
},
)
merge_lora: Optional[bool] = None
@@ -465,11 +485,15 @@ class HyperparametersConfig(BaseModel):
)
batch_size: Optional[int] = Field(
default=None,
json_schema_extra={"description": "Total batch size, we do not recommended setting this manually"},
json_schema_extra={
"description": "Total batch size, we do not recommended setting this manually"
},
)
eval_batch_size: Optional[int] = Field(
default=None,
json_schema_extra={"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"},
json_schema_extra={
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
},
)
auto_find_batch_size: Optional[bool] = None
@@ -481,7 +505,9 @@ class HyperparametersConfig(BaseModel):
embedding_lr: Optional[float] = None
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = OptimizerNames.ADAMW_HF
optimizer: Optional[
Union[OptimizerNames, CustomSupportedOptimizers]
] = OptimizerNames.ADAMW_HF
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
@@ -493,7 +519,9 @@ class HyperparametersConfig(BaseModel):
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]] = SchedulerType.COSINE
lr_scheduler: Optional[
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
] = SchedulerType.COSINE
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
@@ -617,15 +645,21 @@ class RayConfig(BaseModel):
use_ray: bool = Field(default=False)
ray_run_name: Optional[str] = Field(
default=None,
json_schema_extra={"help": "The training results will be saved at `saves/ray_run_name`."},
json_schema_extra={
"help": "The training results will be saved at `saves/ray_run_name`."
},
)
ray_num_workers: int = Field(
default=1,
json_schema_extra={"help": "The number of workers for Ray training. Default is 1 worker."},
json_schema_extra={
"help": "The number of workers for Ray training. Default is 1 worker."
},
)
resources_per_worker: dict = Field(
default_factory=lambda: {"GPU": 1},
json_schema_extra={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
json_schema_extra={
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
)
@@ -665,9 +699,9 @@ class AxolotlInputConfig(
reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None
dpo_use_weighting: Optional[bool] = (
None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
)
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[
@@ -689,7 +723,9 @@ class AxolotlInputConfig(
dataset_shard_idx: Optional[int] = None
skip_prepare_dataset: Optional[bool] = False
pretraining_dataset: Optional[Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]] = Field(
pretraining_dataset: Optional[
Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]
] = Field(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
@@ -744,7 +780,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = Field(default=False)
gradient_checkpointing: Optional[
Union[Literal["unsloth", "offload"], bool]
] = Field(default=False)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
@@ -811,7 +849,9 @@ class AxolotlInputConfig(
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
fsdp_final_state_dict_type: Optional[Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]] = None
fsdp_final_state_dict_type: Optional[
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
] = None
val_set_size: Optional[float] = Field(default=0.0)
@@ -821,7 +861,9 @@ class AxolotlInputConfig(
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
torch_compile_mode: Optional[Literal["default", "reduce-overhead", "max-autotune"]] = None
torch_compile_mode: Optional[
Literal["default", "reduce-overhead", "max-autotune"]
] = None
max_steps: Optional[int] = None
warmup_steps: Optional[int] = None
@@ -852,7 +894,9 @@ class AxolotlInputConfig(
kto_undesirable_weight: Optional[float] = None
rl_beta: Optional[float] = None
max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = None
max_memory: Optional[
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
] = None
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None
@@ -888,7 +932,11 @@ class AxolotlInputConfig(
def deprecate_sharegpt_datasets(cls, datasets):
for _, ds_cfg in enumerate(datasets):
# Handle both dict and pydantic model cases
ds_type = ds_cfg.get("type") if isinstance(ds_cfg, dict) else getattr(ds_cfg, "type", None)
ds_type = (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
)
if not ds_type:
continue
@@ -897,12 +945,16 @@ class AxolotlInputConfig(
continue
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
raise ValueError("`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead.")
raise ValueError(
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
)
return datasets
@field_serializer("datasets")
def datasets_serializer(self, ds_configs: Optional[List[DatasetConfig]]) -> Optional[List[Dict[str, Any]]]:
def datasets_serializer(
self, ds_configs: Optional[List[DatasetConfig]]
) -> Optional[List[Dict[str, Any]]]:
if ds_configs:
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
@@ -968,7 +1020,9 @@ class AxolotlInputConfig(
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError("sample_packing not compatible with xformers_attention. Use flash_attention")
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@@ -976,8 +1030,12 @@ class AxolotlInputConfig(
@classmethod
def check_chat_template_config(cls, data):
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"):
raise ValueError("chat_template_jinja is required when chat_template is set to jinja")
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
@@ -988,8 +1046,14 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):
if data.get("sample_packing") and not data.get("flash_attention") and not data.get("sdp_attention"):
LOG.warning("sample_packing without flash_attention or sdp_attention does not handle cross-attention.")
if (
data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
):
LOG.warning(
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
)
return data
@@ -1028,14 +1092,18 @@ class AxolotlInputConfig(
@classmethod
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning("`pad_to_sequence_len: true` is recommended when using sample_packing")
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
return data
@model_validator(mode="before")
@classmethod
def hint_reward_model_pad(cls, data):
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
LOG.warning("`pad_to_sequence_len: true` is recommended when using reward_model")
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using reward_model"
)
if data.get("pad_to_sequence_len") is None:
data["pad_to_sequence_len"] = True
return data
@@ -1044,7 +1112,9 @@ class AxolotlInputConfig(
@classmethod
def check_gas_bsz(cls, data):
if data.get("gradient_accumulation_steps") and data.get("batch_size"):
raise ValueError("please set only one of gradient_accumulation_steps or batch_size")
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
return data
@model_validator(mode="before")
@@ -1055,14 +1125,21 @@ class AxolotlInputConfig(
and data.get("micro_batch_size")
and data.get("eval_batch_size") != data.get("micro_batch_size")
):
LOG.warning("eval_batch_size != micro_batch_size. This can lead to VRAM instability.")
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
return data
@model_validator(mode="before")
@classmethod
def check_push_ds_auth(cls, data):
if data.get("push_dataset_to_hub") and data.get("hf_use_auth_token") is not True:
raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub")
if (
data.get("push_dataset_to_hub")
and data.get("hf_use_auth_token") is not True
):
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
return data
@model_validator(mode="after")
@@ -1073,14 +1150,18 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_mpt_checkpointing(self):
if (self.base_model and "mpt" in self.base_model.lower()) and self.gradient_checkpointing:
if (
self.base_model and "mpt" in self.base_model.lower()
) and self.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning("`unsloth` is deprecated for gradient_checkpointing, use `offload`")
LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload"
return self
@@ -1088,7 +1169,9 @@ class AxolotlInputConfig(
def check_better_transformers(self):
if self.flash_optimum is True:
if self.adapter:
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
LOG.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if self.fp16 or self.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if self.float16 is not True and self.bfloat16 is not True:
@@ -1116,25 +1199,39 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_saves(cls, data):
if data.get("save_strategy") and data.get("save_steps") and data.get("save_strategy") != "steps":
if (
data.get("save_strategy")
and data.get("save_steps")
and data.get("save_strategy") != "steps"
):
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if data.get("saves_per_epoch") and data.get("save_steps"):
raise ValueError("save_steps and saves_per_epoch are mutually exclusive and cannot be used together.")
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_push_save(cls, data):
if data.get("hub_model_id") and (data.get("save_strategy") not in ["steps", "epoch", None]):
LOG.warning("hub_model_id is set without any models being saved. To save a model, set save_strategy.")
if data.get("hub_model_id") and (
data.get("save_strategy") not in ["steps", "epoch", None]
):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
)
return data
@model_validator(mode="before")
@classmethod
def check_evals(cls, data):
if data.get("eval_strategy") and data.get("eval_steps") and data.get("eval_strategy") != "steps":
if (
data.get("eval_strategy")
and data.get("eval_steps")
and data.get("eval_strategy") != "steps"
):
raise ValueError(
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
)
@@ -1144,21 +1241,41 @@ class AxolotlInputConfig(
and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets")
):
raise ValueError("eval_steps and eval_strategy are not supported with val_set_size == 0")
raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0"
)
if data.get("evals_per_epoch") and data.get("eval_steps"):
raise ValueError("eval_steps and evals_per_epoch are mutually exclusive and cannot be used together.")
if data.get("evals_per_epoch") and data.get("eval_strategy") and data.get("eval_strategy") != "steps":
raise ValueError("eval_strategy must be empty or set to `steps` when used with evals_per_epoch.")
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
)
if (
data.get("evals_per_epoch")
and data.get("eval_strategy")
and data.get("eval_strategy") != "steps"
):
raise ValueError(
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if data.get("do_bench_eval") and not (data.get("evals_per_epoch") or data.get("eval_steps")):
raise ValueError("do_bench_eval requires evals_per_epoch or eval_steps to be set.")
if data.get("do_bench_eval") and not (
data.get("evals_per_epoch") or data.get("eval_steps")
):
raise ValueError(
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
)
return data
@model_validator(mode="before")
@classmethod
def check_test_datasets_bench(cls, data):
if data.get("do_bench_eval") and not data.get("test_datasets") and not data.get("val_set_size"):
LOG.warning("`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset.")
if (
data.get("do_bench_eval")
and not data.get("test_datasets")
and not data.get("val_set_size")
):
LOG.warning(
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
)
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
return data
@@ -1167,12 +1284,22 @@ class AxolotlInputConfig(
def check_eval_packing(cls, data):
# TODO also should check test_datasets and val_set_size as we can skip
# if there are no eval datasets/splits
if data.get("sample_packing") and data.get("eval_table_size") and data.get("eval_sample_packing") is not False:
if (
data.get("sample_packing")
and data.get("eval_table_size")
and data.get("eval_sample_packing") is not False
):
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
if data.get("sample_packing") and data.get("eval_sample_packing") is None and not data.get("eval_table_size"):
LOG.info("explicitly setting `eval_sample_packing` to match `sample_packing`")
if (
data.get("sample_packing")
and data.get("eval_sample_packing") is None
and not data.get("eval_table_size")
):
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`"
)
data["eval_sample_packing"] = True
if (
@@ -1192,7 +1319,9 @@ class AxolotlInputConfig(
def check_mm_prepare(cls, data):
if data.get("skip_prepare_dataset"):
if data.get("remove_unused_columns") is None:
LOG.info("setting `remove_unused_columns: false` for skip_prepare_dataset")
LOG.info(
"setting `remove_unused_columns: false` for skip_prepare_dataset"
)
data["remove_unused_columns"] = False
return data
@@ -1233,13 +1362,19 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_simpo_warmup(self):
if self.rl == "simpo" and self.warmup_ratio:
raise ValueError("warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead")
raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)
return self
@model_validator(mode="before")
@classmethod
def check_frozen(cls, data):
if data.get("adapter") and data.get("peft_layers_to_transform") and data.get("unfrozen_parameters"):
if (
data.get("adapter")
and data.get("peft_layers_to_transform")
and data.get("unfrozen_parameters")
):
raise ValueError(
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
)
@@ -1250,7 +1385,9 @@ class AxolotlInputConfig(
@classmethod
def check_peft_layers_pattern(cls, data):
if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
raise ValueError("peft_layers_pattern requires peft_layers_to_transform to be set")
raise ValueError(
"peft_layers_pattern requires peft_layers_to_transform to be set"
)
return data
@model_validator(mode="after")
@@ -1273,13 +1410,17 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_fused_lora(self):
if self.adapter in ["lora", "qlora"] and (self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp):
if self.adapter in ["lora", "qlora"] and (
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
):
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
return self
@model_validator(mode="after")
def hint_lora_8bit(self):
loftq = self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
loftq = (
self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
)
if not self.load_in_8bit and self.adapter == "lora" and not loftq:
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
return self
@@ -1292,7 +1433,9 @@ class AxolotlInputConfig(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
)
if self.save_steps % self.eval_steps != 0:
raise ValueError("`early_stopping_patience` requires that eval_steps should evenly divide save_steps.")
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
return self
@model_validator(mode="after")
@@ -1308,7 +1451,9 @@ class AxolotlInputConfig(
raise ValueError("deepspeed not supported with ReLoRA")
if self.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
raise ValueError(
"ReLoRA is not compatible with the one_cycle scheduler"
)
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
@@ -1317,8 +1462,13 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_mem_mismatch(cls, data):
if data.get("max_memory") is not None and data.get("gpu_memory_limit") is not None:
raise ValueError("max_memory and gpu_memory_limit are mutually exclusive and cannot be used together.")
if (
data.get("max_memory") is not None
and data.get("gpu_memory_limit") is not None
):
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
)
return data
@model_validator(mode="before")
@@ -1327,10 +1477,13 @@ class AxolotlInputConfig(
if (
data.get("unfrozen_parameters")
and data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is True
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is True
):
# https://github.com/huggingface/transformers/issues/21381
raise ValueError("`use_reentrant` must be false when used with partially frozen model.")
raise ValueError(
"`use_reentrant` must be false when used with partially frozen model."
)
return data
@model_validator(mode="before")
@@ -1339,7 +1492,8 @@ class AxolotlInputConfig(
if (
data.get("adapter") == "qlora"
and data.get("gradient_checkpointing_kwargs", {})
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is False
and data.get("deepspeed", "") is not None
and "zero3" in data.get("deepspeed", "")
):
@@ -1356,14 +1510,21 @@ class AxolotlInputConfig(
@classmethod
def check_val_w_test_datasets(cls, data):
if data.get("test_datasets") and data.get("val_set_size"):
raise ValueError("non-zero val_set_size should not be used with test_datasets configuration")
raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)
return data
@model_validator(mode="before")
@classmethod
def check_eval_strategy(cls, data):
if data.get("evaluation_strategy") is not None and data.get("eval_strategy") is None:
LOG.info("explicitly setting `eval_strategy` from the `evaluation_strategy`")
if (
data.get("evaluation_strategy") is not None
and data.get("eval_strategy") is None
):
LOG.info(
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
)
data["eval_strategy"] = data.get("evaluation_strategy")
return data
@@ -1376,7 +1537,9 @@ class AxolotlInputConfig(
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params")
):
raise ValueError(f"FSDP Offload not compatible with {data.get('optimizer')}")
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
)
return data
@model_validator(mode="before")
@@ -1388,21 +1551,27 @@ class AxolotlInputConfig(
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
):
raise ValueError("FSDP SHARDED_STATE_DICT not compatible with save_safetensors")
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
return data
@model_validator(mode="before")
@classmethod
def check_causal_lm_evals(cls, data):
if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
raise ValueError("do_causal_lm_eval is enabled, eval_sample_packing must be set to False")
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if data.get("eval_causal_lm_metrics"):
if not isinstance(data.get("eval_causal_lm_metrics"), list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
raise ValueError(f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}")
raise ValueError(
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)
return data
@model_validator(mode="before")
@@ -1415,14 +1584,22 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_xentropy_patch_conflicts(cls, data):
if data.get("flash_attn_cross_entropy") and data.get("unsloth_cross_entropy_loss"):
raise ValueError("flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled")
if data.get("flash_attn_cross_entropy") and data.get(
"unsloth_cross_entropy_loss"
):
raise ValueError(
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
)
return data
@model_validator(mode="before")
@classmethod
def check_qlora_unsloth(cls, data):
if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
@@ -1432,7 +1609,11 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_lora_8bit(cls, data):
if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
@@ -1442,8 +1623,13 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_lora_axolotl_unsloth(cls, data):
is_lora_kernel = any(data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"])
is_unsloth_lora = any(data.get(k) for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"])
is_lora_kernel = any(
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
is_unsloth_lora = any(
data.get(k)
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
)
if is_lora_kernel and is_unsloth_lora:
raise ValueError(
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
@@ -1454,7 +1640,9 @@ class AxolotlInputConfig(
@classmethod
def check_torch_compile_deepspeed(cls, data):
if data.get("deepspeed") and data.get("torch_compile"):
raise ValueError("torch_compile should be set within your deepspeed config file")
raise ValueError(
"torch_compile should be set within your deepspeed config file"
)
return data
@model_validator(mode="before")
@@ -1538,9 +1726,15 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
def check_bf16(self):
if self.capabilities.bf16:
if not self.bf16 and not self.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
LOG.info(
"bf16 support detected, but not enabled for this configuration."
)
else:
if not self.merge_lora and not self.is_preprocess and (self.bf16 is True or self.bfloat16 is True):
if (
not self.merge_lora
and not self.is_preprocess
and (self.bf16 is True or self.bfloat16 is True)
):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
@@ -1549,7 +1743,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, data):
is_sm_90: bool = data["capabilities"] and data["capabilities"].get("compute_capability") == "sm_90"
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if (
data.get("sample_packing")
and data.get("sdp_attention")
@@ -1574,7 +1771,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("n_gpu", 0) > 1:
raise ValueError(
@@ -1585,7 +1786,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before")
@classmethod
def check_multigpu_lora_kernels(cls, data):
if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None
is_deepspeed = data.get("deepspeed") is not None
@@ -1614,7 +1819,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError("ADOPT optimizer is incompatible with torch version < 2.5.1")
raise ValueError(
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data
@model_validator(mode="before")
@@ -1623,8 +1830,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("torch_compile") == "auto":
env_capabilities = data.get("env_capabilities", {})
if env_capabilities.get("torch_version"):
if version.parse(env_capabilities.get("torch_version")) >= version.parse("2.5.1"):
LOG.info("torch.compile is available, setting torch_compile to True")
if version.parse(
env_capabilities.get("torch_version")
) >= version.parse("2.5.1"):
LOG.info(
"torch.compile is available, setting torch_compile to True"
)
data["torch_compile"] = True
else:
data["torch_compile"] = False
@@ -1689,13 +1900,16 @@ def handle_legacy_message_fields_logic(data: dict) -> dict:
)
if (
"content" in data["message_property_mappings"]
and data["message_property_mappings"]["content"] != data["message_field_content"]
and data["message_property_mappings"]["content"]
!= data["message_field_content"]
):
raise ValueError(
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
)
data["message_property_mappings"]["content"] = data["message_field_content"] or "content"
data["message_property_mappings"]["content"] = (
data["message_field_content"] or "content"
)
del data["message_field_content"]
elif "content" not in data["message_property_mappings"]: