diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 2c949f8e7..7954e1fbd 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -196,9 +196,9 @@ class TrainerBuilderBase(abc.ABC): ): warmup_steps = 0 warmup_ratio = 0.0 - if self.cfg.warmup_steps: + if self.cfg.warmup_steps is not None: warmup_steps = self.cfg.warmup_steps - elif self.cfg.warmup_ratio: + elif self.cfg.warmup_ratio is not None: if total_num_steps: warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0) else: diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 55317151e..e50483e6c 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -3,6 +3,7 @@ Simple end-to-end test for Liger integration """ import pytest + from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config