improve handling of warmup/logging steps
This commit is contained in:
@@ -25,7 +25,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Type, Union
|
from typing import Any, Dict, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -231,24 +231,34 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
|
def _set_base_training_args(self, total_num_steps) -> dict[str, Any]:
|
||||||
training_args_kwargs = {}
|
training_args_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
warmup_steps = None
|
warmup_steps = 0
|
||||||
if self.cfg.warmup_steps is not None:
|
warmup_ratio = 0.0
|
||||||
|
if self.cfg.warmup_steps:
|
||||||
warmup_steps = self.cfg.warmup_steps
|
warmup_steps = self.cfg.warmup_steps
|
||||||
elif self.cfg.warmup_ratio is not None:
|
elif self.cfg.warmup_ratio:
|
||||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
if total_num_steps:
|
||||||
else:
|
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||||
|
else:
|
||||||
|
warmup_ratio = self.cfg.warmup_ratio
|
||||||
|
elif total_num_steps:
|
||||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||||
|
else:
|
||||||
|
warmup_ratio = 0.03
|
||||||
|
|
||||||
if warmup_steps == 1:
|
if warmup_steps == 1:
|
||||||
warmup_steps = 2
|
warmup_steps = 2
|
||||||
|
|
||||||
logging_steps = (
|
logging_steps = (
|
||||||
self.cfg.logging_steps
|
self.cfg.logging_steps
|
||||||
if self.cfg.logging_steps is not None
|
if self.cfg.logging_steps is not None
|
||||||
|
else None
|
||||||
|
if not total_num_steps
|
||||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
training_args_kwargs["warmup_ratio"] = warmup_ratio
|
||||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||||
training_args_kwargs["logging_steps"] = logging_steps
|
training_args_kwargs["logging_steps"] = logging_steps
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user