roundup_power2_divisions not needed with newer pytorch versions (#3540)

* roundup_power2_divisions not needed with newer pytorch versions

* remove typo

* update qwen3.5 moe 35b-a3b yaml for 5090

* more bug fixes

* fix tests to match updated trainer

* don't use fa2 for hooks test

* reset plugins on the instance

* retry download

* fix references to renamed axolotl_cfg property on trainer

* Fix ref to trainer cfg
This commit is contained in:
Wing Lian
2026-03-24 15:40:05 -04:00
committed by GitHub
parent 86be9f329e
commit e412370877
14 changed files with 100 additions and 60 deletions

View File

@@ -68,13 +68,13 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged python_version: ["3.12", "3.14"]
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"] pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude: exclude:
# - python_version: "3.14" - python_version: "3.14"
# pytorch_version: "2.8.0" pytorch_version: "2.8.0"
# - python_version: "3.14" - python_version: "3.14"
# pytorch_version: "2.9.1" pytorch_version: "2.9.1"
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -164,13 +164,13 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged python_version: ["3.12", "3.14"]
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"] pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude: exclude:
# - python_version: "3.14" - python_version: "3.14"
# pytorch_version: "2.8.0" pytorch_version: "2.8.0"
# - python_version: "3.14" - python_version: "3.14"
# pytorch_version: "2.9.1" pytorch_version: "2.9.1"
timeout-minutes: 30 timeout-minutes: 30
steps: steps:

View File

@@ -4,7 +4,17 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
set -o pipefail set -o pipefail
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1 for i in 1 2 3; do
if curl --silent --show-error --fail -L \
https://axolotl-ci.b-cdn.net/hf-cache.tar.zst \
| tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1; then
echo "HF cache extracted successfully"
break
fi
echo "Attempt $i failed, cleaning up and retrying in 15s..."
rm -rf "${HF_HOME}/hub/"*
sleep 15
done
# hf download "NousResearch/Meta-Llama-3-8B" # hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct" # hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning" # hf download "microsoft/Phi-4-reasoning"

View File

@@ -1,8 +1,18 @@
base_model: Qwen/Qwen3.5-35B-A3B base_model: Qwen/Qwen3.5-35B-A3B-Base
plugins: plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false - axolotl.integrations.kernels.KernelsPlugin
- axolotl.integrations.liger.LigerPlugin
use_kernels: true
use_scattermoe: true
liger_layer_norm: true
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_rms_norm_gated: true
torch_compile: false
chat_template: qwen3_5 chat_template: qwen3_5
datasets: datasets:
@@ -13,6 +23,7 @@ datasets:
message_property_mappings: message_property_mappings:
role: from role: from
content: value content: value
val_set_size: 0.0 val_set_size: 0.0
output_dir: ./outputs/out output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
@@ -36,9 +47,13 @@ lora_target_modules:
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj' # lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts # Target experts
# lora_target_parameters: lora_target_parameters:
# - mlp.experts.gate_up_proj - mlp.experts.gate_up_proj
# - mlp.experts.down_proj - mlp.experts.down_proj
lora_qkv_kernel: true
lora_o_kernel: true
lora_mlp_kernel: false
wandb_project: wandb_project:
wandb_entity: wandb_entity:
@@ -47,22 +62,17 @@ wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
micro_batch_size: 1 micro_batch_size: 4
num_epochs: 1 num_epochs: 1
optimizer: adamw_torch_4bit optimizer: adamw_torch_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
bf16: auto bf16: auto
tf32: true tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true gradient_checkpointing: true
gradient_checkpointing_kwargs: activation_offloading: true
use_reentrant: false
resume_from_checkpoint: resume_from_checkpoint:
logging_steps: 1 logging_steps: 1
flash_attention: true flash_attention: true

View File

@@ -89,7 +89,7 @@ def parse_requirements(extras_require_map):
] ]
if not install_xformers: if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.17.1"] extras_require_map["vllm"] = ["vllm>=0.17.1"]
elif (major, minor) >= (2, 9): elif (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu") extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [ extras_require_map["fbgemm-gpu"] = [

View File

@@ -4,6 +4,7 @@ import os
from pathlib import Path from pathlib import Path
import httpcore import httpcore
import httpx
from accelerate.commands.config import config_args from accelerate.commands.config import config_args
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
@@ -48,7 +49,7 @@ def check_user_token() -> bool:
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." "Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
) )
return False return False
except (HTTPError, httpcore.ConnectError): except (HTTPError, httpcore.ConnectError, httpx.ConnectError):
LOG.warning( LOG.warning(
"Error accessing HuggingFace. This may be due to a network issue or rate limiting." "Error accessing HuggingFace. This may be due to a network issue or rate limiting."
) )

View File

@@ -36,7 +36,9 @@ class DiffusionGenerationCallback(TrainerCallback):
"""Generate samples at specified intervals.""" """Generate samples at specified intervals."""
if ( if (
state.global_step > 0 state.global_step > 0
and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0 and state.global_step
% self.trainer.axolotl_cfg.diffusion.generation_interval
== 0
): ):
if not self.trainer.state.is_world_process_zero: if not self.trainer.state.is_world_process_zero:
return return
@@ -52,7 +54,7 @@ class DiffusionGenerationCallback(TrainerCallback):
dataloader = self.trainer.get_train_dataloader() dataloader = self.trainer.get_train_dataloader()
# Generate samples # Generate samples
diffusion_cfg = self.trainer.cfg.diffusion diffusion_cfg = self.trainer.axolotl_cfg.diffusion
samples = generate_samples( samples = generate_samples(
model=self.trainer.model, model=self.trainer.model,
tokenizer=self.trainer.processing_class, tokenizer=self.trainer.processing_class,
@@ -142,7 +144,7 @@ class DiffusionGenerationCallback(TrainerCallback):
logger.info("=" * 60) logger.info("=" * 60)
if self.trainer.cfg.use_wandb: if self.trainer.axolotl_cfg.use_wandb:
if wandb.run is not None: if wandb.run is not None:
wandb.log( wandb.log(
{ {

View File

@@ -38,4 +38,6 @@ class DiffusionPlugin(BasePlugin):
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer): def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
"""Configure trainer after creation.""" """Configure trainer after creation."""
trainer.set_config(cfg) if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = cfg
trainer.post_set_axolotl_cfg()

View File

@@ -7,7 +7,6 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from axolotl.core.trainers.base import AxolotlTrainer from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from .callbacks import DiffusionGenerationCallback from .callbacks import DiffusionGenerationCallback
@@ -21,19 +20,17 @@ class DiffusionTrainer(AxolotlTrainer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cfg = None
self._special_token_ids = None self._special_token_ids = None
def set_config(self, config: DictDefault): def post_set_axolotl_cfg(self):
"""Set config for diffusion training.""" """Set config for diffusion training."""
self.cfg = config
self._cache_special_token_ids() self._cache_special_token_ids()
self._resolve_mask_token_id() self._resolve_mask_token_id()
token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0)) token_id = int(getattr(self.axolotl_cfg.diffusion, "mask_token_id", 0))
LOG.info(f"Diffusion: using mask_token_id={token_id}") LOG.info(f"Diffusion: using mask_token_id={token_id}")
if getattr(config.diffusion, "generate_samples", True): if getattr(self.axolotl_cfg.diffusion, "generate_samples", True):
generation_callback = DiffusionGenerationCallback(self) generation_callback = DiffusionGenerationCallback(self)
self.add_callback(generation_callback) self.add_callback(generation_callback)
@@ -41,18 +38,20 @@ class DiffusionTrainer(AxolotlTrainer):
"""Ensure mask_token_id is valid for the current tokenizer.""" """Ensure mask_token_id is valid for the current tokenizer."""
from .utils import resolve_mask_token_id from .utils import resolve_mask_token_id
assert self.axolotl_cfg is not None, "axolotl_cfg is not set yet"
tokenizer = getattr(self, "processing_class", None) tokenizer = getattr(self, "processing_class", None)
if tokenizer is None: if tokenizer is None:
return return
mid = resolve_mask_token_id( mid = resolve_mask_token_id(
tokenizer, tokenizer,
self.cfg, self.axolotl_cfg,
allow_add=True, allow_add=True,
model=getattr(self, "model", None), model=getattr(self, "model", None),
) )
try: try:
self.cfg.diffusion.mask_token_id = int(mid) self.axolotl_cfg.diffusion.mask_token_id = int(mid)
except Exception: except Exception:
pass pass
@@ -150,7 +149,7 @@ class DiffusionTrainer(AxolotlTrainer):
masked_indices = masked_indices & answer_mask masked_indices = masked_indices & answer_mask
# Create masked input # Create masked input
mask_token_id = int(self.cfg.diffusion.mask_token_id) mask_token_id = int(self.axolotl_cfg.diffusion.mask_token_id)
mask_value = torch.full_like(input_ids, mask_token_id) mask_value = torch.full_like(input_ids, mask_token_id)
noisy_batch = torch.where(masked_indices, mask_value, input_ids) noisy_batch = torch.where(masked_indices, mask_value, input_ids)
@@ -194,12 +193,12 @@ class DiffusionTrainer(AxolotlTrainer):
# Apply forward process # Apply forward process
noisy_batch, masked_indices, p_mask = self._forward_process( noisy_batch, masked_indices, p_mask = self._forward_process(
input_ids, attention_mask, labels, self.cfg.diffusion.eps input_ids, attention_mask, labels, self.axolotl_cfg.diffusion.eps
) )
# Create bidirectional attention mask # Create bidirectional attention mask
bidirectional_mask = create_bidirectional_attention_mask( bidirectional_mask = create_bidirectional_attention_mask(
input_ids, attention_mask, sample_packing=self.cfg.sample_packing input_ids, attention_mask, sample_packing=self.axolotl_cfg.sample_packing
) )
# Forward pass # Forward pass
@@ -222,7 +221,7 @@ class DiffusionTrainer(AxolotlTrainer):
masked_logits.float(), masked_targets, reduction="none" masked_logits.float(), masked_targets, reduction="none"
) )
if self.cfg.diffusion.importance_weighting: if self.axolotl_cfg.diffusion.importance_weighting:
masked_p_mask = masked_p_mask.float() masked_p_mask = masked_p_mask.float()
weighted_loss = token_loss / masked_p_mask weighted_loss = token_loss / masked_p_mask
else: else:
@@ -251,7 +250,7 @@ class DiffusionTrainer(AxolotlTrainer):
# Non-SFT: when importance weighting is enabled, use unbiased estimator # Non-SFT: when importance weighting is enabled, use unbiased estimator
# (sum(loss/p) / total_tokens). Otherwise, average over masked tokens # (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
# for stable scaling across varying mask ratios. # for stable scaling across varying mask ratios.
if self.cfg.diffusion.importance_weighting: if self.axolotl_cfg.diffusion.importance_weighting:
loss = weighted_loss.sum() / ( loss = weighted_loss.sum() / (
input_ids.shape[0] * input_ids.shape[1] input_ids.shape[0] * input_ids.shape[1]
) )
@@ -283,7 +282,7 @@ class DiffusionTrainer(AxolotlTrainer):
} }
# If doing SFT training, log answer-specific metrics # If doing SFT training, log answer-specific metrics
if self.cfg.datasets is not None: if self.axolotl_cfg.datasets is not None:
with torch.no_grad(): with torch.no_grad():
answer_mask = labels != -100 answer_mask = labels != -100
answer_lengths = answer_mask.sum(dim=1).float() # type: ignore answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
@@ -292,7 +291,7 @@ class DiffusionTrainer(AxolotlTrainer):
metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1) metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
metrics["avg_answer_length"] = answer_lengths.mean().item() metrics["avg_answer_length"] = answer_lengths.mean().item()
if self.cfg.diffusion.importance_weighting: if self.axolotl_cfg.diffusion.importance_weighting:
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item() metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
train_eval: Literal["train", "eval"] = "train" if model.training else "eval" train_eval: Literal["train", "eval"] = "train" if model.training else "eval"

View File

@@ -48,7 +48,8 @@ def set_pytorch_cuda_alloc_conf():
"""Set up CUDA allocation config""" """Set up CUDA allocation config"""
torch_version = torch.__version__.split(".") torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
config_value = "expandable_segments:True,roundup_power2_divisions:16" config_value = "expandable_segments:True"
config_older_suffix = ",roundup_power2_divisions:16"
if ( if (
torch_major == 2 torch_major == 2
and torch_minor >= 9 and torch_minor >= 9
@@ -60,7 +61,7 @@ def set_pytorch_cuda_alloc_conf():
and torch_minor >= 2 and torch_minor >= 2
and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None
): ):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value + config_older_suffix
def set_misc_env(): def set_misc_env():

View File

@@ -1,5 +1,6 @@
"""Shared pytest fixtures""" """Shared pytest fixtures"""
import collections
import functools import functools
import importlib import importlib
import logging import logging
@@ -473,6 +474,18 @@ def temp_dir() -> Generator[str, None, None]:
shutil.rmtree(_temp_dir) shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def reset_plugin_manager():
from axolotl.integrations.base import PluginManager
yield
PluginManager._cfg = None
# Don't reset _instance to None — module-level PLUGIN_MANAGER references
# in train.py, model.py, etc. would become stale
if PluginManager._instance is not None:
PluginManager._instance.plugins = collections.OrderedDict()
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def torch_manual_seed(): def torch_manual_seed():
torch.manual_seed(42) torch.manual_seed(42)

View File

@@ -2,7 +2,7 @@
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists from tests.e2e.utils import check_model_output_exists
@@ -62,6 +62,7 @@ class TestDiffusion:
} }
) )
prepare_plugins(cfg)
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg) dataset_meta = load_datasets(cfg=cfg)
@@ -119,6 +120,7 @@ class TestDiffusion:
} }
) )
prepare_plugins(cfg)
cfg = validate_config(cfg) cfg = validate_config(cfg)
normalize_config(cfg) normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg) dataset_meta = load_datasets(cfg=cfg)

View File

@@ -42,7 +42,7 @@ def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
"""Create a diffusion trainer instance for testing methods directly.""" """Create a diffusion trainer instance for testing methods directly."""
# Create a minimal trainer instance just for testing methods # Create a minimal trainer instance just for testing methods
trainer = object.__new__(DiffusionTrainer) # Bypass __init__ trainer = object.__new__(DiffusionTrainer) # Bypass __init__
trainer.cfg = diffusion_config trainer.axolotl_cfg = diffusion_config
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
trainer.processing_class = mock_tokenizer trainer.processing_class = mock_tokenizer
trainer.store_metrics = Mock() # Mock metrics storage trainer.store_metrics = Mock() # Mock metrics storage
@@ -70,7 +70,7 @@ class TestDiffusionTrainer:
assert not masked_indices[special_token_positions].any() assert not masked_indices[special_token_positions].any()
# Check that mask token is applied # Check that mask token is applied
mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id mask_token_id = diffusion_trainer_instance.axolotl_cfg.diffusion.mask_token_id
masked_positions = masked_indices masked_positions = masked_indices
if masked_positions.any(): if masked_positions.any():
assert (noisy_batch[masked_positions] == mask_token_id).all() assert (noisy_batch[masked_positions] == mask_token_id).all()
@@ -132,7 +132,7 @@ class TestDiffusionTrainer:
self, diffusion_trainer_instance self, diffusion_trainer_instance
): ):
"""Test bidirectional attention mask with sample packing.""" """Test bidirectional attention mask with sample packing."""
diffusion_trainer_instance.cfg.sample_packing = True diffusion_trainer_instance.axolotl_cfg.sample_packing = True
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
# Sample IDs: first sample (1), second sample (2) # Sample IDs: first sample (1), second sample (2)
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long) attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
@@ -184,7 +184,7 @@ class TestDiffusionTrainer:
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True) mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
mock_model.return_value = mock_outputs mock_model.return_value = mock_outputs
mock_model.training = True mock_model.training = True
diffusion_trainer_instance.cfg.datasets = Mock() diffusion_trainer_instance.axolotl_cfg.datasets = Mock()
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long) input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long) labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)

View File

@@ -13,7 +13,7 @@ class DummyTrainer:
def __init__(self, use_eval: bool): def __init__(self, use_eval: bool):
# Config used by callback # Config used by callback
self.cfg = SimpleNamespace( self.axolotl_cfg = SimpleNamespace(
diffusion=SimpleNamespace( diffusion=SimpleNamespace(
generation_interval=1, generation_interval=1,
num_generation_samples=1, num_generation_samples=1,

View File

@@ -1176,7 +1176,7 @@ class TestSwanLabProfiling:
# Mock trainer with SwanLab enabled # Mock trainer with SwanLab enabled
mock_trainer = MagicMock() mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=True) mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
mock_trainer.__class__.__name__ = "TestTrainer" mock_trainer.__class__.__name__ = "TestTrainer"
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
@@ -1199,7 +1199,7 @@ class TestSwanLabProfiling:
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
mock_trainer = MagicMock() mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=False) # Disabled mock_trainer.axolotl_cfg = MagicMock(use_swanlab=False) # Disabled
with patch("swanlab.log") as mock_log: with patch("swanlab.log") as mock_log:
with swanlab_profiling_context(mock_trainer, "test_function"): with swanlab_profiling_context(mock_trainer, "test_function"):
@@ -1213,7 +1213,7 @@ class TestSwanLabProfiling:
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
mock_trainer = MagicMock() mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=True) mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
with ( with (
patch("swanlab.get_run", return_value=None), patch("swanlab.get_run", return_value=None),
@@ -1294,7 +1294,7 @@ class TestSwanLabProfiling:
) )
mock_trainer = MagicMock() mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=True) mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
mock_trainer.__class__.__name__ = "TestTrainer" mock_trainer.__class__.__name__ = "TestTrainer"
# Config that filters out very fast operations # Config that filters out very fast operations
@@ -1320,7 +1320,7 @@ class TestSwanLabProfiling:
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
mock_trainer = MagicMock() mock_trainer = MagicMock()
mock_trainer.cfg = MagicMock(use_swanlab=True) mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
mock_trainer.__class__.__name__ = "TestTrainer" mock_trainer.__class__.__name__ = "TestTrainer"
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log: