From 675561e7451fbc5c8ef25ee487ec193f47c30310 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 23 Feb 2025 12:41:42 -0500 Subject: [PATCH] improve handling of warmup/logging steps --- src/axolotl/core/trainer_builder.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index abced8ceb..2d17fec1a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -25,7 +25,7 @@ import os import sys from abc import abstractmethod from pathlib import Path -from typing import Any, Type, Union +from typing import Any, Dict, Type, Union import torch import transformers @@ -231,24 +231,34 @@ class TrainerBuilderBase(abc.ABC): return trainer def _set_base_training_args(self, total_num_steps) -> dict[str, Any]: - training_args_kwargs = {} + training_args_kwargs: Dict[str, Any] = {} - warmup_steps = None - if self.cfg.warmup_steps is not None: + warmup_steps = 0 + warmup_ratio = 0.0 + if self.cfg.warmup_steps: warmup_steps = self.cfg.warmup_steps - elif self.cfg.warmup_ratio is not None: - warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) - else: + elif self.cfg.warmup_ratio: + if total_num_steps: + warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) + else: + warmup_ratio = self.cfg.warmup_ratio + elif total_num_steps: warmup_steps = min(int(0.03 * total_num_steps), 100) + else: + warmup_ratio = 0.03 + if warmup_steps == 1: warmup_steps = 2 logging_steps = ( self.cfg.logging_steps if self.cfg.logging_steps is not None + else None + if not total_num_steps else max(min(int(0.005 * total_num_steps), 10), 1) ) + training_args_kwargs["warmup_ratio"] = warmup_ratio training_args_kwargs["warmup_steps"] = warmup_steps training_args_kwargs["logging_steps"] = logging_steps