Rank 0-only logging (#2608)

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
salman
2025-05-28 14:57:30 +01:00
committed by GitHub
parent 5fca214108
commit 65c5481120
135 changed files with 454 additions and 378 deletions

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
@@ -14,10 +13,11 @@ from transformers.testing_utils import get_torch_dist_unique_port
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu eval
"""
import logging
import os
from pathlib import Path
@@ -11,10 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
@@ -13,10 +12,11 @@ from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
@@ -15,10 +14,11 @@ from packaging import version
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu qwen2
"""
import logging
import os
from pathlib import Path
@@ -12,8 +11,9 @@ from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu post-training use Ray Train
"""
import logging
import os
from pathlib import Path
@@ -11,10 +10,11 @@ import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multipack fft llama using 4d attention masks
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import pytest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for falcon
"""
import logging
import os
import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for llama w/ S2 attn
"""
import logging
import os
import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for mixtral
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for resuming training
"""
import logging
import os
import re
import subprocess
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
e2e tests for unsloth qlora
"""
import logging
import os
import pytest
@@ -12,10 +11,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for packed training w/ flex attention
"""
import logging
import os
import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for relora llama
"""
import logging
import os
import unittest
from pathlib import Path
@@ -12,10 +11,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for deepseekv3
"""
import logging
import os
from pathlib import Path
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_preference_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for llama pretrain
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for falcon
"""
import logging
import os
import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for gemma2
"""
import logging
import os
from pathlib import Path
@@ -13,8 +12,9 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for gemma3_text
"""
import logging
import os
from pathlib import Path
@@ -13,8 +12,9 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for llama
"""
import logging
import os
from axolotl.cli.args import TrainerCliArgs
@@ -10,10 +9,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for llama pretrain
"""
import logging
import os
import pytest
@@ -12,10 +11,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for mixtral
"""
import logging
import os
import unittest
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for custom optimizers using Llama
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for packed training
"""
import logging
import os
import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for process reward model w/ lora llama
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for qwen
"""
import logging
import os
from pathlib import Path
@@ -12,8 +11,9 @@ from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.tests.qwen")
LOG = get_logger("axolotl.tests.qwen")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for reward model lora llama
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for custom schedulers using Llama
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,8 +2,6 @@
config validation tests for swiglu args
"""
# pylint: disable=duplicate-code
import logging
from typing import Optional
import pytest
@@ -11,6 +9,11 @@ import pytest
from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
# pylint: disable=duplicate-code
from axolotl.utils.logging import get_logger
LOG = get_logger("axolotl.integrations.test_liger")
@pytest.fixture(name="minimal_liger_cfg")
def fixture_cfg():
@@ -41,7 +44,7 @@ class TestValidation:
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
caplog.set_level(logging.WARNING)
caplog.set_level("WARNING")
self._caplog = caplog
def test_deprecated_swiglu(self, minimal_liger_cfg):
@@ -52,9 +55,7 @@ class TestValidation:
| minimal_liger_cfg
)
with self._caplog.at_level(
logging.WARNING, logger="axolotl.integrations.liger.args"
):
with self._caplog.at_level("WARNING", logger="axolotl.integrations.liger.args"):
prepare_plugins(test_cfg)
updated_cfg = validate_config(test_cfg)
# TODO this test is brittle in CI

View File

@@ -1,7 +1,6 @@
# pylint: disable=too-many-lines
"""Module for testing the validation module"""
import logging
import os
import warnings
from typing import Optional
@@ -13,12 +12,15 @@ from axolotl.loaders.utils import check_model_config
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
from axolotl.utils.wandb_ import setup_wandb_env_vars
warnings.filterwarnings("error")
LOG = get_logger(__name__)
@pytest.fixture(name="minimal_cfg")
def fixture_cfg():
@@ -80,7 +82,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(test_cfg)
assert (
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
@@ -218,7 +220,7 @@ class TestValidation(BaseValidation):
}
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert "batch_size is not recommended" in self._caplog.records[0].message
@@ -513,7 +515,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"BetterTransformers probably doesn't work with PEFT adapters"
@@ -531,7 +533,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"probably set bfloat16 or float16" in record.message
@@ -577,7 +579,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
@@ -595,7 +597,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
@@ -654,7 +656,7 @@ class TestValidation(BaseValidation):
)
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
@@ -673,7 +675,7 @@ class TestValidation(BaseValidation):
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
with self._caplog.at_level("INFO"):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
@@ -1109,7 +1111,7 @@ class TestValidation(BaseValidation):
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 1
@@ -1118,7 +1120,7 @@ class TestValidation(BaseValidation):
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 1
@@ -1128,7 +1130,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
@@ -1138,28 +1140,28 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_none(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_dpo_beta_deprecation(self, minimal_cfg):
cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
new_cfg = validate_config(cfg)
assert new_cfg["rl_beta"] == 0.2
assert new_cfg["dpo_beta"] is None
@@ -1175,7 +1177,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
new_cfg = validate_config(cfg)
assert new_cfg.eval_strategy == "steps"
assert (
@@ -1455,7 +1457,7 @@ class TestValidationWandb(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
new_cfg = validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."

View File

@@ -3,14 +3,13 @@ tests for chat_template prompt strategy
"""
# pylint: disable=duplicate-code
import logging
import unittest
from axolotl.prompt_strategies.messages.chat import load
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__, log_level="DEBUG")
class TestMessagesChatLlama3:

View File

@@ -2,7 +2,6 @@
tests for chat_template prompt strategy
"""
import logging
import unittest
from axolotl.prompt_strategies.chat_template import (
@@ -13,9 +12,9 @@ from axolotl.prompt_strategies.chat_template import (
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
class TestAssistantChatTemplateLlama3:

View File

@@ -4,7 +4,6 @@ tests for chat_template prompt strategy
# pylint: disable=too-many-lines
import logging
from copy import deepcopy
import pytest
@@ -18,11 +17,11 @@ from axolotl.prompt_strategies.chat_template import (
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.logging import get_logger
from tests.hf_offline_utils import enable_hf_offline
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
PARAMETRIZE_KEYS = "tokenizer, chat_template, chat_template_jinja, eos_token"
PARAMETRIZE_PARAMS = [

View File

@@ -2,8 +2,6 @@
Tests for splitting reasoning/thinking from content into separate field
"""
import logging
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
@@ -12,11 +10,11 @@ from axolotl.prompt_strategies.chat_template import (
load,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.hf_offline_utils import enable_hf_offline
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
@pytest.fixture(name="messages_w_reasoning")

View File

@@ -2,14 +2,12 @@
tests for jinja_template_analyzer
"""
import logging
import pytest
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.utils.logging import get_logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__, log_level="DEBUG")
class TestJinjaTemplateAnalyzer:

View File

@@ -1,7 +1,6 @@
"""Module for testing prompt tokenizers."""
import json
import logging
from pathlib import Path
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
@@ -17,10 +16,11 @@ from axolotl.prompt_strategies.orpo.chat_template import load
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter, PromptStyle
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
test_data = {
"multi_turn_sys": {