linting
This commit is contained in:
@@ -126,8 +126,12 @@ class DeprecatedParameters(BaseModel):
|
|||||||
class RemappedParameters(BaseModel):
|
class RemappedParameters(BaseModel):
|
||||||
"""parameters that have been remapped to other names"""
|
"""parameters that have been remapped to other names"""
|
||||||
|
|
||||||
overrides_of_model_config: Optional[Dict[str, Any]] = Field(default=None, alias="model_config")
|
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
|
||||||
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(default=None, alias="model_kwargs")
|
default=None, alias="model_config"
|
||||||
|
)
|
||||||
|
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None, alias="model_kwargs"
|
||||||
|
)
|
||||||
type_of_model: Optional[str] = Field(default=None, alias="model_type")
|
type_of_model: Optional[str] = Field(default=None, alias="model_type")
|
||||||
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
|
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
|
||||||
|
|
||||||
@@ -196,8 +200,12 @@ class SFTDataset(BaseModel):
|
|||||||
field_human: Optional[str] = None
|
field_human: Optional[str] = None
|
||||||
field_model: Optional[str] = None
|
field_model: Optional[str] = None
|
||||||
field_messages: Optional[str] = None
|
field_messages: Optional[str] = None
|
||||||
message_field_role: Optional[str] = None # deprecated, use message_property_mappings
|
message_field_role: Optional[
|
||||||
message_field_content: Optional[str] = None # deprecated, use message_property_mappings
|
str
|
||||||
|
] = None # deprecated, use message_property_mappings
|
||||||
|
message_field_content: Optional[
|
||||||
|
str
|
||||||
|
] = None # deprecated, use message_property_mappings
|
||||||
message_property_mappings: Optional[Dict[str, str]] = None
|
message_property_mappings: Optional[Dict[str, str]] = None
|
||||||
message_field_training: Optional[str] = None
|
message_field_training: Optional[str] = None
|
||||||
message_field_training_detail: Optional[str] = None
|
message_field_training_detail: Optional[str] = None
|
||||||
@@ -227,8 +235,12 @@ class SFTDataset(BaseModel):
|
|||||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||||
|
|
||||||
# if chat_template is set to jinja, chat_template_jinja is required
|
# if chat_template is set to jinja, chat_template_jinja is required
|
||||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"):
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
raise ValueError("chat_template_jinja is required when chat_template is set to jinja")
|
"chat_template_jinja"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
|
)
|
||||||
|
|
||||||
# If chat_template_jinja is set, set chat_template to jinja
|
# If chat_template_jinja is set, set chat_template to jinja
|
||||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||||
@@ -300,7 +312,9 @@ DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedData
|
|||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
"""LoftQ configuration subset"""
|
"""LoftQ configuration subset"""
|
||||||
|
|
||||||
loftq_bits: int = Field(default=4, json_schema_extra={"description": "Quantization bits for LoftQ"})
|
loftq_bits: int = Field(
|
||||||
|
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||||
|
)
|
||||||
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||||
|
|
||||||
|
|
||||||
@@ -345,7 +359,9 @@ class LoraConfig(BaseModel):
|
|||||||
|
|
||||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "load qlora model in sharded format for FSDP using answer.ai technique."},
|
json_schema_extra={
|
||||||
|
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
lora_on_cpu: Optional[bool] = None
|
lora_on_cpu: Optional[bool] = None
|
||||||
gptq: Optional[bool] = None
|
gptq: Optional[bool] = None
|
||||||
@@ -353,11 +369,15 @@ class LoraConfig(BaseModel):
|
|||||||
|
|
||||||
loraplus_lr_ratio: Optional[float] = Field(
|
loraplus_lr_ratio: Optional[float] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."},
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
loraplus_lr_embedding: Optional[float] = Field(
|
loraplus_lr_embedding: Optional[float] = Field(
|
||||||
default=1e-6,
|
default=1e-6,
|
||||||
json_schema_extra={"description": "loraplus learning rate for lora embedding layers."},
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate for lora embedding layers."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
merge_lora: Optional[bool] = None
|
merge_lora: Optional[bool] = None
|
||||||
@@ -465,11 +485,15 @@ class HyperparametersConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
batch_size: Optional[int] = Field(
|
batch_size: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Total batch size, we do not recommended setting this manually"},
|
json_schema_extra={
|
||||||
|
"description": "Total batch size, we do not recommended setting this manually"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
eval_batch_size: Optional[int] = Field(
|
eval_batch_size: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"},
|
json_schema_extra={
|
||||||
|
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
auto_find_batch_size: Optional[bool] = None
|
auto_find_batch_size: Optional[bool] = None
|
||||||
@@ -481,7 +505,9 @@ class HyperparametersConfig(BaseModel):
|
|||||||
embedding_lr: Optional[float] = None
|
embedding_lr: Optional[float] = None
|
||||||
embedding_lr_scale: Optional[float] = None
|
embedding_lr_scale: Optional[float] = None
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = OptimizerNames.ADAMW_HF
|
optimizer: Optional[
|
||||||
|
Union[OptimizerNames, CustomSupportedOptimizers]
|
||||||
|
] = OptimizerNames.ADAMW_HF
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||||
@@ -493,7 +519,9 @@ class HyperparametersConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
torchdistx_path: Optional[str] = None
|
torchdistx_path: Optional[str] = None
|
||||||
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]] = SchedulerType.COSINE
|
lr_scheduler: Optional[
|
||||||
|
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
|
||||||
|
] = SchedulerType.COSINE
|
||||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||||
lr_quadratic_warmup: Optional[bool] = None
|
lr_quadratic_warmup: Optional[bool] = None
|
||||||
cosine_min_lr_ratio: Optional[float] = None
|
cosine_min_lr_ratio: Optional[float] = None
|
||||||
@@ -617,15 +645,21 @@ class RayConfig(BaseModel):
|
|||||||
use_ray: bool = Field(default=False)
|
use_ray: bool = Field(default=False)
|
||||||
ray_run_name: Optional[str] = Field(
|
ray_run_name: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"help": "The training results will be saved at `saves/ray_run_name`."},
|
json_schema_extra={
|
||||||
|
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ray_num_workers: int = Field(
|
ray_num_workers: int = Field(
|
||||||
default=1,
|
default=1,
|
||||||
json_schema_extra={"help": "The number of workers for Ray training. Default is 1 worker."},
|
json_schema_extra={
|
||||||
|
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
resources_per_worker: dict = Field(
|
resources_per_worker: dict = Field(
|
||||||
default_factory=lambda: {"GPU": 1},
|
default_factory=lambda: {"GPU": 1},
|
||||||
json_schema_extra={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
json_schema_extra={
|
||||||
|
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -665,9 +699,9 @@ class AxolotlInputConfig(
|
|||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
process_reward_model: Optional[bool] = None
|
process_reward_model: Optional[bool] = None
|
||||||
num_labels: Optional[int] = None
|
num_labels: Optional[int] = None
|
||||||
dpo_use_weighting: Optional[bool] = (
|
dpo_use_weighting: Optional[
|
||||||
None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
bool
|
||||||
)
|
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
dpo_use_logits_to_keep: Optional[bool] = None
|
dpo_use_logits_to_keep: Optional[bool] = None
|
||||||
|
|
||||||
datasets: Optional[
|
datasets: Optional[
|
||||||
@@ -689,7 +723,9 @@ class AxolotlInputConfig(
|
|||||||
dataset_shard_idx: Optional[int] = None
|
dataset_shard_idx: Optional[int] = None
|
||||||
skip_prepare_dataset: Optional[bool] = False
|
skip_prepare_dataset: Optional[bool] = False
|
||||||
|
|
||||||
pretraining_dataset: Optional[Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]] = Field(
|
pretraining_dataset: Optional[
|
||||||
|
Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]
|
||||||
|
] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||||
)
|
)
|
||||||
@@ -744,7 +780,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: Optional[torch.dtype]
|
# torch_dtype: Optional[torch.dtype]
|
||||||
|
|
||||||
gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = Field(default=False)
|
gradient_checkpointing: Optional[
|
||||||
|
Union[Literal["unsloth", "offload"], bool]
|
||||||
|
] = Field(default=False)
|
||||||
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
unfrozen_parameters: Optional[List[str]] = None
|
unfrozen_parameters: Optional[List[str]] = None
|
||||||
@@ -811,7 +849,9 @@ class AxolotlInputConfig(
|
|||||||
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
fsdp: Optional[List[str]] = None
|
fsdp: Optional[List[str]] = None
|
||||||
fsdp_config: Optional[Dict[str, Any]] = None
|
fsdp_config: Optional[Dict[str, Any]] = None
|
||||||
fsdp_final_state_dict_type: Optional[Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]] = None
|
fsdp_final_state_dict_type: Optional[
|
||||||
|
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
|
||||||
|
] = None
|
||||||
|
|
||||||
val_set_size: Optional[float] = Field(default=0.0)
|
val_set_size: Optional[float] = Field(default=0.0)
|
||||||
|
|
||||||
@@ -821,7 +861,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
torch_compile: Optional[Union[Literal["auto"], bool]] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
torch_compile_mode: Optional[Literal["default", "reduce-overhead", "max-autotune"]] = None
|
torch_compile_mode: Optional[
|
||||||
|
Literal["default", "reduce-overhead", "max-autotune"]
|
||||||
|
] = None
|
||||||
|
|
||||||
max_steps: Optional[int] = None
|
max_steps: Optional[int] = None
|
||||||
warmup_steps: Optional[int] = None
|
warmup_steps: Optional[int] = None
|
||||||
@@ -852,7 +894,9 @@ class AxolotlInputConfig(
|
|||||||
kto_undesirable_weight: Optional[float] = None
|
kto_undesirable_weight: Optional[float] = None
|
||||||
rl_beta: Optional[float] = None
|
rl_beta: Optional[float] = None
|
||||||
|
|
||||||
max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = None
|
max_memory: Optional[
|
||||||
|
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||||
|
] = None
|
||||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||||
low_cpu_mem_usage: Optional[bool] = None
|
low_cpu_mem_usage: Optional[bool] = None
|
||||||
|
|
||||||
@@ -888,7 +932,11 @@ class AxolotlInputConfig(
|
|||||||
def deprecate_sharegpt_datasets(cls, datasets):
|
def deprecate_sharegpt_datasets(cls, datasets):
|
||||||
for _, ds_cfg in enumerate(datasets):
|
for _, ds_cfg in enumerate(datasets):
|
||||||
# Handle both dict and pydantic model cases
|
# Handle both dict and pydantic model cases
|
||||||
ds_type = ds_cfg.get("type") if isinstance(ds_cfg, dict) else getattr(ds_cfg, "type", None)
|
ds_type = (
|
||||||
|
ds_cfg.get("type")
|
||||||
|
if isinstance(ds_cfg, dict)
|
||||||
|
else getattr(ds_cfg, "type", None)
|
||||||
|
)
|
||||||
if not ds_type:
|
if not ds_type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -897,12 +945,16 @@ class AxolotlInputConfig(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
|
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
|
||||||
raise ValueError("`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead.")
|
raise ValueError(
|
||||||
|
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
||||||
|
)
|
||||||
|
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
@field_serializer("datasets")
|
@field_serializer("datasets")
|
||||||
def datasets_serializer(self, ds_configs: Optional[List[DatasetConfig]]) -> Optional[List[Dict[str, Any]]]:
|
def datasets_serializer(
|
||||||
|
self, ds_configs: Optional[List[DatasetConfig]]
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
if ds_configs:
|
if ds_configs:
|
||||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||||
return None
|
return None
|
||||||
@@ -968,7 +1020,9 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_sample_packing_w_xformers(cls, data):
|
def check_sample_packing_w_xformers(cls, data):
|
||||||
if data.get("sample_packing") and data.get("xformers_attention"):
|
if data.get("sample_packing") and data.get("xformers_attention"):
|
||||||
raise ValueError("sample_packing not compatible with xformers_attention. Use flash_attention")
|
raise ValueError(
|
||||||
|
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -976,8 +1030,12 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_chat_template_config(cls, data):
|
def check_chat_template_config(cls, data):
|
||||||
# if chat_template is set to jinja, chat_template_jinja is required
|
# if chat_template is set to jinja, chat_template_jinja is required
|
||||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get("chat_template_jinja"):
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
raise ValueError("chat_template_jinja is required when chat_template is set to jinja")
|
"chat_template_jinja"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
|
)
|
||||||
|
|
||||||
# If chat_template_jinja is set, set chat_template to jinja
|
# If chat_template_jinja is set, set chat_template to jinja
|
||||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||||
@@ -988,8 +1046,14 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sample_packing_wo_flash(cls, data):
|
def check_sample_packing_wo_flash(cls, data):
|
||||||
if data.get("sample_packing") and not data.get("flash_attention") and not data.get("sdp_attention"):
|
if (
|
||||||
LOG.warning("sample_packing without flash_attention or sdp_attention does not handle cross-attention.")
|
data.get("sample_packing")
|
||||||
|
and not data.get("flash_attention")
|
||||||
|
and not data.get("sdp_attention")
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"sample_packing without flash_attention or sdp_attention does not handle cross-attention."
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -1028,14 +1092,18 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def hint_sample_packing_padding(cls, data):
|
def hint_sample_packing_padding(cls, data):
|
||||||
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
|
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
|
||||||
LOG.warning("`pad_to_sequence_len: true` is recommended when using sample_packing")
|
LOG.warning(
|
||||||
|
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def hint_reward_model_pad(cls, data):
|
def hint_reward_model_pad(cls, data):
|
||||||
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
|
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
|
||||||
LOG.warning("`pad_to_sequence_len: true` is recommended when using reward_model")
|
LOG.warning(
|
||||||
|
"`pad_to_sequence_len: true` is recommended when using reward_model"
|
||||||
|
)
|
||||||
if data.get("pad_to_sequence_len") is None:
|
if data.get("pad_to_sequence_len") is None:
|
||||||
data["pad_to_sequence_len"] = True
|
data["pad_to_sequence_len"] = True
|
||||||
return data
|
return data
|
||||||
@@ -1044,7 +1112,9 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_gas_bsz(cls, data):
|
def check_gas_bsz(cls, data):
|
||||||
if data.get("gradient_accumulation_steps") and data.get("batch_size"):
|
if data.get("gradient_accumulation_steps") and data.get("batch_size"):
|
||||||
raise ValueError("please set only one of gradient_accumulation_steps or batch_size")
|
raise ValueError(
|
||||||
|
"please set only one of gradient_accumulation_steps or batch_size"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1055,14 +1125,21 @@ class AxolotlInputConfig(
|
|||||||
and data.get("micro_batch_size")
|
and data.get("micro_batch_size")
|
||||||
and data.get("eval_batch_size") != data.get("micro_batch_size")
|
and data.get("eval_batch_size") != data.get("micro_batch_size")
|
||||||
):
|
):
|
||||||
LOG.warning("eval_batch_size != micro_batch_size. This can lead to VRAM instability.")
|
LOG.warning(
|
||||||
|
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_push_ds_auth(cls, data):
|
def check_push_ds_auth(cls, data):
|
||||||
if data.get("push_dataset_to_hub") and data.get("hf_use_auth_token") is not True:
|
if (
|
||||||
raise ValueError("Require cfg.hf_use_auth_token to be True for push_dataset_to_hub")
|
data.get("push_dataset_to_hub")
|
||||||
|
and data.get("hf_use_auth_token") is not True
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -1073,14 +1150,18 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_mpt_checkpointing(self):
|
def check_mpt_checkpointing(self):
|
||||||
if (self.base_model and "mpt" in self.base_model.lower()) and self.gradient_checkpointing:
|
if (
|
||||||
|
self.base_model and "mpt" in self.base_model.lower()
|
||||||
|
) and self.gradient_checkpointing:
|
||||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_offload_grad_checkpointing(self):
|
def check_offload_grad_checkpointing(self):
|
||||||
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
|
||||||
LOG.warning("`unsloth` is deprecated for gradient_checkpointing, use `offload`")
|
LOG.warning(
|
||||||
|
"`unsloth` is deprecated for gradient_checkpointing, use `offload`"
|
||||||
|
)
|
||||||
self.gradient_checkpointing = "offload"
|
self.gradient_checkpointing = "offload"
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -1088,7 +1169,9 @@ class AxolotlInputConfig(
|
|||||||
def check_better_transformers(self):
|
def check_better_transformers(self):
|
||||||
if self.flash_optimum is True:
|
if self.flash_optimum is True:
|
||||||
if self.adapter:
|
if self.adapter:
|
||||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
LOG.warning(
|
||||||
|
"BetterTransformers probably doesn't work with PEFT adapters"
|
||||||
|
)
|
||||||
if self.fp16 or self.bf16:
|
if self.fp16 or self.bf16:
|
||||||
raise ValueError("AMP is not supported with BetterTransformer")
|
raise ValueError("AMP is not supported with BetterTransformer")
|
||||||
if self.float16 is not True and self.bfloat16 is not True:
|
if self.float16 is not True and self.bfloat16 is not True:
|
||||||
@@ -1116,25 +1199,39 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_saves(cls, data):
|
def check_saves(cls, data):
|
||||||
if data.get("save_strategy") and data.get("save_steps") and data.get("save_strategy") != "steps":
|
if (
|
||||||
|
data.get("save_strategy")
|
||||||
|
and data.get("save_steps")
|
||||||
|
and data.get("save_strategy") != "steps"
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
)
|
)
|
||||||
if data.get("saves_per_epoch") and data.get("save_steps"):
|
if data.get("saves_per_epoch") and data.get("save_steps"):
|
||||||
raise ValueError("save_steps and saves_per_epoch are mutually exclusive and cannot be used together.")
|
raise ValueError(
|
||||||
|
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_push_save(cls, data):
|
def check_push_save(cls, data):
|
||||||
if data.get("hub_model_id") and (data.get("save_strategy") not in ["steps", "epoch", None]):
|
if data.get("hub_model_id") and (
|
||||||
LOG.warning("hub_model_id is set without any models being saved. To save a model, set save_strategy.")
|
data.get("save_strategy") not in ["steps", "epoch", None]
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_evals(cls, data):
|
def check_evals(cls, data):
|
||||||
if data.get("eval_strategy") and data.get("eval_steps") and data.get("eval_strategy") != "steps":
|
if (
|
||||||
|
data.get("eval_strategy")
|
||||||
|
and data.get("eval_steps")
|
||||||
|
and data.get("eval_strategy") != "steps"
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
||||||
)
|
)
|
||||||
@@ -1144,21 +1241,41 @@ class AxolotlInputConfig(
|
|||||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||||
and not data.get("test_datasets")
|
and not data.get("test_datasets")
|
||||||
):
|
):
|
||||||
raise ValueError("eval_steps and eval_strategy are not supported with val_set_size == 0")
|
raise ValueError(
|
||||||
|
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||||
|
)
|
||||||
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
||||||
raise ValueError("eval_steps and evals_per_epoch are mutually exclusive and cannot be used together.")
|
raise ValueError(
|
||||||
if data.get("evals_per_epoch") and data.get("eval_strategy") and data.get("eval_strategy") != "steps":
|
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||||
raise ValueError("eval_strategy must be empty or set to `steps` when used with evals_per_epoch.")
|
)
|
||||||
|
if (
|
||||||
|
data.get("evals_per_epoch")
|
||||||
|
and data.get("eval_strategy")
|
||||||
|
and data.get("eval_strategy") != "steps"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
|
)
|
||||||
|
|
||||||
if data.get("do_bench_eval") and not (data.get("evals_per_epoch") or data.get("eval_steps")):
|
if data.get("do_bench_eval") and not (
|
||||||
raise ValueError("do_bench_eval requires evals_per_epoch or eval_steps to be set.")
|
data.get("evals_per_epoch") or data.get("eval_steps")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_test_datasets_bench(cls, data):
|
def check_test_datasets_bench(cls, data):
|
||||||
if data.get("do_bench_eval") and not data.get("test_datasets") and not data.get("val_set_size"):
|
if (
|
||||||
LOG.warning("`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset.")
|
data.get("do_bench_eval")
|
||||||
|
and not data.get("test_datasets")
|
||||||
|
and not data.get("val_set_size")
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
|
||||||
|
)
|
||||||
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
|
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -1167,12 +1284,22 @@ class AxolotlInputConfig(
|
|||||||
def check_eval_packing(cls, data):
|
def check_eval_packing(cls, data):
|
||||||
# TODO also should check test_datasets and val_set_size as we can skip
|
# TODO also should check test_datasets and val_set_size as we can skip
|
||||||
# if there are no eval datasets/splits
|
# if there are no eval datasets/splits
|
||||||
if data.get("sample_packing") and data.get("eval_table_size") and data.get("eval_sample_packing") is not False:
|
if (
|
||||||
|
data.get("sample_packing")
|
||||||
|
and data.get("eval_table_size")
|
||||||
|
and data.get("eval_sample_packing") is not False
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||||
)
|
)
|
||||||
if data.get("sample_packing") and data.get("eval_sample_packing") is None and not data.get("eval_table_size"):
|
if (
|
||||||
LOG.info("explicitly setting `eval_sample_packing` to match `sample_packing`")
|
data.get("sample_packing")
|
||||||
|
and data.get("eval_sample_packing") is None
|
||||||
|
and not data.get("eval_table_size")
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"explicitly setting `eval_sample_packing` to match `sample_packing`"
|
||||||
|
)
|
||||||
data["eval_sample_packing"] = True
|
data["eval_sample_packing"] = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1192,7 +1319,9 @@ class AxolotlInputConfig(
|
|||||||
def check_mm_prepare(cls, data):
|
def check_mm_prepare(cls, data):
|
||||||
if data.get("skip_prepare_dataset"):
|
if data.get("skip_prepare_dataset"):
|
||||||
if data.get("remove_unused_columns") is None:
|
if data.get("remove_unused_columns") is None:
|
||||||
LOG.info("setting `remove_unused_columns: false` for skip_prepare_dataset")
|
LOG.info(
|
||||||
|
"setting `remove_unused_columns: false` for skip_prepare_dataset"
|
||||||
|
)
|
||||||
data["remove_unused_columns"] = False
|
data["remove_unused_columns"] = False
|
||||||
|
|
||||||
return data
|
return data
|
||||||
@@ -1233,13 +1362,19 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_simpo_warmup(self):
|
def check_simpo_warmup(self):
|
||||||
if self.rl == "simpo" and self.warmup_ratio:
|
if self.rl == "simpo" and self.warmup_ratio:
|
||||||
raise ValueError("warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead")
|
raise ValueError(
|
||||||
|
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_frozen(cls, data):
|
def check_frozen(cls, data):
|
||||||
if data.get("adapter") and data.get("peft_layers_to_transform") and data.get("unfrozen_parameters"):
|
if (
|
||||||
|
data.get("adapter")
|
||||||
|
and data.get("peft_layers_to_transform")
|
||||||
|
and data.get("unfrozen_parameters")
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
||||||
)
|
)
|
||||||
@@ -1250,7 +1385,9 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_peft_layers_pattern(cls, data):
|
def check_peft_layers_pattern(cls, data):
|
||||||
if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
|
if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
|
||||||
raise ValueError("peft_layers_pattern requires peft_layers_to_transform to be set")
|
raise ValueError(
|
||||||
|
"peft_layers_pattern requires peft_layers_to_transform to be set"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -1273,13 +1410,17 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_fused_lora(self):
|
def check_fused_lora(self):
|
||||||
if self.adapter in ["lora", "qlora"] and (self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp):
|
if self.adapter in ["lora", "qlora"] and (
|
||||||
|
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
|
||||||
|
):
|
||||||
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def hint_lora_8bit(self):
|
def hint_lora_8bit(self):
|
||||||
loftq = self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
|
loftq = (
|
||||||
|
self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
|
||||||
|
)
|
||||||
if not self.load_in_8bit and self.adapter == "lora" and not loftq:
|
if not self.load_in_8bit and self.adapter == "lora" and not loftq:
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
return self
|
return self
|
||||||
@@ -1292,7 +1433,9 @@ class AxolotlInputConfig(
|
|||||||
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||||
)
|
)
|
||||||
if self.save_steps % self.eval_steps != 0:
|
if self.save_steps % self.eval_steps != 0:
|
||||||
raise ValueError("`early_stopping_patience` requires that eval_steps should evenly divide save_steps.")
|
raise ValueError(
|
||||||
|
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -1308,7 +1451,9 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError("deepspeed not supported with ReLoRA")
|
raise ValueError("deepspeed not supported with ReLoRA")
|
||||||
|
|
||||||
if self.lr_scheduler == "one_cycle":
|
if self.lr_scheduler == "one_cycle":
|
||||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
raise ValueError(
|
||||||
|
"ReLoRA is not compatible with the one_cycle scheduler"
|
||||||
|
)
|
||||||
|
|
||||||
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
|
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
|
||||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||||
@@ -1317,8 +1462,13 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_mem_mismatch(cls, data):
|
def check_mem_mismatch(cls, data):
|
||||||
if data.get("max_memory") is not None and data.get("gpu_memory_limit") is not None:
|
if (
|
||||||
raise ValueError("max_memory and gpu_memory_limit are mutually exclusive and cannot be used together.")
|
data.get("max_memory") is not None
|
||||||
|
and data.get("gpu_memory_limit") is not None
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1327,10 +1477,13 @@ class AxolotlInputConfig(
|
|||||||
if (
|
if (
|
||||||
data.get("unfrozen_parameters")
|
data.get("unfrozen_parameters")
|
||||||
and data.get("gradient_checkpointing_kwargs")
|
and data.get("gradient_checkpointing_kwargs")
|
||||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is True
|
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||||
|
is True
|
||||||
):
|
):
|
||||||
# https://github.com/huggingface/transformers/issues/21381
|
# https://github.com/huggingface/transformers/issues/21381
|
||||||
raise ValueError("`use_reentrant` must be false when used with partially frozen model.")
|
raise ValueError(
|
||||||
|
"`use_reentrant` must be false when used with partially frozen model."
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1339,7 +1492,8 @@ class AxolotlInputConfig(
|
|||||||
if (
|
if (
|
||||||
data.get("adapter") == "qlora"
|
data.get("adapter") == "qlora"
|
||||||
and data.get("gradient_checkpointing_kwargs", {})
|
and data.get("gradient_checkpointing_kwargs", {})
|
||||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False
|
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||||
|
is False
|
||||||
and data.get("deepspeed", "") is not None
|
and data.get("deepspeed", "") is not None
|
||||||
and "zero3" in data.get("deepspeed", "")
|
and "zero3" in data.get("deepspeed", "")
|
||||||
):
|
):
|
||||||
@@ -1356,14 +1510,21 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_val_w_test_datasets(cls, data):
|
def check_val_w_test_datasets(cls, data):
|
||||||
if data.get("test_datasets") and data.get("val_set_size"):
|
if data.get("test_datasets") and data.get("val_set_size"):
|
||||||
raise ValueError("non-zero val_set_size should not be used with test_datasets configuration")
|
raise ValueError(
|
||||||
|
"non-zero val_set_size should not be used with test_datasets configuration"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_eval_strategy(cls, data):
|
def check_eval_strategy(cls, data):
|
||||||
if data.get("evaluation_strategy") is not None and data.get("eval_strategy") is None:
|
if (
|
||||||
LOG.info("explicitly setting `eval_strategy` from the `evaluation_strategy`")
|
data.get("evaluation_strategy") is not None
|
||||||
|
and data.get("eval_strategy") is None
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
|
||||||
|
)
|
||||||
data["eval_strategy"] = data.get("evaluation_strategy")
|
data["eval_strategy"] = data.get("evaluation_strategy")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -1376,7 +1537,9 @@ class AxolotlInputConfig(
|
|||||||
and data.get("fsdp_config")
|
and data.get("fsdp_config")
|
||||||
and data["fsdp_config"].get("fsdp_offload_params")
|
and data["fsdp_config"].get("fsdp_offload_params")
|
||||||
):
|
):
|
||||||
raise ValueError(f"FSDP Offload not compatible with {data.get('optimizer')}")
|
raise ValueError(
|
||||||
|
f"FSDP Offload not compatible with {data.get('optimizer')}"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1388,21 +1551,27 @@ class AxolotlInputConfig(
|
|||||||
and data.get("fsdp_config")
|
and data.get("fsdp_config")
|
||||||
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
|
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
|
||||||
):
|
):
|
||||||
raise ValueError("FSDP SHARDED_STATE_DICT not compatible with save_safetensors")
|
raise ValueError(
|
||||||
|
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_causal_lm_evals(cls, data):
|
def check_causal_lm_evals(cls, data):
|
||||||
if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
|
if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
|
||||||
raise ValueError("do_causal_lm_eval is enabled, eval_sample_packing must be set to False")
|
raise ValueError(
|
||||||
|
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
||||||
|
)
|
||||||
|
|
||||||
if data.get("eval_causal_lm_metrics"):
|
if data.get("eval_causal_lm_metrics"):
|
||||||
if not isinstance(data.get("eval_causal_lm_metrics"), list):
|
if not isinstance(data.get("eval_causal_lm_metrics"), list):
|
||||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||||
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
|
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
|
||||||
raise ValueError(f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}")
|
raise ValueError(
|
||||||
|
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1415,14 +1584,22 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_xentropy_patch_conflicts(cls, data):
|
def check_xentropy_patch_conflicts(cls, data):
|
||||||
if data.get("flash_attn_cross_entropy") and data.get("unsloth_cross_entropy_loss"):
|
if data.get("flash_attn_cross_entropy") and data.get(
|
||||||
raise ValueError("flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled")
|
"unsloth_cross_entropy_loss"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_qlora_unsloth(cls, data):
|
def check_qlora_unsloth(cls, data):
|
||||||
if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"):
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||||
@@ -1432,7 +1609,11 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_8bit(cls, data):
|
def check_lora_8bit(cls, data):
|
||||||
if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"):
|
if (
|
||||||
|
data.get("lora_mlp_kernel")
|
||||||
|
or data.get("lora_qkv_kernel")
|
||||||
|
or data.get("lora_o_kernel")
|
||||||
|
):
|
||||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
|
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
|
||||||
@@ -1442,8 +1623,13 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_axolotl_unsloth(cls, data):
|
def check_lora_axolotl_unsloth(cls, data):
|
||||||
is_lora_kernel = any(data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"])
|
is_lora_kernel = any(
|
||||||
is_unsloth_lora = any(data.get(k) for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"])
|
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||||
|
)
|
||||||
|
is_unsloth_lora = any(
|
||||||
|
data.get(k)
|
||||||
|
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||||
|
)
|
||||||
if is_lora_kernel and is_unsloth_lora:
|
if is_lora_kernel and is_unsloth_lora:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
|
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
|
||||||
@@ -1454,7 +1640,9 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_torch_compile_deepspeed(cls, data):
|
def check_torch_compile_deepspeed(cls, data):
|
||||||
if data.get("deepspeed") and data.get("torch_compile"):
|
if data.get("deepspeed") and data.get("torch_compile"):
|
||||||
raise ValueError("torch_compile should be set within your deepspeed config file")
|
raise ValueError(
|
||||||
|
"torch_compile should be set within your deepspeed config file"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1538,9 +1726,15 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
def check_bf16(self):
|
def check_bf16(self):
|
||||||
if self.capabilities.bf16:
|
if self.capabilities.bf16:
|
||||||
if not self.bf16 and not self.bfloat16:
|
if not self.bf16 and not self.bfloat16:
|
||||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
LOG.info(
|
||||||
|
"bf16 support detected, but not enabled for this configuration."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if not self.merge_lora and not self.is_preprocess and (self.bf16 is True or self.bfloat16 is True):
|
if (
|
||||||
|
not self.merge_lora
|
||||||
|
and not self.is_preprocess
|
||||||
|
and (self.bf16 is True or self.bfloat16 is True)
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||||
)
|
)
|
||||||
@@ -1549,7 +1743,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sample_packing_w_sdpa_bf16(cls, data):
|
def check_sample_packing_w_sdpa_bf16(cls, data):
|
||||||
is_sm_90: bool = data["capabilities"] and data["capabilities"].get("compute_capability") == "sm_90"
|
is_sm_90: bool = (
|
||||||
|
data["capabilities"]
|
||||||
|
and data["capabilities"].get("compute_capability") == "sm_90"
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
data.get("sample_packing")
|
data.get("sample_packing")
|
||||||
and data.get("sdp_attention")
|
and data.get("sdp_attention")
|
||||||
@@ -1574,7 +1771,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_multigpu_unsloth(cls, data):
|
def check_multigpu_unsloth(cls, data):
|
||||||
if data.get("unsloth_lora_mlp") or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o"):
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1585,7 +1786,11 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_multigpu_lora_kernels(cls, data):
|
def check_multigpu_lora_kernels(cls, data):
|
||||||
if data.get("lora_mlp_kernel") or data.get("lora_qkv_kernel") or data.get("lora_o_kernel"):
|
if (
|
||||||
|
data.get("lora_mlp_kernel")
|
||||||
|
or data.get("lora_qkv_kernel")
|
||||||
|
or data.get("lora_o_kernel")
|
||||||
|
):
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_fsdp = data.get("fsdp") is not None
|
is_fsdp = data.get("fsdp") is not None
|
||||||
is_deepspeed = data.get("deepspeed") is not None
|
is_deepspeed = data.get("deepspeed") is not None
|
||||||
@@ -1614,7 +1819,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||||
|
|
||||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||||
raise ValueError("ADOPT optimizer is incompatible with torch version < 2.5.1")
|
raise ValueError(
|
||||||
|
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
||||||
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1623,8 +1830,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("torch_compile") == "auto":
|
if data.get("torch_compile") == "auto":
|
||||||
env_capabilities = data.get("env_capabilities", {})
|
env_capabilities = data.get("env_capabilities", {})
|
||||||
if env_capabilities.get("torch_version"):
|
if env_capabilities.get("torch_version"):
|
||||||
if version.parse(env_capabilities.get("torch_version")) >= version.parse("2.5.1"):
|
if version.parse(
|
||||||
LOG.info("torch.compile is available, setting torch_compile to True")
|
env_capabilities.get("torch_version")
|
||||||
|
) >= version.parse("2.5.1"):
|
||||||
|
LOG.info(
|
||||||
|
"torch.compile is available, setting torch_compile to True"
|
||||||
|
)
|
||||||
data["torch_compile"] = True
|
data["torch_compile"] = True
|
||||||
else:
|
else:
|
||||||
data["torch_compile"] = False
|
data["torch_compile"] = False
|
||||||
@@ -1689,13 +1900,16 @@ def handle_legacy_message_fields_logic(data: dict) -> dict:
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
"content" in data["message_property_mappings"]
|
"content" in data["message_property_mappings"]
|
||||||
and data["message_property_mappings"]["content"] != data["message_field_content"]
|
and data["message_property_mappings"]["content"]
|
||||||
|
!= data["message_field_content"]
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||||
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||||
)
|
)
|
||||||
data["message_property_mappings"]["content"] = data["message_field_content"] or "content"
|
data["message_property_mappings"]["content"] = (
|
||||||
|
data["message_field_content"] or "content"
|
||||||
|
)
|
||||||
|
|
||||||
del data["message_field_content"]
|
del data["message_field_content"]
|
||||||
elif "content" not in data["message_property_mappings"]:
|
elif "content" not in data["message_property_mappings"]:
|
||||||
|
|||||||
Reference in New Issue
Block a user