roundup_power2_divisions not needed with newer pytorch versions (#3540)
* roundup_power2_divisions not needed with newer pytorch versions * remove typo * update qwen3.5 moe 35b-a3b yaml for 5090 * more bug fixes * fix tests to match updated trainer * don't use fa2 for hooks test * reset plugins on the instance * retry download * fix references to renamed axolotl_cfg property on trainer * Fix ref to trainer cfg
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Shared pytest fixtures"""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
@@ -473,6 +474,18 @@ def temp_dir() -> Generator[str, None, None]:
|
||||
shutil.rmtree(_temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def reset_plugin_manager():
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
yield
|
||||
PluginManager._cfg = None
|
||||
# Don't reset _instance to None — module-level PLUGIN_MANAGER references
|
||||
# in train.py, model.py, etc. would become stale
|
||||
if PluginManager._instance is not None:
|
||||
PluginManager._instance.plugins = collections.OrderedDict()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def torch_manual_seed():
|
||||
torch.manual_seed(42)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists
|
||||
@@ -62,6 +62,7 @@ class TestDiffusion:
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
@@ -119,6 +120,7 @@ class TestDiffusion:
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
@@ -42,7 +42,7 @@ def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
|
||||
"""Create a diffusion trainer instance for testing methods directly."""
|
||||
# Create a minimal trainer instance just for testing methods
|
||||
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
|
||||
trainer.cfg = diffusion_config
|
||||
trainer.axolotl_cfg = diffusion_config
|
||||
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
|
||||
trainer.processing_class = mock_tokenizer
|
||||
trainer.store_metrics = Mock() # Mock metrics storage
|
||||
@@ -70,7 +70,7 @@ class TestDiffusionTrainer:
|
||||
assert not masked_indices[special_token_positions].any()
|
||||
|
||||
# Check that mask token is applied
|
||||
mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id
|
||||
mask_token_id = diffusion_trainer_instance.axolotl_cfg.diffusion.mask_token_id
|
||||
masked_positions = masked_indices
|
||||
if masked_positions.any():
|
||||
assert (noisy_batch[masked_positions] == mask_token_id).all()
|
||||
@@ -132,7 +132,7 @@ class TestDiffusionTrainer:
|
||||
self, diffusion_trainer_instance
|
||||
):
|
||||
"""Test bidirectional attention mask with sample packing."""
|
||||
diffusion_trainer_instance.cfg.sample_packing = True
|
||||
diffusion_trainer_instance.axolotl_cfg.sample_packing = True
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
|
||||
# Sample IDs: first sample (1), second sample (2)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
|
||||
@@ -184,7 +184,7 @@ class TestDiffusionTrainer:
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
diffusion_trainer_instance.cfg.datasets = Mock()
|
||||
diffusion_trainer_instance.axolotl_cfg.datasets = Mock()
|
||||
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
@@ -13,7 +13,7 @@ class DummyTrainer:
|
||||
|
||||
def __init__(self, use_eval: bool):
|
||||
# Config used by callback
|
||||
self.cfg = SimpleNamespace(
|
||||
self.axolotl_cfg = SimpleNamespace(
|
||||
diffusion=SimpleNamespace(
|
||||
generation_interval=1,
|
||||
num_generation_samples=1,
|
||||
|
||||
@@ -1176,7 +1176,7 @@ class TestSwanLabProfiling:
|
||||
|
||||
# Mock trainer with SwanLab enabled
|
||||
mock_trainer = MagicMock()
|
||||
mock_trainer.cfg = MagicMock(use_swanlab=True)
|
||||
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
|
||||
mock_trainer.__class__.__name__ = "TestTrainer"
|
||||
|
||||
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
|
||||
@@ -1199,7 +1199,7 @@ class TestSwanLabProfiling:
|
||||
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
|
||||
|
||||
mock_trainer = MagicMock()
|
||||
mock_trainer.cfg = MagicMock(use_swanlab=False) # Disabled
|
||||
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=False) # Disabled
|
||||
|
||||
with patch("swanlab.log") as mock_log:
|
||||
with swanlab_profiling_context(mock_trainer, "test_function"):
|
||||
@@ -1213,7 +1213,7 @@ class TestSwanLabProfiling:
|
||||
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
|
||||
|
||||
mock_trainer = MagicMock()
|
||||
mock_trainer.cfg = MagicMock(use_swanlab=True)
|
||||
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
|
||||
|
||||
with (
|
||||
patch("swanlab.get_run", return_value=None),
|
||||
@@ -1294,7 +1294,7 @@ class TestSwanLabProfiling:
|
||||
)
|
||||
|
||||
mock_trainer = MagicMock()
|
||||
mock_trainer.cfg = MagicMock(use_swanlab=True)
|
||||
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
|
||||
mock_trainer.__class__.__name__ = "TestTrainer"
|
||||
|
||||
# Config that filters out very fast operations
|
||||
@@ -1320,7 +1320,7 @@ class TestSwanLabProfiling:
|
||||
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
|
||||
|
||||
mock_trainer = MagicMock()
|
||||
mock_trainer.cfg = MagicMock(use_swanlab=True)
|
||||
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
|
||||
mock_trainer.__class__.__name__ = "TestTrainer"
|
||||
|
||||
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
|
||||
|
||||
Reference in New Issue
Block a user