pre-commit
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
"""Main Axolotl input configuration Pydantic models"""
|
"""Main Axolotl input configuration Pydantic models"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -91,24 +92,30 @@ class AxolotlInputConfig(
|
|||||||
dpo_use_weighting: bool | None = None
|
dpo_use_weighting: bool | None = None
|
||||||
dpo_use_logits_to_keep: bool | None = None
|
dpo_use_logits_to_keep: bool | None = None
|
||||||
|
|
||||||
datasets: Annotated[
|
datasets: (
|
||||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
Annotated[
|
||||||
MinLen(1),
|
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||||
] | None = None
|
MinLen(1),
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
) = None
|
||||||
|
|
||||||
test_datasets: Annotated[
|
test_datasets: (
|
||||||
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
Annotated[
|
||||||
MinLen(1),
|
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
|
||||||
] | None = None
|
MinLen(1),
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
) = None
|
||||||
shuffle_merged_datasets: bool | None = True
|
shuffle_merged_datasets: bool | None = True
|
||||||
dataset_prepared_path: str | None = None
|
dataset_prepared_path: str | None = None
|
||||||
dataset_shard_num: int | None = None
|
dataset_shard_num: int | None = None
|
||||||
dataset_shard_idx: int | None = None
|
dataset_shard_idx: int | None = None
|
||||||
skip_prepare_dataset: bool | None = False
|
skip_prepare_dataset: bool | None = False
|
||||||
|
|
||||||
pretraining_dataset: Annotated[
|
pretraining_dataset: (
|
||||||
list[PretrainingDataset | SFTDataset], MinLen(1)
|
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
|
||||||
] | None = Field(
|
) = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||||
)
|
)
|
||||||
@@ -232,9 +239,9 @@ class AxolotlInputConfig(
|
|||||||
deepspeed: str | dict[str, Any] | None = None
|
deepspeed: str | dict[str, Any] | None = None
|
||||||
fsdp: list[str] | None = None
|
fsdp: list[str] | None = None
|
||||||
fsdp_config: dict[str, Any] | None = None
|
fsdp_config: dict[str, Any] | None = None
|
||||||
fsdp_final_state_dict_type: Literal[
|
fsdp_final_state_dict_type: (
|
||||||
"FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"
|
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
|
||||||
] | None = None
|
) = None
|
||||||
|
|
||||||
val_set_size: float | None = Field(default=0.0)
|
val_set_size: float | None = Field(default=0.0)
|
||||||
|
|
||||||
@@ -244,9 +251,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
torch_compile: Literal["auto"] | bool | None = None
|
torch_compile: Literal["auto"] | bool | None = None
|
||||||
torch_compile_backend: str | None = None
|
torch_compile_backend: str | None = None
|
||||||
torch_compile_mode: Literal[
|
torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = (
|
||||||
"default", "reduce-overhead", "max-autotune"
|
None
|
||||||
] | None = None
|
)
|
||||||
|
|
||||||
max_steps: int | None = None
|
max_steps: int | None = None
|
||||||
warmup_steps: int | None = None
|
warmup_steps: int | None = None
|
||||||
|
|||||||
@@ -50,9 +50,9 @@ class HyperparametersConfig(BaseModel):
|
|||||||
embedding_lr: float | None = None
|
embedding_lr: float | None = None
|
||||||
embedding_lr_scale: float | None = None
|
embedding_lr_scale: float | None = None
|
||||||
weight_decay: float | None = 0.0
|
weight_decay: float | None = 0.0
|
||||||
optimizer: (
|
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
|
||||||
OptimizerNames | CustomSupportedOptimizers
|
OptimizerNames.ADAMW_TORCH_FUSED
|
||||||
) | None = OptimizerNames.ADAMW_TORCH_FUSED
|
)
|
||||||
optim_args: (str | dict[str, Any]) | None = Field(
|
optim_args: (str | dict[str, Any]) | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||||
@@ -64,9 +64,9 @@ class HyperparametersConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
torchdistx_path: str | None = None
|
torchdistx_path: str | None = None
|
||||||
lr_scheduler: (
|
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
|
||||||
SchedulerType | Literal["one_cycle"] | Literal["rex"]
|
SchedulerType.COSINE
|
||||||
) | None = SchedulerType.COSINE
|
)
|
||||||
lr_scheduler_kwargs: dict[str, Any] | None = None
|
lr_scheduler_kwargs: dict[str, Any] | None = None
|
||||||
lr_quadratic_warmup: bool | None = None
|
lr_quadratic_warmup: bool | None = None
|
||||||
cosine_min_lr_ratio: float | None = None
|
cosine_min_lr_ratio: float | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user