Rank 0-only logging (#2608)

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
salman
2025-05-28 14:57:30 +01:00
committed by GitHub
parent 5fca214108
commit 65c5481120
135 changed files with 454 additions and 378 deletions

View File

@@ -1,7 +1,6 @@
# pylint: disable=too-many-lines
"""Module for testing the validation module"""
import logging
import os
import warnings
from typing import Optional
@@ -13,12 +12,15 @@ from axolotl.loaders.utils import check_model_config
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
from axolotl.utils.wandb_ import setup_wandb_env_vars
warnings.filterwarnings("error")
LOG = get_logger(__name__)
@pytest.fixture(name="minimal_cfg")
def fixture_cfg():
@@ -80,7 +82,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(test_cfg)
assert (
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
@@ -218,7 +220,7 @@ class TestValidation(BaseValidation):
}
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert "batch_size is not recommended" in self._caplog.records[0].message
@@ -513,7 +515,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"BetterTransformers probably doesn't work with PEFT adapters"
@@ -531,7 +533,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"probably set bfloat16 or float16" in record.message
@@ -577,7 +579,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
@@ -595,7 +597,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
@@ -654,7 +656,7 @@ class TestValidation(BaseValidation):
)
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
@@ -673,7 +675,7 @@ class TestValidation(BaseValidation):
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
with self._caplog.at_level("INFO"):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
@@ -1109,7 +1111,7 @@ class TestValidation(BaseValidation):
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 1
@@ -1118,7 +1120,7 @@ class TestValidation(BaseValidation):
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 1
@@ -1128,7 +1130,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
@@ -1138,28 +1140,28 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_none(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_dpo_beta_deprecation(self, minimal_cfg):
cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
new_cfg = validate_config(cfg)
assert new_cfg["rl_beta"] == 0.2
assert new_cfg["dpo_beta"] is None
@@ -1175,7 +1177,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
new_cfg = validate_config(cfg)
assert new_cfg.eval_strategy == "steps"
assert (
@@ -1455,7 +1457,7 @@ class TestValidationWandb(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
new_cfg = validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."