pre-commit

This commit is contained in:
Dan Saunders
2025-03-21 11:23:39 -04:00
parent ddd84d7c65
commit 94c00c1d04
2 changed files with 30 additions and 23 deletions

View File

@@ -1,4 +1,5 @@
"""Main Axolotl input configuration Pydantic models"""
# pylint: disable=too-many-lines
import logging
@@ -91,24 +92,30 @@ class AxolotlInputConfig(
dpo_use_weighting: bool | None = None
dpo_use_logits_to_keep: bool | None = None
datasets: Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
MinLen(1),
] | None = None
datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
MinLen(1),
]
| None
) = None
test_datasets: Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
MinLen(1),
] | None = None
test_datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
MinLen(1),
]
| None
) = None
shuffle_merged_datasets: bool | None = True
dataset_prepared_path: str | None = None
dataset_shard_num: int | None = None
dataset_shard_idx: int | None = None
skip_prepare_dataset: bool | None = False
pretraining_dataset: Annotated[
list[PretrainingDataset | SFTDataset], MinLen(1)
] | None = Field(
pretraining_dataset: (
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
) = Field(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
@@ -232,9 +239,9 @@ class AxolotlInputConfig(
deepspeed: str | dict[str, Any] | None = None
fsdp: list[str] | None = None
fsdp_config: dict[str, Any] | None = None
fsdp_final_state_dict_type: Literal[
"FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"
] | None = None
fsdp_final_state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = None
val_set_size: float | None = Field(default=0.0)
@@ -244,9 +251,9 @@ class AxolotlInputConfig(
torch_compile: Literal["auto"] | bool | None = None
torch_compile_backend: str | None = None
torch_compile_mode: Literal[
"default", "reduce-overhead", "max-autotune"
] | None = None
torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = (
None
)
max_steps: int | None = None
warmup_steps: int | None = None

View File

@@ -50,9 +50,9 @@ class HyperparametersConfig(BaseModel):
embedding_lr: float | None = None
embedding_lr_scale: float | None = None
weight_decay: float | None = 0.0
optimizer: (
OptimizerNames | CustomSupportedOptimizers
) | None = OptimizerNames.ADAMW_TORCH_FUSED
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
OptimizerNames.ADAMW_TORCH_FUSED
)
optim_args: (str | dict[str, Any]) | None = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
@@ -64,9 +64,9 @@ class HyperparametersConfig(BaseModel):
},
)
torchdistx_path: str | None = None
lr_scheduler: (
SchedulerType | Literal["one_cycle"] | Literal["rex"]
) | None = SchedulerType.COSINE
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
SchedulerType.COSINE
)
lr_scheduler_kwargs: dict[str, Any] | None = None
lr_quadratic_warmup: bool | None = None
cosine_min_lr_ratio: float | None = None