roundup_power2_divisions not needed with newer pytorch versions (#3540)
* roundup_power2_divisions not needed with newer pytorch versions * remove typo * update qwen3.5 moe 35b-a3b yaml for 5090 * more bug fixes * fix tests to match updated trainer * don't use fa2 for hooks test * reset plugins on the instance * retry download * fix references to renamed axolotl_cfg property on trainer * Fix ref to trainer cfg
This commit is contained in:
24
.github/workflows/tests.yml
vendored
24
.github/workflows/tests.yml
vendored
@@ -68,13 +68,13 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
python_version: ["3.12", "3.14"]
|
||||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||||
# exclude:
|
exclude:
|
||||||
# - python_version: "3.14"
|
- python_version: "3.14"
|
||||||
# pytorch_version: "2.8.0"
|
pytorch_version: "2.8.0"
|
||||||
# - python_version: "3.14"
|
- python_version: "3.14"
|
||||||
# pytorch_version: "2.9.1"
|
pytorch_version: "2.9.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -164,13 +164,13 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
python_version: ["3.12", "3.14"]
|
||||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||||
# exclude:
|
exclude:
|
||||||
# - python_version: "3.14"
|
- python_version: "3.14"
|
||||||
# pytorch_version: "2.8.0"
|
pytorch_version: "2.8.0"
|
||||||
# - python_version: "3.14"
|
- python_version: "3.14"
|
||||||
# pytorch_version: "2.9.1"
|
pytorch_version: "2.9.1"
|
||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
12
cicd/cicd.sh
12
cicd/cicd.sh
@@ -4,7 +4,17 @@ set -e
|
|||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
for i in 1 2 3; do
|
||||||
|
if curl --silent --show-error --fail -L \
|
||||||
|
https://axolotl-ci.b-cdn.net/hf-cache.tar.zst \
|
||||||
|
| tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1; then
|
||||||
|
echo "HF cache extracted successfully"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
echo "Attempt $i failed, cleaning up and retrying in 15s..."
|
||||||
|
rm -rf "${HF_HOME}/hub/"*
|
||||||
|
sleep 15
|
||||||
|
done
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||||
# hf download "microsoft/Phi-4-reasoning"
|
# hf download "microsoft/Phi-4-reasoning"
|
||||||
|
|||||||
@@ -1,8 +1,18 @@
|
|||||||
base_model: Qwen/Qwen3.5-35B-A3B
|
base_model: Qwen/Qwen3.5-35B-A3B-Base
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
strict: false
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
- axolotl.integrations.liger.LigerPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
liger_layer_norm: true
|
||||||
|
liger_rope: true
|
||||||
|
liger_rms_norm: true
|
||||||
|
liger_glu_activation: true
|
||||||
|
liger_rms_norm_gated: true
|
||||||
|
|
||||||
|
torch_compile: false
|
||||||
|
|
||||||
chat_template: qwen3_5
|
chat_template: qwen3_5
|
||||||
datasets:
|
datasets:
|
||||||
@@ -13,6 +23,7 @@ datasets:
|
|||||||
message_property_mappings:
|
message_property_mappings:
|
||||||
role: from
|
role: from
|
||||||
content: value
|
content: value
|
||||||
|
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
@@ -36,9 +47,13 @@ lora_target_modules:
|
|||||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||||
|
|
||||||
# Target experts
|
# Target experts
|
||||||
# lora_target_parameters:
|
lora_target_parameters:
|
||||||
# - mlp.experts.gate_up_proj
|
- mlp.experts.gate_up_proj
|
||||||
# - mlp.experts.down_proj
|
- mlp.experts.down_proj
|
||||||
|
|
||||||
|
lora_qkv_kernel: true
|
||||||
|
lora_o_kernel: true
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
@@ -47,22 +62,17 @@ wandb_name:
|
|||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
gradient_accumulation_steps: 2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 4
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: adamw_torch_4bit
|
optimizer: adamw_torch_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.0002
|
learning_rate: 0.0002
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
tf32: true
|
tf32: true
|
||||||
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
activation_offloading: true
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -89,7 +89,7 @@ def parse_requirements(extras_require_map):
|
|||||||
]
|
]
|
||||||
if not install_xformers:
|
if not install_xformers:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
extras_require_map["vllm"] = ["vllm==0.17.1"]
|
extras_require_map["vllm"] = ["vllm>=0.17.1"]
|
||||||
elif (major, minor) >= (2, 9):
|
elif (major, minor) >= (2, 9):
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
extras_require_map.pop("fbgemm-gpu")
|
||||||
extras_require_map["fbgemm-gpu"] = [
|
extras_require_map["fbgemm-gpu"] = [
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import httpcore
|
import httpcore
|
||||||
|
import httpx
|
||||||
from accelerate.commands.config import config_args
|
from accelerate.commands.config import config_args
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
@@ -48,7 +49,7 @@ def check_user_token() -> bool:
|
|||||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
except (HTTPError, httpcore.ConnectError):
|
except (HTTPError, httpcore.ConnectError, httpx.ConnectError):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,7 +36,9 @@ class DiffusionGenerationCallback(TrainerCallback):
|
|||||||
"""Generate samples at specified intervals."""
|
"""Generate samples at specified intervals."""
|
||||||
if (
|
if (
|
||||||
state.global_step > 0
|
state.global_step > 0
|
||||||
and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0
|
and state.global_step
|
||||||
|
% self.trainer.axolotl_cfg.diffusion.generation_interval
|
||||||
|
== 0
|
||||||
):
|
):
|
||||||
if not self.trainer.state.is_world_process_zero:
|
if not self.trainer.state.is_world_process_zero:
|
||||||
return
|
return
|
||||||
@@ -52,7 +54,7 @@ class DiffusionGenerationCallback(TrainerCallback):
|
|||||||
dataloader = self.trainer.get_train_dataloader()
|
dataloader = self.trainer.get_train_dataloader()
|
||||||
|
|
||||||
# Generate samples
|
# Generate samples
|
||||||
diffusion_cfg = self.trainer.cfg.diffusion
|
diffusion_cfg = self.trainer.axolotl_cfg.diffusion
|
||||||
samples = generate_samples(
|
samples = generate_samples(
|
||||||
model=self.trainer.model,
|
model=self.trainer.model,
|
||||||
tokenizer=self.trainer.processing_class,
|
tokenizer=self.trainer.processing_class,
|
||||||
@@ -142,7 +144,7 @@ class DiffusionGenerationCallback(TrainerCallback):
|
|||||||
|
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
|
|
||||||
if self.trainer.cfg.use_wandb:
|
if self.trainer.axolotl_cfg.use_wandb:
|
||||||
if wandb.run is not None:
|
if wandb.run is not None:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -38,4 +38,6 @@ class DiffusionPlugin(BasePlugin):
|
|||||||
|
|
||||||
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
|
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
|
||||||
"""Configure trainer after creation."""
|
"""Configure trainer after creation."""
|
||||||
trainer.set_config(cfg)
|
if hasattr(trainer, "axolotl_cfg"):
|
||||||
|
trainer.axolotl_cfg = cfg
|
||||||
|
trainer.post_set_axolotl_cfg()
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
from .callbacks import DiffusionGenerationCallback
|
from .callbacks import DiffusionGenerationCallback
|
||||||
@@ -21,19 +20,17 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.cfg = None
|
|
||||||
self._special_token_ids = None
|
self._special_token_ids = None
|
||||||
|
|
||||||
def set_config(self, config: DictDefault):
|
def post_set_axolotl_cfg(self):
|
||||||
"""Set config for diffusion training."""
|
"""Set config for diffusion training."""
|
||||||
self.cfg = config
|
|
||||||
self._cache_special_token_ids()
|
self._cache_special_token_ids()
|
||||||
self._resolve_mask_token_id()
|
self._resolve_mask_token_id()
|
||||||
|
|
||||||
token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0))
|
token_id = int(getattr(self.axolotl_cfg.diffusion, "mask_token_id", 0))
|
||||||
LOG.info(f"Diffusion: using mask_token_id={token_id}")
|
LOG.info(f"Diffusion: using mask_token_id={token_id}")
|
||||||
|
|
||||||
if getattr(config.diffusion, "generate_samples", True):
|
if getattr(self.axolotl_cfg.diffusion, "generate_samples", True):
|
||||||
generation_callback = DiffusionGenerationCallback(self)
|
generation_callback = DiffusionGenerationCallback(self)
|
||||||
self.add_callback(generation_callback)
|
self.add_callback(generation_callback)
|
||||||
|
|
||||||
@@ -41,18 +38,20 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
"""Ensure mask_token_id is valid for the current tokenizer."""
|
"""Ensure mask_token_id is valid for the current tokenizer."""
|
||||||
from .utils import resolve_mask_token_id
|
from .utils import resolve_mask_token_id
|
||||||
|
|
||||||
|
assert self.axolotl_cfg is not None, "axolotl_cfg is not set yet"
|
||||||
|
|
||||||
tokenizer = getattr(self, "processing_class", None)
|
tokenizer = getattr(self, "processing_class", None)
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
mid = resolve_mask_token_id(
|
mid = resolve_mask_token_id(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
self.cfg,
|
self.axolotl_cfg,
|
||||||
allow_add=True,
|
allow_add=True,
|
||||||
model=getattr(self, "model", None),
|
model=getattr(self, "model", None),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.cfg.diffusion.mask_token_id = int(mid)
|
self.axolotl_cfg.diffusion.mask_token_id = int(mid)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -150,7 +149,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
masked_indices = masked_indices & answer_mask
|
masked_indices = masked_indices & answer_mask
|
||||||
|
|
||||||
# Create masked input
|
# Create masked input
|
||||||
mask_token_id = int(self.cfg.diffusion.mask_token_id)
|
mask_token_id = int(self.axolotl_cfg.diffusion.mask_token_id)
|
||||||
mask_value = torch.full_like(input_ids, mask_token_id)
|
mask_value = torch.full_like(input_ids, mask_token_id)
|
||||||
noisy_batch = torch.where(masked_indices, mask_value, input_ids)
|
noisy_batch = torch.where(masked_indices, mask_value, input_ids)
|
||||||
|
|
||||||
@@ -194,12 +193,12 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
|
|
||||||
# Apply forward process
|
# Apply forward process
|
||||||
noisy_batch, masked_indices, p_mask = self._forward_process(
|
noisy_batch, masked_indices, p_mask = self._forward_process(
|
||||||
input_ids, attention_mask, labels, self.cfg.diffusion.eps
|
input_ids, attention_mask, labels, self.axolotl_cfg.diffusion.eps
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create bidirectional attention mask
|
# Create bidirectional attention mask
|
||||||
bidirectional_mask = create_bidirectional_attention_mask(
|
bidirectional_mask = create_bidirectional_attention_mask(
|
||||||
input_ids, attention_mask, sample_packing=self.cfg.sample_packing
|
input_ids, attention_mask, sample_packing=self.axolotl_cfg.sample_packing
|
||||||
)
|
)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
@@ -222,7 +221,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
masked_logits.float(), masked_targets, reduction="none"
|
masked_logits.float(), masked_targets, reduction="none"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.diffusion.importance_weighting:
|
if self.axolotl_cfg.diffusion.importance_weighting:
|
||||||
masked_p_mask = masked_p_mask.float()
|
masked_p_mask = masked_p_mask.float()
|
||||||
weighted_loss = token_loss / masked_p_mask
|
weighted_loss = token_loss / masked_p_mask
|
||||||
else:
|
else:
|
||||||
@@ -251,7 +250,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
# Non-SFT: when importance weighting is enabled, use unbiased estimator
|
# Non-SFT: when importance weighting is enabled, use unbiased estimator
|
||||||
# (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
|
# (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
|
||||||
# for stable scaling across varying mask ratios.
|
# for stable scaling across varying mask ratios.
|
||||||
if self.cfg.diffusion.importance_weighting:
|
if self.axolotl_cfg.diffusion.importance_weighting:
|
||||||
loss = weighted_loss.sum() / (
|
loss = weighted_loss.sum() / (
|
||||||
input_ids.shape[0] * input_ids.shape[1]
|
input_ids.shape[0] * input_ids.shape[1]
|
||||||
)
|
)
|
||||||
@@ -283,7 +282,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# If doing SFT training, log answer-specific metrics
|
# If doing SFT training, log answer-specific metrics
|
||||||
if self.cfg.datasets is not None:
|
if self.axolotl_cfg.datasets is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
answer_mask = labels != -100
|
answer_mask = labels != -100
|
||||||
answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
|
answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
|
||||||
@@ -292,7 +291,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
|||||||
metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
|
metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
|
||||||
metrics["avg_answer_length"] = answer_lengths.mean().item()
|
metrics["avg_answer_length"] = answer_lengths.mean().item()
|
||||||
|
|
||||||
if self.cfg.diffusion.importance_weighting:
|
if self.axolotl_cfg.diffusion.importance_weighting:
|
||||||
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
||||||
|
|
||||||
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
|
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
|
||||||
|
|||||||
@@ -48,7 +48,8 @@ def set_pytorch_cuda_alloc_conf():
|
|||||||
"""Set up CUDA allocation config"""
|
"""Set up CUDA allocation config"""
|
||||||
torch_version = torch.__version__.split(".")
|
torch_version = torch.__version__.split(".")
|
||||||
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
config_value = "expandable_segments:True,roundup_power2_divisions:16"
|
config_value = "expandable_segments:True"
|
||||||
|
config_older_suffix = ",roundup_power2_divisions:16"
|
||||||
if (
|
if (
|
||||||
torch_major == 2
|
torch_major == 2
|
||||||
and torch_minor >= 9
|
and torch_minor >= 9
|
||||||
@@ -60,7 +61,7 @@ def set_pytorch_cuda_alloc_conf():
|
|||||||
and torch_minor >= 2
|
and torch_minor >= 2
|
||||||
and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None
|
and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None
|
||||||
):
|
):
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value + config_older_suffix
|
||||||
|
|
||||||
|
|
||||||
def set_misc_env():
|
def set_misc_env():
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Shared pytest fixtures"""
|
"""Shared pytest fixtures"""
|
||||||
|
|
||||||
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
@@ -473,6 +474,18 @@ def temp_dir() -> Generator[str, None, None]:
|
|||||||
shutil.rmtree(_temp_dir)
|
shutil.rmtree(_temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def reset_plugin_manager():
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
|
|
||||||
|
yield
|
||||||
|
PluginManager._cfg = None
|
||||||
|
# Don't reset _instance to None — module-level PLUGIN_MANAGER references
|
||||||
|
# in train.py, model.py, etc. would become stale
|
||||||
|
if PluginManager._instance is not None:
|
||||||
|
PluginManager._instance.plugins = collections.OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def torch_manual_seed():
|
def torch_manual_seed():
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_model_output_exists
|
from tests.e2e.utils import check_model_output_exists
|
||||||
@@ -62,6 +62,7 @@ class TestDiffusion:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
cfg = validate_config(cfg)
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
@@ -119,6 +120,7 @@ class TestDiffusion:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
cfg = validate_config(cfg)
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
|
|||||||
"""Create a diffusion trainer instance for testing methods directly."""
|
"""Create a diffusion trainer instance for testing methods directly."""
|
||||||
# Create a minimal trainer instance just for testing methods
|
# Create a minimal trainer instance just for testing methods
|
||||||
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
|
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
|
||||||
trainer.cfg = diffusion_config
|
trainer.axolotl_cfg = diffusion_config
|
||||||
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
|
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
|
||||||
trainer.processing_class = mock_tokenizer
|
trainer.processing_class = mock_tokenizer
|
||||||
trainer.store_metrics = Mock() # Mock metrics storage
|
trainer.store_metrics = Mock() # Mock metrics storage
|
||||||
@@ -70,7 +70,7 @@ class TestDiffusionTrainer:
|
|||||||
assert not masked_indices[special_token_positions].any()
|
assert not masked_indices[special_token_positions].any()
|
||||||
|
|
||||||
# Check that mask token is applied
|
# Check that mask token is applied
|
||||||
mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id
|
mask_token_id = diffusion_trainer_instance.axolotl_cfg.diffusion.mask_token_id
|
||||||
masked_positions = masked_indices
|
masked_positions = masked_indices
|
||||||
if masked_positions.any():
|
if masked_positions.any():
|
||||||
assert (noisy_batch[masked_positions] == mask_token_id).all()
|
assert (noisy_batch[masked_positions] == mask_token_id).all()
|
||||||
@@ -132,7 +132,7 @@ class TestDiffusionTrainer:
|
|||||||
self, diffusion_trainer_instance
|
self, diffusion_trainer_instance
|
||||||
):
|
):
|
||||||
"""Test bidirectional attention mask with sample packing."""
|
"""Test bidirectional attention mask with sample packing."""
|
||||||
diffusion_trainer_instance.cfg.sample_packing = True
|
diffusion_trainer_instance.axolotl_cfg.sample_packing = True
|
||||||
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
|
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
|
||||||
# Sample IDs: first sample (1), second sample (2)
|
# Sample IDs: first sample (1), second sample (2)
|
||||||
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
|
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
|
||||||
@@ -184,7 +184,7 @@ class TestDiffusionTrainer:
|
|||||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||||
mock_model.return_value = mock_outputs
|
mock_model.return_value = mock_outputs
|
||||||
mock_model.training = True
|
mock_model.training = True
|
||||||
diffusion_trainer_instance.cfg.datasets = Mock()
|
diffusion_trainer_instance.axolotl_cfg.datasets = Mock()
|
||||||
|
|
||||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||||
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class DummyTrainer:
|
|||||||
|
|
||||||
def __init__(self, use_eval: bool):
|
def __init__(self, use_eval: bool):
|
||||||
# Config used by callback
|
# Config used by callback
|
||||||
self.cfg = SimpleNamespace(
|
self.axolotl_cfg = SimpleNamespace(
|
||||||
diffusion=SimpleNamespace(
|
diffusion=SimpleNamespace(
|
||||||
generation_interval=1,
|
generation_interval=1,
|
||||||
num_generation_samples=1,
|
num_generation_samples=1,
|
||||||
|
|||||||
@@ -1176,7 +1176,7 @@ class TestSwanLabProfiling:
|
|||||||
|
|
||||||
# Mock trainer with SwanLab enabled
|
# Mock trainer with SwanLab enabled
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
mock_trainer.cfg = MagicMock(use_swanlab=True)
|
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
|
||||||
mock_trainer.__class__.__name__ = "TestTrainer"
|
mock_trainer.__class__.__name__ = "TestTrainer"
|
||||||
|
|
||||||
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
|
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
|
||||||
@@ -1199,7 +1199,7 @@ class TestSwanLabProfiling:
|
|||||||
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
|
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
|
||||||
|
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
mock_trainer.cfg = MagicMock(use_swanlab=False) # Disabled
|
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=False) # Disabled
|
||||||
|
|
||||||
with patch("swanlab.log") as mock_log:
|
with patch("swanlab.log") as mock_log:
|
||||||
with swanlab_profiling_context(mock_trainer, "test_function"):
|
with swanlab_profiling_context(mock_trainer, "test_function"):
|
||||||
@@ -1213,7 +1213,7 @@ class TestSwanLabProfiling:
|
|||||||
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
|
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
|
||||||
|
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
mock_trainer.cfg = MagicMock(use_swanlab=True)
|
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("swanlab.get_run", return_value=None),
|
patch("swanlab.get_run", return_value=None),
|
||||||
@@ -1294,7 +1294,7 @@ class TestSwanLabProfiling:
|
|||||||
)
|
)
|
||||||
|
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
mock_trainer.cfg = MagicMock(use_swanlab=True)
|
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
|
||||||
mock_trainer.__class__.__name__ = "TestTrainer"
|
mock_trainer.__class__.__name__ = "TestTrainer"
|
||||||
|
|
||||||
# Config that filters out very fast operations
|
# Config that filters out very fast operations
|
||||||
@@ -1320,7 +1320,7 @@ class TestSwanLabProfiling:
|
|||||||
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
|
from axolotl.integrations.swanlab.profiling import swanlab_profiling_context
|
||||||
|
|
||||||
mock_trainer = MagicMock()
|
mock_trainer = MagicMock()
|
||||||
mock_trainer.cfg = MagicMock(use_swanlab=True)
|
mock_trainer.axolotl_cfg = MagicMock(use_swanlab=True)
|
||||||
mock_trainer.__class__.__name__ = "TestTrainer"
|
mock_trainer.__class__.__name__ = "TestTrainer"
|
||||||
|
|
||||||
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
|
with patch("swanlab.get_run") as mock_get_run, patch("swanlab.log") as mock_log:
|
||||||
|
|||||||
Reference in New Issue
Block a user