remove deprecated wandb env var (#2751)
* remove deprecated wandb env var * remove os.environ wandb setting; unused loggers * remove os.environ wandb setting; unused loggers
This commit is contained in:
@@ -16,6 +16,3 @@ def setup_wandb_env_vars(cfg: DictDefault):
|
|||||||
# Enable wandb if project name is present
|
# Enable wandb if project name is present
|
||||||
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||||
cfg.use_wandb = True
|
cfg.use_wandb = True
|
||||||
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
|
||||||
else:
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""E2E tests for sequence parallelism"""
|
"""E2E tests for sequence parallelism"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,8 +11,6 @@ from axolotl.utils.dict import DictDefault
|
|||||||
|
|
||||||
from ...utils import check_tensorboard
|
from ...utils import check_tensorboard
|
||||||
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestSequenceParallelism:
|
class TestSequenceParallelism:
|
||||||
"""Test case for training with sequence parallelism enabled"""
|
"""Test case for training with sequence parallelism enabled"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu lora tinyllama
|
E2E tests for multigpu lora tinyllama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -13,13 +12,9 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu eval
|
E2E tests for multigpu eval
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@@ -10,13 +9,9 @@ from accelerate.test_utils import execute_subprocess_async
|
|||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_tensorboard
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu lora tinyllama
|
E2E tests for multigpu lora tinyllama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,13 +11,9 @@ from huggingface_hub import snapshot_download
|
|||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard
|
from tests.e2e.utils import check_tensorboard
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu lora tinyllama
|
E2E tests for multigpu lora tinyllama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -14,13 +13,9 @@ from packaging import version
|
|||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu qwen2
|
E2E tests for multigpu qwen2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -11,10 +10,6 @@ from accelerate.test_utils import execute_subprocess_async
|
|||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUQwen2:
|
class TestMultiGPUQwen2:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multigpu post-training use Ray Train
|
E2E tests for multigpu post-training use Ray Train
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -10,13 +9,9 @@ import yaml
|
|||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for multipack fft llama using 4d attention masks
|
E2E tests for multipack fft llama using 4d attention masks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class Test4dMultipackLlama(unittest.TestCase):
|
class Test4dMultipackLlama(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
@@ -12,13 +10,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, check_tensorboard
|
from ..utils import check_model_output_exists, check_tensorboard
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestFAXentropyLlama:
|
class TestFAXentropyLlama:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for falcon
|
E2E tests for falcon
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,13 +11,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestFalconPatched(unittest.TestCase):
|
class TestFalconPatched(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -13,13 +12,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("FIXME, mostly underused functionality")
|
@pytest.mark.skip("FIXME, mostly underused functionality")
|
||||||
class TestFusedLlama(unittest.TestCase):
|
class TestFusedLlama(unittest.TestCase):
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for llama w/ S2 attn
|
E2E tests for llama w/ S2 attn
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,13 +11,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="FIXME?")
|
@pytest.mark.skip(reason="FIXME?")
|
||||||
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -13,13 +12,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoraLlama(unittest.TestCase):
|
class TestLoraLlama(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMistral(unittest.TestCase):
|
class TestMistral(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for mixtral
|
E2E tests for mixtral
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMixtral(unittest.TestCase):
|
class TestMixtral(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, with_temp_dir
|
from ..utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestPhiMultipack(unittest.TestCase):
|
class TestPhiMultipack(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for resuming training
|
E2E tests for resuming training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
@@ -13,13 +12,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
|
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestResumeLlama:
|
class TestResumeLlama:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
e2e tests for unsloth qlora
|
e2e tests for unsloth qlora
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -11,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, check_tensorboard
|
from ..utils import check_model_output_exists, check_tensorboard
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for packed training w/ flex attention
|
E2E tests for packed training w/ flex attention
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
@@ -12,13 +11,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir
|
from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestPackedFlex(unittest.TestCase):
|
class TestPackedFlex(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for relora llama
|
E2E tests for relora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -11,13 +10,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestReLoraLlama(unittest.TestCase):
|
class TestReLoraLlama(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for deepseekv3
|
E2E tests for deepseekv3
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,13 +11,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeepseekV3:
|
class TestDeepseekV3:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -13,13 +12,9 @@ from axolotl.common.datasets import load_preference_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestDPOLlamaLora(unittest.TestCase):
|
class TestDPOLlamaLora(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for llama pretrain
|
E2E tests for llama pretrain
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingsLrScale(unittest.TestCase):
|
class TestEmbeddingsLrScale(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""E2E smoke test for evaluate CLI command"""
|
"""E2E smoke test for evaluate CLI command"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@@ -9,8 +8,6 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestE2eEvaluate:
|
class TestE2eEvaluate:
|
||||||
"""Test cases for evaluate CLI"""
|
"""Test cases for evaluate CLI"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for falcon
|
E2E tests for falcon
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,13 +11,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestFalcon(unittest.TestCase):
|
class TestFalcon(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for gemma2
|
E2E tests for gemma2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,10 +11,6 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGemma2:
|
class TestGemma2:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for gemma3_text
|
E2E tests for gemma3_text
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,10 +11,6 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGemma3Text:
|
class TestGemma3Text:
|
||||||
|
|||||||
@@ -2,20 +2,14 @@
|
|||||||
E2E tests for llama
|
E2E tests for llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from tests.e2e.utils import check_model_output_exists
|
from tests.e2e.utils import check_model_output_exists
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestLlama:
|
class TestLlama:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
E2E tests for llama pretrain
|
E2E tests for llama pretrain
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -11,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard
|
from .utils import check_model_output_exists, check_tensorboard
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestPretrainLlama:
|
class TestPretrainLlama:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestLlamaVision(unittest.TestCase):
|
class TestLlamaVision(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoraLlama(unittest.TestCase):
|
class TestLoraLlama(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -12,13 +11,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="skipping until upstreamed into transformers")
|
@pytest.mark.skip(reason="skipping until upstreamed into transformers")
|
||||||
class TestMamba(unittest.TestCase):
|
class TestMamba(unittest.TestCase):
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
@@ -12,13 +11,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMistral(unittest.TestCase):
|
class TestMistral(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for mixtral
|
E2E tests for mixtral
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -13,13 +12,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMixtral(unittest.TestCase):
|
class TestMixtral(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for custom optimizers using Llama
|
E2E tests for custom optimizers using Llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
|
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestCustomOptimizers(unittest.TestCase):
|
class TestCustomOptimizers(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for packed training
|
E2E tests for packed training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
@@ -12,13 +11,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_tensorboard, with_temp_dir
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestPackedLlama(unittest.TestCase):
|
class TestPackedLlama(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for lora llama
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestPhi(unittest.TestCase):
|
class TestPhi(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for process reward model w/ lora llama
|
E2E tests for process reward model w/ lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestProcessRewardSmolLM2(unittest.TestCase):
|
class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for qwen
|
E2E tests for qwen
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -11,10 +10,6 @@ from accelerate.test_utils import execute_subprocess_async
|
|||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.qwen")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestE2eQwen:
|
class TestE2eQwen:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for reward model lora llama
|
E2E tests for reward model lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
E2E tests for custom schedulers using Llama
|
E2E tests for custom schedulers using Llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
@@ -10,13 +9,9 @@ from axolotl.common.datasets import load_datasets
|
|||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
|
|
||||||
LOG = get_logger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestCustomSchedulers(unittest.TestCase):
|
class TestCustomSchedulers(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -9,12 +9,8 @@ import pytest
|
|||||||
from axolotl.utils.config import prepare_plugins, validate_config
|
from axolotl.utils.config import prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.integrations.test_liger")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="minimal_liger_cfg")
|
@pytest.fixture(name="minimal_liger_cfg")
|
||||||
def fixture_cfg():
|
def fixture_cfg():
|
||||||
return DictDefault(
|
return DictDefault(
|
||||||
|
|||||||
@@ -12,15 +12,12 @@ from axolotl.loaders.utils import check_model_config
|
|||||||
from axolotl.utils import is_comet_available
|
from axolotl.utils import is_comet_available
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="minimal_cfg")
|
@pytest.fixture(name="minimal_cfg")
|
||||||
def fixture_cfg():
|
def fixture_cfg():
|
||||||
@@ -1507,7 +1504,6 @@ class TestValidationWandb(BaseValidation):
|
|||||||
assert os.environ.get("WANDB_MODE", "") == "online"
|
assert os.environ.get("WANDB_MODE", "") == "online"
|
||||||
assert os.environ.get("WANDB_WATCH", "") == "false"
|
assert os.environ.get("WANDB_WATCH", "") == "false"
|
||||||
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
||||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
|
||||||
|
|
||||||
os.environ.pop("WANDB_PROJECT", None)
|
os.environ.pop("WANDB_PROJECT", None)
|
||||||
os.environ.pop("WANDB_NAME", None)
|
os.environ.pop("WANDB_NAME", None)
|
||||||
@@ -1516,16 +1512,12 @@ class TestValidationWandb(BaseValidation):
|
|||||||
os.environ.pop("WANDB_MODE", None)
|
os.environ.pop("WANDB_MODE", None)
|
||||||
os.environ.pop("WANDB_WATCH", None)
|
os.environ.pop("WANDB_WATCH", None)
|
||||||
os.environ.pop("WANDB_LOG_MODEL", None)
|
os.environ.pop("WANDB_LOG_MODEL", None)
|
||||||
os.environ.pop("WANDB_DISABLED", None)
|
|
||||||
|
|
||||||
def test_wandb_set_disabled(self, minimal_cfg):
|
def test_wandb_set_disabled(self, minimal_cfg):
|
||||||
cfg = DictDefault({}) | minimal_cfg
|
cfg = DictDefault({}) | minimal_cfg
|
||||||
|
|
||||||
new_cfg = validate_config(cfg)
|
new_cfg = validate_config(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(new_cfg)
|
setup_wandb_env_vars(new_cfg)
|
||||||
|
assert new_cfg.use_wandb is None
|
||||||
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
@@ -1537,13 +1529,10 @@ class TestValidationWandb(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
new_cfg = validate_config(cfg)
|
new_cfg = validate_config(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(new_cfg)
|
setup_wandb_env_vars(new_cfg)
|
||||||
|
assert new_cfg.use_wandb is True
|
||||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
|
||||||
|
|
||||||
os.environ.pop("WANDB_PROJECT", None)
|
os.environ.pop("WANDB_PROJECT", None)
|
||||||
os.environ.pop("WANDB_DISABLED", None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
|
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
|
||||||
|
|||||||
@@ -10,12 +10,9 @@ from axolotl.prompt_strategies.chat_template import (
|
|||||||
load,
|
load,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="messages_w_reasoning")
|
@pytest.fixture(name="messages_w_reasoning")
|
||||||
def messages_w_reasoning_fixture():
|
def messages_w_reasoning_fixture():
|
||||||
|
|||||||
@@ -16,12 +16,9 @@ from axolotl.prompt_strategies.orpo.chat_template import load
|
|||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
"multi_turn_sys": {
|
"multi_turn_sys": {
|
||||||
"conversations": [
|
"conversations": [
|
||||||
|
|||||||
Reference in New Issue
Block a user