improve handling of warmup/logging steps

This commit is contained in:
Wing Lian
2025-02-23 12:41:42 -05:00
committed by NanoCode012
parent a6ce7d7522
commit 675561e745

View File

@@ -25,7 +25,7 @@ import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Type, Union from typing import Any, Dict, Type, Union
import torch import torch
import transformers import transformers
@@ -231,24 +231,34 @@ class TrainerBuilderBase(abc.ABC):
return trainer return trainer
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]: def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
training_args_kwargs = {} training_args_kwargs: Dict[str, Any] = {}
warmup_steps = None warmup_steps = 0
if self.cfg.warmup_steps is not None: warmup_ratio = 0.0
if self.cfg.warmup_steps:
warmup_steps = self.cfg.warmup_steps warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio is not None: elif self.cfg.warmup_ratio:
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) if total_num_steps:
else: warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_ratio = self.cfg.warmup_ratio
elif total_num_steps:
warmup_steps = min(int(0.03 * total_num_steps), 100) warmup_steps = min(int(0.03 * total_num_steps), 100)
else:
warmup_ratio = 0.03
if warmup_steps == 1: if warmup_steps == 1:
warmup_steps = 2 warmup_steps = 2
logging_steps = ( logging_steps = (
self.cfg.logging_steps self.cfg.logging_steps
if self.cfg.logging_steps is not None if self.cfg.logging_steps is not None
else None
if not total_num_steps
else max(min(int(0.005 * total_num_steps), 10), 1) else max(min(int(0.005 * total_num_steps), 10), 1)
) )
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps training_args_kwargs["warmup_steps"] = warmup_steps
training_args_kwargs["logging_steps"] = logging_steps training_args_kwargs["logging_steps"] = logging_steps