improve handling of warmup/logging steps

This commit is contained in:
Wing Lian
2025-02-23 12:41:42 -05:00
committed by NanoCode012
parent a6ce7d7522
commit 675561e745

View File

@@ -25,7 +25,7 @@ import os
import sys
from abc import abstractmethod
from pathlib import Path
from typing import Any, Type, Union
from typing import Any, Dict, Type, Union
import torch
import transformers
@@ -231,24 +231,34 @@ class TrainerBuilderBase(abc.ABC):
return trainer
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
training_args_kwargs = {}
training_args_kwargs: Dict[str, Any] = {}
warmup_steps = None
if self.cfg.warmup_steps is not None:
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio is not None:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
elif self.cfg.warmup_ratio:
if total_num_steps:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_ratio = self.cfg.warmup_ratio
elif total_num_steps:
warmup_steps = min(int(0.03 * total_num_steps), 100)
else:
warmup_ratio = 0.03
if warmup_steps == 1:
warmup_steps = 2
logging_steps = (
self.cfg.logging_steps
if self.cfg.logging_steps is not None
else None
if not total_num_steps
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
training_args_kwargs["logging_steps"] = logging_steps