This commit is contained in:
Salman Mohammadi
2025-03-18 11:23:23 +00:00
parent 57b0ad1467
commit 690908cf2f

View File

@@ -126,8 +126,12 @@ class DeprecatedParameters(BaseModel):
class RemappedParameters(BaseModel): class RemappedParameters(BaseModel):
"""parameters that have been remapped to other names""" """parameters that have been remapped to other names"""
overrides_of_model_config: Optional[Dict[str, Any]] = Field(default=None, alias="model_config") overrides_of_model_config: Optional[Dict[str, Any]] = Field(
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(default=None, alias="model_kwargs") default=None, alias="model_config"
)
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(
default=None, alias="model_kwargs"
)
type_of_model: Optional[str] = Field(default=None, alias="model_type") type_of_model: Optional[str] = Field(default=None, alias="model_type")
revision_of_model: Optional[str] = Field(default=None, alias="model_revision") revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
@@ -196,8 +200,12 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None field_human: Optional[str] = None
field_model: Optional[str] = None field_model: Optional[str] = None
field_messages: Optional[str] = None field_messages: Optional[str] = None
message_field_role: Optional[str] = None # deprecated, use message_property_mappings message_field_role: Optional[
message_field_content: Optional[str] = None # deprecated, use message_property_mappings str
] = None # deprecated, use message_property_mappings
message_field_content: Optional[
str
] = None # deprecated, use message_property_mappings
message_property_mappings: Optional[Dict[str, str]] = None message_property_mappings: Optional[Dict[str, str]] = None
message_field_training: Optional[str] = None message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None message_field_training_detail: Optional[str] = None
@@ -227,8 +235,12 @@ class SFTDataset(BaseModel):
data["chat_template"] = ChatTemplate.tokenizer_default data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required # if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"): if data.get("chat_template") == ChatTemplate.jinja and not data.get(
raise ValueError("chat_template_jinja is required when chat_template is set to jinja") "chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja # If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"): if data.get("chat_template_jinja") and not data.get("chat_template"):
@@ -300,7 +312,9 @@ DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedData
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
"""LoftQ configuration subset""" """LoftQ configuration subset"""
loftq_bits: int = Field(default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}) loftq_bits: int = Field(
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
)
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"}) # loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
@@ -345,7 +359,9 @@ class LoraConfig(BaseModel):
qlora_sharded_model_loading: Optional[bool] = Field( qlora_sharded_model_loading: Optional[bool] = Field(
default=False, default=False,
json_schema_extra={"description": "load qlora model in sharded format for FSDP using answer.ai technique."}, json_schema_extra={
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
},
) )
lora_on_cpu: Optional[bool] = None lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None gptq: Optional[bool] = None
@@ -353,11 +369,15 @@ class LoraConfig(BaseModel):
loraplus_lr_ratio: Optional[float] = Field( loraplus_lr_ratio: Optional[float] = Field(
default=None, default=None,
json_schema_extra={"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."}, json_schema_extra={
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
},
) )
loraplus_lr_embedding: Optional[float] = Field( loraplus_lr_embedding: Optional[float] = Field(
default=1e-6, default=1e-6,
json_schema_extra={"description": "loraplus learning rate for lora embedding layers."}, json_schema_extra={
"description": "loraplus learning rate for lora embedding layers."
},
) )
merge_lora: Optional[bool] = None merge_lora: Optional[bool] = None
@@ -465,11 +485,15 @@ class HyperparametersConfig(BaseModel):
) )
batch_size: Optional[int] = Field( batch_size: Optional[int] = Field(
default=None, default=None,
json_schema_extra={"description": "Total batch size, we do not recommended setting this manually"}, json_schema_extra={
"description": "Total batch size, we do not recommended setting this manually"
},
) )
eval_batch_size: Optional[int] = Field( eval_batch_size: Optional[int] = Field(
default=None, default=None,
json_schema_extra={"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"}, json_schema_extra={
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
},
) )
auto_find_batch_size: Optional[bool] = None auto_find_batch_size: Optional[bool] = None
@@ -481,7 +505,9 @@ class HyperparametersConfig(BaseModel):
embedding_lr: Optional[float] = None embedding_lr: Optional[float] = None
embedding_lr_scale: Optional[float] = None embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = OptimizerNames.ADAMW_HF optimizer: Optional[
Union[OptimizerNames, CustomSupportedOptimizers]
] = OptimizerNames.ADAMW_HF
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."}, json_schema_extra={"description": "Optional arguments to supply to optimizer."},
@@ -493,7 +519,9 @@ class HyperparametersConfig(BaseModel):
}, },
) )
torchdistx_path: Optional[str] = None torchdistx_path: Optional[str] = None
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]] = SchedulerType.COSINE lr_scheduler: Optional[
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
] = SchedulerType.COSINE
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None cosine_min_lr_ratio: Optional[float] = None
@@ -617,15 +645,21 @@ class RayConfig(BaseModel):
use_ray: bool = Field(default=False) use_ray: bool = Field(default=False)
ray_run_name: Optional[str] = Field( ray_run_name: Optional[str] = Field(
default=None, default=None,
json_schema_extra={"help": "The training results will be saved at `saves/ray_run_name`."}, json_schema_extra={
"help": "The training results will be saved at `saves/ray_run_name`."
},
) )
ray_num_workers: int = Field( ray_num_workers: int = Field(
default=1, default=1,
json_schema_extra={"help": "The number of workers for Ray training. Default is 1 worker."}, json_schema_extra={
"help": "The number of workers for Ray training. Default is 1 worker."
},
) )
resources_per_worker: dict = Field( resources_per_worker: dict = Field(
default_factory=lambda: {"GPU": 1}, default_factory=lambda: {"GPU": 1},
json_schema_extra={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, json_schema_extra={
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
) )
@@ -665,9 +699,9 @@ class AxolotlInputConfig(
reward_model: Optional[bool] = None reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None num_labels: Optional[int] = None
dpo_use_weighting: Optional[bool] = ( dpo_use_weighting: Optional[
None # whether to use weighting in DPO trainer. If none, default is false in the trainer. bool
) ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_logits_to_keep: Optional[bool] = None dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[ datasets: Optional[
@@ -689,7 +723,9 @@ class AxolotlInputConfig(
dataset_shard_idx: Optional[int] = None dataset_shard_idx: Optional[int] = None
skip_prepare_dataset: Optional[bool] = False skip_prepare_dataset: Optional[bool] = False
pretraining_dataset: Optional[Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]] = Field( pretraining_dataset: Optional[
Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]
] = Field(
default=None, default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"}, json_schema_extra={"description": "streaming dataset to use for pretraining"},
) )
@@ -744,7 +780,9 @@ class AxolotlInputConfig(
# torch_dtype: Optional[torch.dtype] # torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = Field(default=False) gradient_checkpointing: Optional[
Union[Literal["unsloth", "offload"], bool]
] = Field(default=False)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None unfrozen_parameters: Optional[List[str]] = None
@@ -811,7 +849,9 @@ class AxolotlInputConfig(
deepspeed: Optional[Union[str, Dict[str, Any]]] = None deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None fsdp_config: Optional[Dict[str, Any]] = None
fsdp_final_state_dict_type: Optional[Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]] = None fsdp_final_state_dict_type: Optional[
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
] = None
val_set_size: Optional[float] = Field(default=0.0) val_set_size: Optional[float] = Field(default=0.0)
@@ -821,7 +861,9 @@ class AxolotlInputConfig(
torch_compile: Optional[Union[Literal["auto"], bool]] = None torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None torch_compile_backend: Optional[str] = None
torch_compile_mode: Optional[Literal["default", "reduce-overhead", "max-autotune"]] = None torch_compile_mode: Optional[
Literal["default", "reduce-overhead", "max-autotune"]
] = None
max_steps: Optional[int] = None max_steps: Optional[int] = None
warmup_steps: Optional[int] = None warmup_steps: Optional[int] = None
@@ -852,7 +894,9 @@ class AxolotlInputConfig(
kto_undesirable_weight: Optional[float] = None kto_undesirable_weight: Optional[float] = None
rl_beta: Optional[float] = None rl_beta: Optional[float] = None
max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = None max_memory: Optional[
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
] = None
gpu_memory_limit: Optional[Union[int, str]] = None gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None low_cpu_mem_usage: Optional[bool] = None
@@ -888,7 +932,11 @@ class AxolotlInputConfig(
def deprecate_sharegpt_datasets(cls, datasets): def deprecate_sharegpt_datasets(cls, datasets):
for _, ds_cfg in enumerate(datasets): for _, ds_cfg in enumerate(datasets):
# Handle both dict and pydantic model cases # Handle both dict and pydantic model cases
ds_type = ds_cfg.get("type") if isinstance(ds_cfg, dict) else getattr(ds_cfg, "type", None) ds_type = (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
)
if not ds_type: if not ds_type:
continue continue
@@ -897,12 +945,16 @@ class AxolotlInputConfig(
continue continue
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"): if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
raise ValueError("`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead.") raise ValueError(
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
)
return datasets return datasets
@field_serializer("datasets") @field_serializer("datasets")
def datasets_serializer(self, ds_configs: Optional[List[DatasetConfig]]) -> Optional[List[Dict[str, Any]]]: def datasets_serializer(
self, ds_configs: Optional[List[DatasetConfig]]
) -> Optional[List[Dict[str, Any]]]:
if ds_configs: if ds_configs:
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None return None
@@ -968,7 +1020,9 @@ class AxolotlInputConfig(
@classmethod @classmethod
def check_sample_packing_w_xformers(cls, data): def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"): if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError("sample_packing not compatible with xformers_attention. Use flash_attention") raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data return data
@@ -976,8 +1030,12 @@ class AxolotlInputConfig(
@classmethod @classmethod
def check_chat_template_config(cls, data): def check_chat_template_config(cls, data):
# if chat_template is set to jinja, chat_template_jinja is required # if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"): if data.get("chat_template") == ChatTemplate.jinja and not data.get(
raise ValueError("chat_template_jinja is required when chat_template is set to jinja") "chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja # If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"): if data.get("chat_template_jinja") and not data.get("chat_template"):
@@ -988,8 +1046,14 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sample_packing_wo_flash(cls, data): def check_sample_packing_wo_flash(cls, data):
if data.get("sample_packing") and not data.get("flash_attention") and not data.get("sdp_attention"): if (
LOG.warning("sample_packing without flash_attention or sdp_attention does not handle cross-attention.") data.get("sample_packing")
and not data.get("flash_attention")
and not data.get("sdp_attention")
):
LOG.warning(
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
)
return data return data
@@ -1028,14 +1092,18 @@ class AxolotlInputConfig(
@classmethod @classmethod
def hint_sample_packing_padding(cls, data): def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"): if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning("`pad_to_sequence_len: true` is recommended when using sample_packing") LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def hint_reward_model_pad(cls, data): def hint_reward_model_pad(cls, data):
if data.get("reward_model") and not data.get("pad_to_sequence_len"): if data.get("reward_model") and not data.get("pad_to_sequence_len"):
LOG.warning("`pad_to_sequence_len: true` is recommended when using reward_model") LOG.warning(
"`pad_to_sequence_len: true` is recommended when using reward_model"
)
if data.get("pad_to_sequence_len") is None: if data.get("pad_to_sequence_len") is None:
data["pad_to_sequence_len"] = True data["pad_to_sequence_len"] = True
return data return data
@@ -1044,7 +1112,9 @@ class AxolotlInputConfig(
@classmethod @classmethod
def check_gas_bsz(cls, data): def check_gas_bsz(cls, data):
if data.get("gradient_accumulation_steps") and data.get("batch_size"): if data.get("gradient_accumulation_steps") and data.get("batch_size"):
raise ValueError("please set only one of gradient_accumulation_steps or batch_size") raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@@ -1055,14 +1125,21 @@ class AxolotlInputConfig(
and data.get("micro_batch_size") and data.get("micro_batch_size")
and data.get("eval_batch_size") != data.get("micro_batch_size") and data.get("eval_batch_size") != data.get("micro_batch_size")
): ):
LOG.warning("eval_batch_size != micro_batch_size. This can lead to VRAM instability.") LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_push_ds_auth(cls, data): def check_push_ds_auth(cls, data):
if data.get("push_dataset_to_hub") and data.get("hf_use_auth_token") is not True: if (
raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub") data.get("push_dataset_to_hub")
and data.get("hf_use_auth_token") is not True
):
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
return data return data
@model_validator(mode="after") @model_validator(mode="after")
@@ -1073,14 +1150,18 @@ class AxolotlInputConfig(
@model_validator(mode="after") @model_validator(mode="after")
def check_mpt_checkpointing(self): def check_mpt_checkpointing(self):
if (self.base_model and "mpt" in self.base_model.lower()) and self.gradient_checkpointing: if (
self.base_model and "mpt" in self.base_model.lower()
) and self.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models") raise ValueError("gradient_checkpointing is not supported for MPT models")
return self return self
@model_validator(mode="after") @model_validator(mode="after")
def check_offload_grad_checkpointing(self): def check_offload_grad_checkpointing(self):
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth": if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
LOG.warning("`unsloth` is deprecated for gradient_checkpointing, use `offload`") LOG.warning(
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
)
self.gradient_checkpointing = "offload" self.gradient_checkpointing = "offload"
return self return self
@@ -1088,7 +1169,9 @@ class AxolotlInputConfig(
def check_better_transformers(self): def check_better_transformers(self):
if self.flash_optimum is True: if self.flash_optimum is True:
if self.adapter: if self.adapter:
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters") LOG.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if self.fp16 or self.bf16: if self.fp16 or self.bf16:
raise ValueError("AMP is not supported with BetterTransformer") raise ValueError("AMP is not supported with BetterTransformer")
if self.float16 is not True and self.bfloat16 is not True: if self.float16 is not True and self.bfloat16 is not True:
@@ -1116,25 +1199,39 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_saves(cls, data): def check_saves(cls, data):
if data.get("save_strategy") and data.get("save_steps") and data.get("save_strategy") != "steps": if (
data.get("save_strategy")
and data.get("save_steps")
and data.get("save_strategy") != "steps"
):
raise ValueError( raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
) )
if data.get("saves_per_epoch") and data.get("save_steps"): if data.get("saves_per_epoch") and data.get("save_steps"):
raise ValueError("save_steps and saves_per_epoch are mutually exclusive and cannot be used together.") raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_push_save(cls, data): def check_push_save(cls, data):
if data.get("hub_model_id") and (data.get("save_strategy") not in ["steps", "epoch", None]): if data.get("hub_model_id") and (
LOG.warning("hub_model_id is set without any models being saved. To save a model, set save_strategy.") data.get("save_strategy") not in ["steps", "epoch", None]
):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_evals(cls, data): def check_evals(cls, data):
if data.get("eval_strategy") and data.get("eval_steps") and data.get("eval_strategy") != "steps": if (
data.get("eval_strategy")
and data.get("eval_steps")
and data.get("eval_strategy") != "steps"
):
raise ValueError( raise ValueError(
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps." "eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
) )
@@ -1144,21 +1241,41 @@ class AxolotlInputConfig(
and (data.get("eval_steps") or data.get("eval_strategy")) and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets") and not data.get("test_datasets")
): ):
raise ValueError("eval_steps and eval_strategy are not supported with val_set_size == 0") raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0"
)
if data.get("evals_per_epoch") and data.get("eval_steps"): if data.get("evals_per_epoch") and data.get("eval_steps"):
raise ValueError("eval_steps and evals_per_epoch are mutually exclusive and cannot be used together.") raise ValueError(
if data.get("evals_per_epoch") and data.get("eval_strategy") and data.get("eval_strategy") != "steps": "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
raise ValueError("eval_strategy must be empty or set to `steps` when used with evals_per_epoch.") )
if (
data.get("evals_per_epoch")
and data.get("eval_strategy")
and data.get("eval_strategy") != "steps"
):
raise ValueError(
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if data.get("do_bench_eval") and not (data.get("evals_per_epoch") or data.get("eval_steps")): if data.get("do_bench_eval") and not (
raise ValueError("do_bench_eval requires evals_per_epoch or eval_steps to be set.") data.get("evals_per_epoch") or data.get("eval_steps")
):
raise ValueError(
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_test_datasets_bench(cls, data): def check_test_datasets_bench(cls, data):
if data.get("do_bench_eval") and not data.get("test_datasets") and not data.get("val_set_size"): if (
LOG.warning("`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset.") data.get("do_bench_eval")
and not data.get("test_datasets")
and not data.get("val_set_size")
):
LOG.warning(
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
)
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
return data return data
@@ -1167,12 +1284,22 @@ class AxolotlInputConfig(
def check_eval_packing(cls, data): def check_eval_packing(cls, data):
# TODO also should check test_datasets and val_set_size as we can skip # TODO also should check test_datasets and val_set_size as we can skip
# if there are no eval datasets/splits # if there are no eval datasets/splits
if data.get("sample_packing") and data.get("eval_table_size") and data.get("eval_sample_packing") is not False: if (
data.get("sample_packing")
and data.get("eval_table_size")
and data.get("eval_sample_packing") is not False
):
raise ValueError( raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
) )
if data.get("sample_packing") and data.get("eval_sample_packing") is None and not data.get("eval_table_size"): if (
LOG.info("explicitly setting `eval_sample_packing` to match `sample_packing`") data.get("sample_packing")
and data.get("eval_sample_packing") is None
and not data.get("eval_table_size")
):
LOG.info(
"explicitly setting `eval_sample_packing` to match `sample_packing`"
)
data["eval_sample_packing"] = True data["eval_sample_packing"] = True
if ( if (
@@ -1192,7 +1319,9 @@ class AxolotlInputConfig(
def check_mm_prepare(cls, data): def check_mm_prepare(cls, data):
if data.get("skip_prepare_dataset"): if data.get("skip_prepare_dataset"):
if data.get("remove_unused_columns") is None: if data.get("remove_unused_columns") is None:
LOG.info("setting `remove_unused_columns: false` for skip_prepare_dataset") LOG.info(
"setting `remove_unused_columns: false` for skip_prepare_dataset"
)
data["remove_unused_columns"] = False data["remove_unused_columns"] = False
return data return data
@@ -1233,13 +1362,19 @@ class AxolotlInputConfig(
@model_validator(mode="after") @model_validator(mode="after")
def check_simpo_warmup(self): def check_simpo_warmup(self):
if self.rl == "simpo" and self.warmup_ratio: if self.rl == "simpo" and self.warmup_ratio:
raise ValueError("warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead") raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)
return self return self
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_frozen(cls, data): def check_frozen(cls, data):
if data.get("adapter") and data.get("peft_layers_to_transform") and data.get("unfrozen_parameters"): if (
data.get("adapter")
and data.get("peft_layers_to_transform")
and data.get("unfrozen_parameters")
):
raise ValueError( raise ValueError(
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
) )
@@ -1250,7 +1385,9 @@ class AxolotlInputConfig(
@classmethod @classmethod
def check_peft_layers_pattern(cls, data): def check_peft_layers_pattern(cls, data):
if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
raise ValueError("peft_layers_pattern requires peft_layers_to_transform to be set") raise ValueError(
"peft_layers_pattern requires peft_layers_to_transform to be set"
)
return data return data
@model_validator(mode="after") @model_validator(mode="after")
@@ -1273,13 +1410,17 @@ class AxolotlInputConfig(
@model_validator(mode="after") @model_validator(mode="after")
def check_fused_lora(self): def check_fused_lora(self):
if self.adapter in ["lora", "qlora"] and (self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp): if self.adapter in ["lora", "qlora"] and (
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
):
raise ValueError("Fused modules are not supported with LoRA/QLoRA") raise ValueError("Fused modules are not supported with LoRA/QLoRA")
return self return self
@model_validator(mode="after") @model_validator(mode="after")
def hint_lora_8bit(self): def hint_lora_8bit(self):
loftq = self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits loftq = (
self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
)
if not self.load_in_8bit and self.adapter == "lora" and not loftq: if not self.load_in_8bit and self.adapter == "lora" and not loftq:
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
return self return self
@@ -1292,7 +1433,9 @@ class AxolotlInputConfig(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
) )
if self.save_steps % self.eval_steps != 0: if self.save_steps % self.eval_steps != 0:
raise ValueError("`early_stopping_patience` requires that eval_steps should evenly divide save_steps.") raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
return self return self
@model_validator(mode="after") @model_validator(mode="after")
@@ -1308,7 +1451,9 @@ class AxolotlInputConfig(
raise ValueError("deepspeed not supported with ReLoRA") raise ValueError("deepspeed not supported with ReLoRA")
if self.lr_scheduler == "one_cycle": if self.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") raise ValueError(
"ReLoRA is not compatible with the one_cycle scheduler"
)
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp: if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA") raise ValueError("Fused modules are not supported with ReLoRA")
@@ -1317,8 +1462,13 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_mem_mismatch(cls, data): def check_mem_mismatch(cls, data):
if data.get("max_memory") is not None and data.get("gpu_memory_limit") is not None: if (
raise ValueError("max_memory and gpu_memory_limit are mutually exclusive and cannot be used together.") data.get("max_memory") is not None
and data.get("gpu_memory_limit") is not None
):
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@@ -1327,10 +1477,13 @@ class AxolotlInputConfig(
if ( if (
data.get("unfrozen_parameters") data.get("unfrozen_parameters")
and data.get("gradient_checkpointing_kwargs") and data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is True and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is True
): ):
# https://github.com/huggingface/transformers/issues/21381 # https://github.com/huggingface/transformers/issues/21381
raise ValueError("`use_reentrant` must be false when used with partially frozen model.") raise ValueError(
"`use_reentrant` must be false when used with partially frozen model."
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@@ -1339,7 +1492,8 @@ class AxolotlInputConfig(
if ( if (
data.get("adapter") == "qlora" data.get("adapter") == "qlora"
and data.get("gradient_checkpointing_kwargs", {}) and data.get("gradient_checkpointing_kwargs", {})
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is False
and data.get("deepspeed", "") is not None and data.get("deepspeed", "") is not None
and "zero3" in data.get("deepspeed", "") and "zero3" in data.get("deepspeed", "")
): ):
@@ -1356,14 +1510,21 @@ class AxolotlInputConfig(
@classmethod @classmethod
def check_val_w_test_datasets(cls, data): def check_val_w_test_datasets(cls, data):
if data.get("test_datasets") and data.get("val_set_size"): if data.get("test_datasets") and data.get("val_set_size"):
raise ValueError("non-zero val_set_size should not be used with test_datasets configuration") raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_eval_strategy(cls, data): def check_eval_strategy(cls, data):
if data.get("evaluation_strategy") is not None and data.get("eval_strategy") is None: if (
LOG.info("explicitly setting `eval_strategy` from the `evaluation_strategy`") data.get("evaluation_strategy") is not None
and data.get("eval_strategy") is None
):
LOG.info(
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
)
data["eval_strategy"] = data.get("evaluation_strategy") data["eval_strategy"] = data.get("evaluation_strategy")
return data return data
@@ -1376,7 +1537,9 @@ class AxolotlInputConfig(
and data.get("fsdp_config") and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_offload_params") and data["fsdp_config"].get("fsdp_offload_params")
): ):
raise ValueError(f"FSDP Offload not compatible with {data.get('optimizer')}") raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@@ -1388,21 +1551,27 @@ class AxolotlInputConfig(
and data.get("fsdp_config") and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
): ):
raise ValueError("FSDP SHARDED_STATE_DICT not compatible with save_safetensors") raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_causal_lm_evals(cls, data): def check_causal_lm_evals(cls, data):
if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"): if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
raise ValueError("do_causal_lm_eval is enabled, eval_sample_packing must be set to False") raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if data.get("eval_causal_lm_metrics"): if data.get("eval_causal_lm_metrics"):
if not isinstance(data.get("eval_causal_lm_metrics"), list): if not isinstance(data.get("eval_causal_lm_metrics"), list):
raise ValueError("eval_causal_lm_metrics must be a list") raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported # only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS: if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
raise ValueError(f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}") raise ValueError(
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@@ -1415,14 +1584,22 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_xentropy_patch_conflicts(cls, data): def check_xentropy_patch_conflicts(cls, data):
if data.get("flash_attn_cross_entropy") and data.get("unsloth_cross_entropy_loss"): if data.get("flash_attn_cross_entropy") and data.get(
raise ValueError("flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled") "unsloth_cross_entropy_loss"
):
raise ValueError(
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_qlora_unsloth(cls, data): def check_qlora_unsloth(cls, data):
if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"): if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"): if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError( raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
@@ -1432,7 +1609,11 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_lora_8bit(cls, data): def check_lora_8bit(cls, data):
if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"): if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"): if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError( raise ValueError(
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA" "lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
@@ -1442,8 +1623,13 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_lora_axolotl_unsloth(cls, data): def check_lora_axolotl_unsloth(cls, data):
is_lora_kernel = any(data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]) is_lora_kernel = any(
is_unsloth_lora = any(data.get(k) for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]) data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
is_unsloth_lora = any(
data.get(k)
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
)
if is_lora_kernel and is_unsloth_lora: if is_lora_kernel and is_unsloth_lora:
raise ValueError( raise ValueError(
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)" "both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
@@ -1454,7 +1640,9 @@ class AxolotlInputConfig(
@classmethod @classmethod
def check_torch_compile_deepspeed(cls, data): def check_torch_compile_deepspeed(cls, data):
if data.get("deepspeed") and data.get("torch_compile"): if data.get("deepspeed") and data.get("torch_compile"):
raise ValueError("torch_compile should be set within your deepspeed config file") raise ValueError(
"torch_compile should be set within your deepspeed config file"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@@ -1538,9 +1726,15 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
def check_bf16(self): def check_bf16(self):
if self.capabilities.bf16: if self.capabilities.bf16:
if not self.bf16 and not self.bfloat16: if not self.bf16 and not self.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.") LOG.info(
"bf16 support detected, but not enabled for this configuration."
)
else: else:
if not self.merge_lora and not self.is_preprocess and (self.bf16 is True or self.bfloat16 is True): if (
not self.merge_lora
and not self.is_preprocess
and (self.bf16 is True or self.bfloat16 is True)
):
raise ValueError( raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
) )
@@ -1549,7 +1743,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sample_packing_w_sdpa_bf16(cls, data): def check_sample_packing_w_sdpa_bf16(cls, data):
is_sm_90: bool = data["capabilities"] and data["capabilities"].get("compute_capability") == "sm_90" is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if ( if (
data.get("sample_packing") data.get("sample_packing")
and data.get("sdp_attention") and data.get("sdp_attention")
@@ -1574,7 +1771,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_multigpu_unsloth(cls, data): def check_multigpu_unsloth(cls, data):
if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"): if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities") capabilities = data.get("capabilities")
if capabilities and capabilities.get("n_gpu", 0) > 1: if capabilities and capabilities.get("n_gpu", 0) > 1:
raise ValueError( raise ValueError(
@@ -1585,7 +1786,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_multigpu_lora_kernels(cls, data): def check_multigpu_lora_kernels(cls, data):
if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"): if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
capabilities = data.get("capabilities") capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None is_fsdp = data.get("fsdp") is not None
is_deepspeed = data.get("deepspeed") is not None is_deepspeed = data.get("deepspeed") is not None
@@ -1614,7 +1819,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
torch_version = str(torch.__version__).split("+", maxsplit=1)[0] torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"): if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError("ADOPT optimizer is incompatible with torch version < 2.5.1") raise ValueError(
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@@ -1623,8 +1830,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("torch_compile") == "auto": if data.get("torch_compile") == "auto":
env_capabilities = data.get("env_capabilities", {}) env_capabilities = data.get("env_capabilities", {})
if env_capabilities.get("torch_version"): if env_capabilities.get("torch_version"):
if version.parse(env_capabilities.get("torch_version")) >= version.parse("2.5.1"): if version.parse(
LOG.info("torch.compile is available, setting torch_compile to True") env_capabilities.get("torch_version")
) >= version.parse("2.5.1"):
LOG.info(
"torch.compile is available, setting torch_compile to True"
)
data["torch_compile"] = True data["torch_compile"] = True
else: else:
data["torch_compile"] = False data["torch_compile"] = False
@@ -1689,13 +1900,16 @@ def handle_legacy_message_fields_logic(data: dict) -> dict:
) )
if ( if (
"content" in data["message_property_mappings"] "content" in data["message_property_mappings"]
and data["message_property_mappings"]["content"] != data["message_field_content"] and data["message_property_mappings"]["content"]
!= data["message_field_content"]
): ):
raise ValueError( raise ValueError(
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' " f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'" f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
) )
data["message_property_mappings"]["content"] = data["message_field_content"] or "content" data["message_property_mappings"]["content"] = (
data["message_field_content"] or "content"
)
del data["message_field_content"] del data["message_field_content"]
elif "content" not in data["message_property_mappings"]: elif "content" not in data["message_property_mappings"]: