linting
This commit is contained in:
@@ -126,8 +126,12 @@ class DeprecatedParameters(BaseModel):
|
||||
class RemappedParameters(BaseModel):
|
||||
"""parameters that have been remapped to other names"""
|
||||
|
||||
overrides_of_model_config: Optional[Dict[str, Any]] = Field(default=None, alias="model_config")
|
||||
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(default=None, alias="model_kwargs")
|
||||
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="model_config"
|
||||
)
|
||||
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="model_kwargs"
|
||||
)
|
||||
type_of_model: Optional[str] = Field(default=None, alias="model_type")
|
||||
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
|
||||
|
||||
@@ -196,8 +200,12 @@ class SFTDataset(BaseModel):
|
||||
field_human: Optional[str] = None
|
||||
field_model: Optional[str] = None
|
||||
field_messages: Optional[str] = None
|
||||
message_field_role: Optional[str] = None # deprecated, use message_property_mappings
|
||||
message_field_content: Optional[str] = None # deprecated, use message_property_mappings
|
||||
message_field_role: Optional[
|
||||
str
|
||||
] = None # deprecated, use message_property_mappings
|
||||
message_field_content: Optional[
|
||||
str
|
||||
] = None # deprecated, use message_property_mappings
|
||||
message_property_mappings: Optional[Dict[str, str]] = None
|
||||
message_field_training: Optional[str] = None
|
||||
message_field_training_detail: Optional[str] = None
|
||||
@@ -227,8 +235,12 @@ class SFTDataset(BaseModel):
|
||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||
|
||||
# if chat_template is set to jinja, chat_template_jinja is required
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"):
|
||||
raise ValueError("chat_template_jinja is required when chat_template is set to jinja")
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
):
|
||||
raise ValueError(
|
||||
"chat_template_jinja is required when chat_template is set to jinja"
|
||||
)
|
||||
|
||||
# If chat_template_jinja is set, set chat_template to jinja
|
||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||
@@ -300,7 +312,9 @@ DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedData
|
||||
class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
loftq_bits: int = Field(default=4, json_schema_extra={"description": "Quantization bits for LoftQ"})
|
||||
loftq_bits: int = Field(
|
||||
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||
)
|
||||
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||
|
||||
|
||||
@@ -345,7 +359,9 @@ class LoraConfig(BaseModel):
|
||||
|
||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "load qlora model in sharded format for FSDP using answer.ai technique."},
|
||||
json_schema_extra={
|
||||
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||
},
|
||||
)
|
||||
lora_on_cpu: Optional[bool] = None
|
||||
gptq: Optional[bool] = None
|
||||
@@ -353,11 +369,15 @@ class LoraConfig(BaseModel):
|
||||
|
||||
loraplus_lr_ratio: Optional[float] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."},
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||
},
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = Field(
|
||||
default=1e-6,
|
||||
json_schema_extra={"description": "loraplus learning rate for lora embedding layers."},
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate for lora embedding layers."
|
||||
},
|
||||
)
|
||||
|
||||
merge_lora: Optional[bool] = None
|
||||
@@ -465,11 +485,15 @@ class HyperparametersConfig(BaseModel):
|
||||
)
|
||||
batch_size: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Total batch size, we do not recommended setting this manually"},
|
||||
json_schema_extra={
|
||||
"description": "Total batch size, we do not recommended setting this manually"
|
||||
},
|
||||
)
|
||||
eval_batch_size: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"},
|
||||
json_schema_extra={
|
||||
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||
},
|
||||
)
|
||||
|
||||
auto_find_batch_size: Optional[bool] = None
|
||||
@@ -481,7 +505,9 @@ class HyperparametersConfig(BaseModel):
|
||||
embedding_lr: Optional[float] = None
|
||||
embedding_lr_scale: Optional[float] = None
|
||||
weight_decay: Optional[float] = 0.0
|
||||
optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = OptimizerNames.ADAMW_HF
|
||||
optimizer: Optional[
|
||||
Union[OptimizerNames, CustomSupportedOptimizers]
|
||||
] = OptimizerNames.ADAMW_HF
|
||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||
@@ -493,7 +519,9 @@ class HyperparametersConfig(BaseModel):
|
||||
},
|
||||
)
|
||||
torchdistx_path: Optional[str] = None
|
||||
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]] = SchedulerType.COSINE
|
||||
lr_scheduler: Optional[
|
||||
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
|
||||
] = SchedulerType.COSINE
|
||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||
lr_quadratic_warmup: Optional[bool] = None
|
||||
cosine_min_lr_ratio: Optional[float] = None
|
||||
@@ -617,15 +645,21 @@ class RayConfig(BaseModel):
|
||||
use_ray: bool = Field(default=False)
|
||||
ray_run_name: Optional[str] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"help": "The training results will be saved at `saves/ray_run_name`."},
|
||||
json_schema_extra={
|
||||
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||
},
|
||||
)
|
||||
ray_num_workers: int = Field(
|
||||
default=1,
|
||||
json_schema_extra={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||
json_schema_extra={
|
||||
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||
},
|
||||
)
|
||||
resources_per_worker: dict = Field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
json_schema_extra={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||
json_schema_extra={
|
||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -665,9 +699,9 @@ class AxolotlInputConfig(
|
||||
reward_model: Optional[bool] = None
|
||||
process_reward_model: Optional[bool] = None
|
||||
num_labels: Optional[int] = None
|
||||
dpo_use_weighting: Optional[bool] = (
|
||||
None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
)
|
||||
dpo_use_weighting: Optional[
|
||||
bool
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
dpo_use_logits_to_keep: Optional[bool] = None
|
||||
|
||||
datasets: Optional[
|
||||
@@ -689,7 +723,9 @@ class AxolotlInputConfig(
|
||||
dataset_shard_idx: Optional[int] = None
|
||||
skip_prepare_dataset: Optional[bool] = False
|
||||
|
||||
pretraining_dataset: Optional[Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]] = Field(
|
||||
pretraining_dataset: Optional[
|
||||
Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]
|
||||
] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
)
|
||||
@@ -744,7 +780,9 @@ class AxolotlInputConfig(
|
||||
|
||||
# torch_dtype: Optional[torch.dtype]
|
||||
|
||||
gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = Field(default=False)
|
||||
gradient_checkpointing: Optional[
|
||||
Union[Literal["unsloth", "offload"], bool]
|
||||
] = Field(default=False)
|
||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
unfrozen_parameters: Optional[List[str]] = None
|
||||
@@ -811,7 +849,9 @@ class AxolotlInputConfig(
|
||||
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
||||
fsdp: Optional[List[str]] = None
|
||||
fsdp_config: Optional[Dict[str, Any]] = None
|
||||
fsdp_final_state_dict_type: Optional[Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]] = None
|
||||
fsdp_final_state_dict_type: Optional[
|
||||
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
|
||||
] = None
|
||||
|
||||
val_set_size: Optional[float] = Field(default=0.0)
|
||||
|
||||
@@ -821,7 +861,9 @@ class AxolotlInputConfig(
|
||||
|
||||
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
||||
torch_compile_backend: Optional[str] = None
|
||||
torch_compile_mode: Optional[Literal["default", "reduce-overhead", "max-autotune"]] = None
|
||||
torch_compile_mode: Optional[
|
||||
Literal["default", "reduce-overhead", "max-autotune"]
|
||||
] = None
|
||||
|
||||
max_steps: Optional[int] = None
|
||||
warmup_steps: Optional[int] = None
|
||||
@@ -852,7 +894,9 @@ class AxolotlInputConfig(
|
||||
kto_undesirable_weight: Optional[float] = None
|
||||
rl_beta: Optional[float] = None
|
||||
|
||||
max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = None
|
||||
max_memory: Optional[
|
||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||
] = None
|
||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||
low_cpu_mem_usage: Optional[bool] = None
|
||||
|
||||
@@ -888,7 +932,11 @@ class AxolotlInputConfig(
|
||||
def deprecate_sharegpt_datasets(cls, datasets):
|
||||
for _, ds_cfg in enumerate(datasets):
|
||||
# Handle both dict and pydantic model cases
|
||||
ds_type = ds_cfg.get("type") if isinstance(ds_cfg, dict) else getattr(ds_cfg, "type", None)
|
||||
ds_type = (
|
||||
ds_cfg.get("type")
|
||||
if isinstance(ds_cfg, dict)
|
||||
else getattr(ds_cfg, "type", None)
|
||||
)
|
||||
if not ds_type:
|
||||
continue
|
||||
|
||||
@@ -897,12 +945,16 @@ class AxolotlInputConfig(
|
||||
continue
|
||||
|
||||
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
|
||||
raise ValueError("`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead.")
|
||||
raise ValueError(
|
||||
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
@field_serializer("datasets")
|
||||
def datasets_serializer(self, ds_configs: Optional[List[DatasetConfig]]) -> Optional[List[Dict[str, Any]]]:
|
||||
def datasets_serializer(
|
||||
self, ds_configs: Optional[List[DatasetConfig]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
if ds_configs:
|
||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||
return None
|
||||
@@ -968,7 +1020,9 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_sample_packing_w_xformers(cls, data):
|
||||
if data.get("sample_packing") and data.get("xformers_attention"):
|
||||
raise ValueError("sample_packing not compatible with xformers_attention. Use flash_attention")
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -976,8 +1030,12 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
# if chat_template is set to jinja, chat_template_jinja is required
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"):
|
||||
raise ValueError("chat_template_jinja is required when chat_template is set to jinja")
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
):
|
||||
raise ValueError(
|
||||
"chat_template_jinja is required when chat_template is set to jinja"
|
||||
)
|
||||
|
||||
# If chat_template_jinja is set, set chat_template to jinja
|
||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||
@@ -988,8 +1046,14 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_wo_flash(cls, data):
|
||||
if data.get("sample_packing") and not data.get("flash_attention") and not data.get("sdp_attention"):
|
||||
LOG.warning("sample_packing without flash_attention or sdp_attention does not handle cross-attention.")
|
||||
if (
|
||||
data.get("sample_packing")
|
||||
and not data.get("flash_attention")
|
||||
and not data.get("sdp_attention")
|
||||
):
|
||||
LOG.warning(
|
||||
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -1028,14 +1092,18 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def hint_sample_packing_padding(cls, data):
|
||||
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
|
||||
LOG.warning("`pad_to_sequence_len: true` is recommended when using sample_packing")
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def hint_reward_model_pad(cls, data):
|
||||
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
|
||||
LOG.warning("`pad_to_sequence_len: true` is recommended when using reward_model")
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using reward_model"
|
||||
)
|
||||
if data.get("pad_to_sequence_len") is None:
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
@@ -1044,7 +1112,9 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_gas_bsz(cls, data):
|
||||
if data.get("gradient_accumulation_steps") and data.get("batch_size"):
|
||||
raise ValueError("please set only one of gradient_accumulation_steps or batch_size")
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1055,14 +1125,21 @@ class AxolotlInputConfig(
|
||||
and data.get("micro_batch_size")
|
||||
and data.get("eval_batch_size") != data.get("micro_batch_size")
|
||||
):
|
||||
LOG.warning("eval_batch_size != micro_batch_size. This can lead to VRAM instability.")
|
||||
LOG.warning(
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_push_ds_auth(cls, data):
|
||||
if data.get("push_dataset_to_hub") and data.get("hf_use_auth_token") is not True:
|
||||
raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub")
|
||||
if (
|
||||
data.get("push_dataset_to_hub")
|
||||
and data.get("hf_use_auth_token") is not True
|
||||
):
|
||||
raise ValueError(
|
||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -1073,14 +1150,18 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_mpt_checkpointing(self):
|
||||
if (self.base_model and "mpt" in self.base_model.lower()) and self.gradient_checkpointing:
|
||||
if (
|
||||
self.base_model and "mpt" in self.base_model.lower()
|
||||
) and self.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_offload_grad_checkpointing(self):
|
||||
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
||||
LOG.warning("`unsloth` is deprecated for gradient_checkpointing, use `offload`")
|
||||
LOG.warning(
|
||||
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
||||
)
|
||||
self.gradient_checkpointing = "offload"
|
||||
return self
|
||||
|
||||
@@ -1088,7 +1169,9 @@ class AxolotlInputConfig(
|
||||
def check_better_transformers(self):
|
||||
if self.flash_optimum is True:
|
||||
if self.adapter:
|
||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||
LOG.warning(
|
||||
"BetterTransformers probably doesn't work with PEFT adapters"
|
||||
)
|
||||
if self.fp16 or self.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if self.float16 is not True and self.bfloat16 is not True:
|
||||
@@ -1116,25 +1199,39 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_saves(cls, data):
|
||||
if data.get("save_strategy") and data.get("save_steps") and data.get("save_strategy") != "steps":
|
||||
if (
|
||||
data.get("save_strategy")
|
||||
and data.get("save_steps")
|
||||
and data.get("save_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
if data.get("saves_per_epoch") and data.get("save_steps"):
|
||||
raise ValueError("save_steps and saves_per_epoch are mutually exclusive and cannot be used together.")
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_push_save(cls, data):
|
||||
if data.get("hub_model_id") and (data.get("save_strategy") not in ["steps", "epoch", None]):
|
||||
LOG.warning("hub_model_id is set without any models being saved. To save a model, set save_strategy.")
|
||||
if data.get("hub_model_id") and (
|
||||
data.get("save_strategy") not in ["steps", "epoch", None]
|
||||
):
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_evals(cls, data):
|
||||
if data.get("eval_strategy") and data.get("eval_steps") and data.get("eval_strategy") != "steps":
|
||||
if (
|
||||
data.get("eval_strategy")
|
||||
and data.get("eval_steps")
|
||||
and data.get("eval_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
@@ -1144,21 +1241,41 @@ class AxolotlInputConfig(
|
||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||
and not data.get("test_datasets")
|
||||
):
|
||||
raise ValueError("eval_steps and eval_strategy are not supported with val_set_size == 0")
|
||||
raise ValueError(
|
||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
||||
raise ValueError("eval_steps and evals_per_epoch are mutually exclusive and cannot be used together.")
|
||||
if data.get("evals_per_epoch") and data.get("eval_strategy") and data.get("eval_strategy") != "steps":
|
||||
raise ValueError("eval_strategy must be empty or set to `steps` when used with evals_per_epoch.")
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if (
|
||||
data.get("evals_per_epoch")
|
||||
and data.get("eval_strategy")
|
||||
and data.get("eval_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
|
||||
if data.get("do_bench_eval") and not (data.get("evals_per_epoch") or data.get("eval_steps")):
|
||||
raise ValueError("do_bench_eval requires evals_per_epoch or eval_steps to be set.")
|
||||
if data.get("do_bench_eval") and not (
|
||||
data.get("evals_per_epoch") or data.get("eval_steps")
|
||||
):
|
||||
raise ValueError(
|
||||
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_test_datasets_bench(cls, data):
|
||||
if data.get("do_bench_eval") and not data.get("test_datasets") and not data.get("val_set_size"):
|
||||
LOG.warning("`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset.")
|
||||
if (
|
||||
data.get("do_bench_eval")
|
||||
and not data.get("test_datasets")
|
||||
and not data.get("val_set_size")
|
||||
):
|
||||
LOG.warning(
|
||||
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
|
||||
)
|
||||
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
|
||||
return data
|
||||
|
||||
@@ -1167,12 +1284,22 @@ class AxolotlInputConfig(
|
||||
def check_eval_packing(cls, data):
|
||||
# TODO also should check test_datasets and val_set_size as we can skip
|
||||
# if there are no eval datasets/splits
|
||||
if data.get("sample_packing") and data.get("eval_table_size") and data.get("eval_sample_packing") is not False:
|
||||
if (
|
||||
data.get("sample_packing")
|
||||
and data.get("eval_table_size")
|
||||
and data.get("eval_sample_packing") is not False
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||
)
|
||||
if data.get("sample_packing") and data.get("eval_sample_packing") is None and not data.get("eval_table_size"):
|
||||
LOG.info("explicitly setting `eval_sample_packing` to match `sample_packing`")
|
||||
if (
|
||||
data.get("sample_packing")
|
||||
and data.get("eval_sample_packing") is None
|
||||
and not data.get("eval_table_size")
|
||||
):
|
||||
LOG.info(
|
||||
"explicitly setting `eval_sample_packing` to match `sample_packing`"
|
||||
)
|
||||
data["eval_sample_packing"] = True
|
||||
|
||||
if (
|
||||
@@ -1192,7 +1319,9 @@ class AxolotlInputConfig(
|
||||
def check_mm_prepare(cls, data):
|
||||
if data.get("skip_prepare_dataset"):
|
||||
if data.get("remove_unused_columns") is None:
|
||||
LOG.info("setting `remove_unused_columns: false` for skip_prepare_dataset")
|
||||
LOG.info(
|
||||
"setting `remove_unused_columns: false` for skip_prepare_dataset"
|
||||
)
|
||||
data["remove_unused_columns"] = False
|
||||
|
||||
return data
|
||||
@@ -1233,13 +1362,19 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="after")
|
||||
def check_simpo_warmup(self):
|
||||
if self.rl == "simpo" and self.warmup_ratio:
|
||||
raise ValueError("warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead")
|
||||
raise ValueError(
|
||||
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_frozen(cls, data):
|
||||
if data.get("adapter") and data.get("peft_layers_to_transform") and data.get("unfrozen_parameters"):
|
||||
if (
|
||||
data.get("adapter")
|
||||
and data.get("peft_layers_to_transform")
|
||||
and data.get("unfrozen_parameters")
|
||||
):
|
||||
raise ValueError(
|
||||
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
||||
)
|
||||
@@ -1250,7 +1385,9 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_peft_layers_pattern(cls, data):
|
||||
if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
|
||||
raise ValueError("peft_layers_pattern requires peft_layers_to_transform to be set")
|
||||
raise ValueError(
|
||||
"peft_layers_pattern requires peft_layers_to_transform to be set"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -1273,13 +1410,17 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fused_lora(self):
|
||||
if self.adapter in ["lora", "qlora"] and (self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp):
|
||||
if self.adapter in ["lora", "qlora"] and (
|
||||
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
|
||||
):
|
||||
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def hint_lora_8bit(self):
|
||||
loftq = self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
|
||||
loftq = (
|
||||
self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
|
||||
)
|
||||
if not self.load_in_8bit and self.adapter == "lora" and not loftq:
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
return self
|
||||
@@ -1292,7 +1433,9 @@ class AxolotlInputConfig(
|
||||
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||
)
|
||||
if self.save_steps % self.eval_steps != 0:
|
||||
raise ValueError("`early_stopping_patience` requires that eval_steps should evenly divide save_steps.")
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -1308,7 +1451,9 @@ class AxolotlInputConfig(
|
||||
raise ValueError("deepspeed not supported with ReLoRA")
|
||||
|
||||
if self.lr_scheduler == "one_cycle":
|
||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||
raise ValueError(
|
||||
"ReLoRA is not compatible with the one_cycle scheduler"
|
||||
)
|
||||
|
||||
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||
@@ -1317,8 +1462,13 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_mem_mismatch(cls, data):
|
||||
if data.get("max_memory") is not None and data.get("gpu_memory_limit") is not None:
|
||||
raise ValueError("max_memory and gpu_memory_limit are mutually exclusive and cannot be used together.")
|
||||
if (
|
||||
data.get("max_memory") is not None
|
||||
and data.get("gpu_memory_limit") is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1327,10 +1477,13 @@ class AxolotlInputConfig(
|
||||
if (
|
||||
data.get("unfrozen_parameters")
|
||||
and data.get("gradient_checkpointing_kwargs")
|
||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is True
|
||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||
is True
|
||||
):
|
||||
# https://github.com/huggingface/transformers/issues/21381
|
||||
raise ValueError("`use_reentrant` must be false when used with partially frozen model.")
|
||||
raise ValueError(
|
||||
"`use_reentrant` must be false when used with partially frozen model."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1339,7 +1492,8 @@ class AxolotlInputConfig(
|
||||
if (
|
||||
data.get("adapter") == "qlora"
|
||||
and data.get("gradient_checkpointing_kwargs", {})
|
||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False
|
||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||
is False
|
||||
and data.get("deepspeed", "") is not None
|
||||
and "zero3" in data.get("deepspeed", "")
|
||||
):
|
||||
@@ -1356,14 +1510,21 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_val_w_test_datasets(cls, data):
|
||||
if data.get("test_datasets") and data.get("val_set_size"):
|
||||
raise ValueError("non-zero val_set_size should not be used with test_datasets configuration")
|
||||
raise ValueError(
|
||||
"non-zero val_set_size should not be used with test_datasets configuration"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_eval_strategy(cls, data):
|
||||
if data.get("evaluation_strategy") is not None and data.get("eval_strategy") is None:
|
||||
LOG.info("explicitly setting `eval_strategy` from the `evaluation_strategy`")
|
||||
if (
|
||||
data.get("evaluation_strategy") is not None
|
||||
and data.get("eval_strategy") is None
|
||||
):
|
||||
LOG.info(
|
||||
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
|
||||
)
|
||||
data["eval_strategy"] = data.get("evaluation_strategy")
|
||||
return data
|
||||
|
||||
@@ -1376,7 +1537,9 @@ class AxolotlInputConfig(
|
||||
and data.get("fsdp_config")
|
||||
and data["fsdp_config"].get("fsdp_offload_params")
|
||||
):
|
||||
raise ValueError(f"FSDP Offload not compatible with {data.get('optimizer')}")
|
||||
raise ValueError(
|
||||
f"FSDP Offload not compatible with {data.get('optimizer')}"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1388,21 +1551,27 @@ class AxolotlInputConfig(
|
||||
and data.get("fsdp_config")
|
||||
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
|
||||
):
|
||||
raise ValueError("FSDP SHARDED_STATE_DICT not compatible with save_safetensors")
|
||||
raise ValueError(
|
||||
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_causal_lm_evals(cls, data):
|
||||
if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
|
||||
raise ValueError("do_causal_lm_eval is enabled, eval_sample_packing must be set to False")
|
||||
raise ValueError(
|
||||
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
||||
)
|
||||
|
||||
if data.get("eval_causal_lm_metrics"):
|
||||
if not isinstance(data.get("eval_causal_lm_metrics"), list):
|
||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
|
||||
raise ValueError(f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}")
|
||||
raise ValueError(
|
||||
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1415,14 +1584,22 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_xentropy_patch_conflicts(cls, data):
|
||||
if data.get("flash_attn_cross_entropy") and data.get("unsloth_cross_entropy_loss"):
|
||||
raise ValueError("flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled")
|
||||
if data.get("flash_attn_cross_entropy") and data.get(
|
||||
"unsloth_cross_entropy_loss"
|
||||
):
|
||||
raise ValueError(
|
||||
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_qlora_unsloth(cls, data):
|
||||
if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||
@@ -1432,7 +1609,11 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_8bit(cls, data):
|
||||
if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
|
||||
@@ -1442,8 +1623,13 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_axolotl_unsloth(cls, data):
|
||||
is_lora_kernel = any(data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"])
|
||||
is_unsloth_lora = any(data.get(k) for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"])
|
||||
is_lora_kernel = any(
|
||||
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||
)
|
||||
is_unsloth_lora = any(
|
||||
data.get(k)
|
||||
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||
)
|
||||
if is_lora_kernel and is_unsloth_lora:
|
||||
raise ValueError(
|
||||
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
|
||||
@@ -1454,7 +1640,9 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
if data.get("deepspeed") and data.get("torch_compile"):
|
||||
raise ValueError("torch_compile should be set within your deepspeed config file")
|
||||
raise ValueError(
|
||||
"torch_compile should be set within your deepspeed config file"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1538,9 +1726,15 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
def check_bf16(self):
|
||||
if self.capabilities.bf16:
|
||||
if not self.bf16 and not self.bfloat16:
|
||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
||||
LOG.info(
|
||||
"bf16 support detected, but not enabled for this configuration."
|
||||
)
|
||||
else:
|
||||
if not self.merge_lora and not self.is_preprocess and (self.bf16 is True or self.bfloat16 is True):
|
||||
if (
|
||||
not self.merge_lora
|
||||
and not self.is_preprocess
|
||||
and (self.bf16 is True or self.bfloat16 is True)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
)
|
||||
@@ -1549,7 +1743,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_w_sdpa_bf16(cls, data):
|
||||
is_sm_90: bool = data["capabilities"] and data["capabilities"].get("compute_capability") == "sm_90"
|
||||
is_sm_90: bool = (
|
||||
data["capabilities"]
|
||||
and data["capabilities"].get("compute_capability") == "sm_90"
|
||||
)
|
||||
if (
|
||||
data.get("sample_packing")
|
||||
and data.get("sdp_attention")
|
||||
@@ -1574,7 +1771,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_unsloth(cls, data):
|
||||
if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
raise ValueError(
|
||||
@@ -1585,7 +1786,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_lora_kernels(cls, data):
|
||||
if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
is_deepspeed = data.get("deepspeed") is not None
|
||||
@@ -1614,7 +1819,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
|
||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||
raise ValueError("ADOPT optimizer is incompatible with torch version < 2.5.1")
|
||||
raise ValueError(
|
||||
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1623,8 +1830,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("torch_compile") == "auto":
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
if env_capabilities.get("torch_version"):
|
||||
if version.parse(env_capabilities.get("torch_version")) >= version.parse("2.5.1"):
|
||||
LOG.info("torch.compile is available, setting torch_compile to True")
|
||||
if version.parse(
|
||||
env_capabilities.get("torch_version")
|
||||
) >= version.parse("2.5.1"):
|
||||
LOG.info(
|
||||
"torch.compile is available, setting torch_compile to True"
|
||||
)
|
||||
data["torch_compile"] = True
|
||||
else:
|
||||
data["torch_compile"] = False
|
||||
@@ -1689,13 +1900,16 @@ def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||
)
|
||||
if (
|
||||
"content" in data["message_property_mappings"]
|
||||
and data["message_property_mappings"]["content"] != data["message_field_content"]
|
||||
and data["message_property_mappings"]["content"]
|
||||
!= data["message_field_content"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||
)
|
||||
data["message_property_mappings"]["content"] = data["message_field_content"] or "content"
|
||||
data["message_property_mappings"]["content"] = (
|
||||
data["message_field_content"] or "content"
|
||||
)
|
||||
|
||||
del data["message_field_content"]
|
||||
elif "content" not in data["message_property_mappings"]:
|
||||
|
||||
Reference in New Issue
Block a user