pre-commit

This commit is contained in:
Dan Saunders
2025-03-21 11:23:39 -04:00
parent ddd84d7c65
commit 94c00c1d04
2 changed files with 30 additions and 23 deletions

View File

@@ -1,4 +1,5 @@
"""Main Axolotl input configuration Pydantic models""" """Main Axolotl input configuration Pydantic models"""
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import logging import logging
@@ -91,24 +92,30 @@ class AxolotlInputConfig(
dpo_use_weighting: bool | None = None dpo_use_weighting: bool | None = None
dpo_use_logits_to_keep: bool | None = None dpo_use_logits_to_keep: bool | None = None
datasets: Annotated[ datasets: (
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset], Annotated[
MinLen(1), list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
] | None = None MinLen(1),
]
| None
) = None
test_datasets: Annotated[ test_datasets: (
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset], Annotated[
MinLen(1), list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
] | None = None MinLen(1),
]
| None
) = None
shuffle_merged_datasets: bool | None = True shuffle_merged_datasets: bool | None = True
dataset_prepared_path: str | None = None dataset_prepared_path: str | None = None
dataset_shard_num: int | None = None dataset_shard_num: int | None = None
dataset_shard_idx: int | None = None dataset_shard_idx: int | None = None
skip_prepare_dataset: bool | None = False skip_prepare_dataset: bool | None = False
pretraining_dataset: Annotated[ pretraining_dataset: (
list[PretrainingDataset | SFTDataset], MinLen(1) Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
] | None = Field( ) = Field(
default=None, default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"}, json_schema_extra={"description": "streaming dataset to use for pretraining"},
) )
@@ -232,9 +239,9 @@ class AxolotlInputConfig(
deepspeed: str | dict[str, Any] | None = None deepspeed: str | dict[str, Any] | None = None
fsdp: list[str] | None = None fsdp: list[str] | None = None
fsdp_config: dict[str, Any] | None = None fsdp_config: dict[str, Any] | None = None
fsdp_final_state_dict_type: Literal[ fsdp_final_state_dict_type: (
"FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT" Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
] | None = None ) = None
val_set_size: float | None = Field(default=0.0) val_set_size: float | None = Field(default=0.0)
@@ -244,9 +251,9 @@ class AxolotlInputConfig(
torch_compile: Literal["auto"] | bool | None = None torch_compile: Literal["auto"] | bool | None = None
torch_compile_backend: str | None = None torch_compile_backend: str | None = None
torch_compile_mode: Literal[ torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = (
"default", "reduce-overhead", "max-autotune" None
] | None = None )
max_steps: int | None = None max_steps: int | None = None
warmup_steps: int | None = None warmup_steps: int | None = None

View File

@@ -50,9 +50,9 @@ class HyperparametersConfig(BaseModel):
embedding_lr: float | None = None embedding_lr: float | None = None
embedding_lr_scale: float | None = None embedding_lr_scale: float | None = None
weight_decay: float | None = 0.0 weight_decay: float | None = 0.0
optimizer: ( optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
OptimizerNames | CustomSupportedOptimizers OptimizerNames.ADAMW_TORCH_FUSED
) | None = OptimizerNames.ADAMW_TORCH_FUSED )
optim_args: (str | dict[str, Any]) | None = Field( optim_args: (str | dict[str, Any]) | None = Field(
default=None, default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."}, json_schema_extra={"description": "Optional arguments to supply to optimizer."},
@@ -64,9 +64,9 @@ class HyperparametersConfig(BaseModel):
}, },
) )
torchdistx_path: str | None = None torchdistx_path: str | None = None
lr_scheduler: ( lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
SchedulerType | Literal["one_cycle"] | Literal["rex"] SchedulerType.COSINE
) | None = SchedulerType.COSINE )
lr_scheduler_kwargs: dict[str, Any] | None = None lr_scheduler_kwargs: dict[str, Any] | None = None
lr_quadratic_warmup: bool | None = None lr_quadratic_warmup: bool | None = None
cosine_min_lr_ratio: float | None = None cosine_min_lr_ratio: float | None = None