improve handling of warmup/logging steps
This commit is contained in:
@@ -25,7 +25,7 @@ import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Type, Union
|
||||
from typing import Any, Dict, Type, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -231,24 +231,34 @@ class TrainerBuilderBase(abc.ABC):
|
||||
return trainer
|
||||
|
||||
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
|
||||
training_args_kwargs = {}
|
||||
training_args_kwargs: Dict[str, Any] = {}
|
||||
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = 0
|
||||
warmup_ratio = 0.0
|
||||
if self.cfg.warmup_steps:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
elif self.cfg.warmup_ratio:
|
||||
if total_num_steps:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_ratio = self.cfg.warmup_ratio
|
||||
elif total_num_steps:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
else:
|
||||
warmup_ratio = 0.03
|
||||
|
||||
if warmup_steps == 1:
|
||||
warmup_steps = 2
|
||||
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
if self.cfg.logging_steps is not None
|
||||
else None
|
||||
if not total_num_steps
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_args_kwargs["warmup_ratio"] = warmup_ratio
|
||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||
training_args_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
|
||||
Reference in New Issue
Block a user