Num epochs float (#2282) [skip ci]

* Change num_epochs type to float

* Handle float value for num_epochs in trainer.py
This commit is contained in:
mashdragon
2025-01-29 04:23:26 +00:00
committed by GitHub
parent 067b442596
commit c015a76a23
2 changed files with 2 additions and 2 deletions

View File

@@ -489,7 +489,7 @@ class HyperparametersConfig(BaseModel):
adam_beta1: Optional[float] = None
adam_beta2: Optional[float] = None
max_grad_norm: Optional[float] = None
num_epochs: int = Field(default=1)
num_epochs: float = Field(default=1.0)
@field_validator("batch_size")
@classmethod

View File

@@ -374,7 +374,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
if cfg.sample_packing_eff_est:
total_num_steps = (
# match count to len est in dataloader
(
int(
math.floor(
0.99
* cfg.total_num_tokens