roundup_power2_divisions not needed with newer pytorch versions (#3540)

* roundup_power2_divisions not needed with newer pytorch versions

* remove typo

* update qwen3.5 moe 35b-a3b yaml for 5090

* more bug fixes

* fix tests to match updated trainer

* don't use fa2 for hooks test

* reset plugins on the instance

* retry download

* fix references to renamed axolotl_cfg property on trainer

* Fix ref to trainer cfg
This commit is contained in:
Wing Lian
2026-03-24 15:40:05 -04:00
committed by GitHub
parent 86be9f329e
commit e412370877
14 changed files with 100 additions and 60 deletions

View File

@@ -1,5 +1,6 @@
"""Shared pytest fixtures"""
import collections
import functools
import importlib
import logging
@@ -473,6 +474,18 @@ def temp_dir() -> Generator[str, None, None]:
shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def reset_plugin_manager():
from axolotl.integrations.base import PluginManager
yield
PluginManager._cfg = None
# Don't reset _instance to None — module-level PLUGIN_MANAGER references
# in train.py, model.py, etc. would become stale
if PluginManager._instance is not None:
PluginManager._instance.plugins = collections.OrderedDict()
@pytest.fixture(scope="function", autouse=True)
def torch_manual_seed():
torch.manual_seed(42)

View File

@@ -2,7 +2,7 @@
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
@@ -62,6 +62,7 @@ class TestDiffusion:
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
@@ -119,6 +120,7 @@ class TestDiffusion:
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)

View File

@@ -42,7 +42,7 @@ def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion trainer instance for testing methods directly."""
# Create a minimal trainer instance just for testing methods
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.cfg = diffusion_config
trainer.axolotl_cfg = diffusion_config
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
trainer.processing_class = mock_tokenizer
trainer.store_metrics = Mock() # Mock metrics storage
@@ -70,7 +70,7 @@ class TestDiffusionTrainer:
assert not masked_indices[special_token_positions].any()
# Check that mask token is applied
mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id
mask_token_id = diffusion_trainer_instance.axolotl_cfg.diffusion.mask_token_id
masked_positions = masked_indices
if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all()
@@ -132,7 +132,7 @@ class TestDiffusionTrainer:
self, diffusion_trainer_instance
):
"""Test bidirectional attention mask with sample packing."""
diffusion_trainer_instance.cfg.sample_packing = True
diffusion_trainer_instance.axolotl_cfg.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
@@ -184,7 +184,7 @@ class TestDiffusionTrainer:
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs
mock_model.training = True
diffusion_trainer_instance.cfg.datasets = Mock()
diffusion_trainer_instance.axolotl_cfg.datasets = Mock()
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)

View File

@@ -13,7 +13,7 @@ class DummyTrainer:
def __init__(self, use_eval: bool):
# Config used by callback
self.cfg = SimpleNamespace(
self.axolotl_cfg = SimpleNamespace(
diffusion=SimpleNamespace(
generation_interval=1,
num_generation_samples=1,

View File

@@ -1176,7 +1176,7 @@ class TestSwanLabProfiling:
# Mock trainer with SwanLab enabled
mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=True)
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
mock_trainer.__class__.__name__ = "TestTrainer"
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
@@ -1199,7 +1199,7 @@ class TestSwanLabProfiling:
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=False) # Disabled
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=False) # Disabled
with patch("swanlab.log") as mock_log:
with swanlab_profiling_context(mock_trainer, "test_function"):
@@ -1213,7 +1213,7 @@ class TestSwanLabProfiling:
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=True)
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
with (
patch("swanlab.get_run", return_value=None),
@@ -1294,7 +1294,7 @@ class TestSwanLabProfiling:
)
mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=True)
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
mock_trainer.__class__.__name__ = "TestTrainer"
# Config that filters out very fast operations
@@ -1320,7 +1320,7 @@ class TestSwanLabProfiling:
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=True)
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
mock_trainer.__class__.__name__ = "TestTrainer"
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: