Rank 0-only logging (#2608)
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# pylint: disable=too-many-lines
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
@@ -13,12 +12,15 @@ from axolotl.loaders.utils import check_model_config
|
||||
from axolotl.utils import is_comet_available
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_cfg")
|
||||
def fixture_cfg():
|
||||
@@ -80,7 +82,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(test_cfg)
|
||||
assert (
|
||||
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||
@@ -218,7 +220,7 @@ class TestValidation(BaseValidation):
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert "batch_size is not recommended" in self._caplog.records[0].message
|
||||
|
||||
@@ -513,7 +515,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"BetterTransformers probably doesn't work with PEFT adapters"
|
||||
@@ -531,7 +533,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"probably set bfloat16 or float16" in record.message
|
||||
@@ -577,7 +579,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"adamw hyperparameters found, but no adamw optimizer set"
|
||||
@@ -595,7 +597,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"adamw hyperparameters found, but no adamw optimizer set"
|
||||
@@ -654,7 +656,7 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
@@ -673,7 +675,7 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
with self._caplog.at_level(logging.INFO):
|
||||
with self._caplog.at_level("INFO"):
|
||||
cfg = validate_config(cfg)
|
||||
assert any(
|
||||
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||
@@ -1109,7 +1111,7 @@ class TestValidation(BaseValidation):
|
||||
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
@@ -1118,7 +1120,7 @@ class TestValidation(BaseValidation):
|
||||
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
@@ -1128,7 +1130,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
@@ -1138,28 +1140,28 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_none(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_dpo_beta_deprecation(self, minimal_cfg):
|
||||
cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg["rl_beta"] == 0.2
|
||||
assert new_cfg["dpo_beta"] is None
|
||||
@@ -1175,7 +1177,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.eval_strategy == "steps"
|
||||
assert (
|
||||
@@ -1455,7 +1457,7 @@ class TestValidationWandb(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert any(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
|
||||
Reference in New Issue
Block a user