Num epochs float (#2282) [skip ci]
* Change num_epochs type to float * Handle float value for num_epochs in trainer.py
This commit is contained in:
@@ -489,7 +489,7 @@ class HyperparametersConfig(BaseModel):
|
||||
adam_beta1: Optional[float] = None
|
||||
adam_beta2: Optional[float] = None
|
||||
max_grad_norm: Optional[float] = None
|
||||
num_epochs: int = Field(default=1)
|
||||
num_epochs: float = Field(default=1.0)
|
||||
|
||||
@field_validator("batch_size")
|
||||
@classmethod
|
||||
|
||||
@@ -374,7 +374,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_steps = (
|
||||
# match count to len est in dataloader
|
||||
(
|
||||
int(
|
||||
math.floor(
|
||||
0.99
|
||||
* cfg.total_num_tokens
|
||||
|
||||
Reference in New Issue
Block a user