make sure to capture non-null defaults from config validation (#1415)

This commit is contained in:
Wing Lian
2024-03-26 12:18:47 -07:00
committed by GitHub
parent ff939d8a64
commit 601b77bc9d
3 changed files with 26 additions and 18 deletions

View File

@@ -208,11 +208,11 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
dict(
AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities
).model_dump(exclude_unset=True)
).model_dump(exclude_none=True)
)
)
return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_unset=True))
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
)

View File

@@ -151,12 +151,6 @@ class PeftConfig(BaseModel):
loftq_config: Optional[LoftQConfig] = None
class AutoType(str, Enum):
"""auto type string configuration subset - used for bf16"""
AUTO = "auto"
class SpecialTokensConfig(BaseModel):
"""Special tokens configuration subset"""
@@ -307,12 +301,14 @@ class HyperparametersConfig(BaseModel):
},
)
train_on_inputs: Optional[bool] = None
train_on_inputs: Optional[bool] = False
group_by_length: Optional[bool] = None
learning_rate: Union[str, float]
weight_decay: Optional[float] = None
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch"]]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
)
@@ -323,7 +319,7 @@ class HyperparametersConfig(BaseModel):
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = None
lr_scheduler: Optional[SchedulerType] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
@@ -473,7 +469,7 @@ class AxolotlInputConfig(
loss_watchdog_threshold: Optional[float] = None
loss_watchdog_patience: Optional[int] = None
bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO
bf16: Optional[Union[Literal["auto"], bool]] = "auto"
fp16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases
float16: Optional[bool] = None # for non-AMP cases
@@ -487,7 +483,7 @@ class AxolotlInputConfig(
unfrozen_parameters: Optional[List[str]] = None
sequence_len: int = Field(default=1024)
sequence_len: int = Field(default=512)
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
@@ -548,10 +544,10 @@ class AxolotlInputConfig(
sample_packing_eff_est: Optional[float] = None
axolotl_config_path: Optional[str] = None
is_falcon_derived_model: Optional[bool] = Field(default=False)
is_llama_derived_model: Optional[bool] = Field(default=False)
is_mistral_derived_model: Optional[bool] = Field(default=False)
is_qwen_derived_model: Optional[bool] = Field(default=False)
is_falcon_derived_model: Optional[bool] = Field(default=None)
is_llama_derived_model: Optional[bool] = Field(default=None)
is_mistral_derived_model: Optional[bool] = Field(default=None)
is_qwen_derived_model: Optional[bool] = Field(default=None)
@field_validator("datasets", mode="before")
@classmethod

View File

@@ -54,6 +54,18 @@ class TestValidation(BaseValidation):
Test the validation module
"""
def test_defaults(self, minimal_cfg):
test_cfg = DictDefault(
{
"weight_decay": None,
}
| minimal_cfg
)
cfg = validate_config(test_cfg)
assert cfg.train_on_inputs is False
assert cfg.weight_decay is None
def test_datasets_min_length(self):
cfg = DictDefault(
{