roundup_power2_divisions not needed with newer pytorch versions (#3540)
* roundup_power2_divisions not needed with newer pytorch versions * remove typo * update qwen3.5 moe 35b-a3b yaml for 5090 * more bug fixes * fix tests to match updated trainer * don't use fa2 for hooks test * reset plugins on the instance * retry download * fix references to renamed axolotl_cfg property on trainer * Fix ref to trainer cfg
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpcore
|
||||
import httpx
|
||||
from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
@@ -48,7 +49,7 @@ def check_user_token() -> bool:
|
||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except (HTTPError, httpcore.ConnectError):
|
||||
except (HTTPError, httpcore.ConnectError, httpx.ConnectError):
|
||||
LOG.warning(
|
||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||
)
|
||||
|
||||
@@ -36,7 +36,9 @@ class DiffusionGenerationCallback(TrainerCallback):
|
||||
"""Generate samples at specified intervals."""
|
||||
if (
|
||||
state.global_step > 0
|
||||
and state.global_step % self.trainer.cfg.diffusion.generation_interval == 0
|
||||
and state.global_step
|
||||
% self.trainer.axolotl_cfg.diffusion.generation_interval
|
||||
== 0
|
||||
):
|
||||
if not self.trainer.state.is_world_process_zero:
|
||||
return
|
||||
@@ -52,7 +54,7 @@ class DiffusionGenerationCallback(TrainerCallback):
|
||||
dataloader = self.trainer.get_train_dataloader()
|
||||
|
||||
# Generate samples
|
||||
diffusion_cfg = self.trainer.cfg.diffusion
|
||||
diffusion_cfg = self.trainer.axolotl_cfg.diffusion
|
||||
samples = generate_samples(
|
||||
model=self.trainer.model,
|
||||
tokenizer=self.trainer.processing_class,
|
||||
@@ -142,7 +144,7 @@ class DiffusionGenerationCallback(TrainerCallback):
|
||||
|
||||
logger.info("=" * 60)
|
||||
|
||||
if self.trainer.cfg.use_wandb:
|
||||
if self.trainer.axolotl_cfg.use_wandb:
|
||||
if wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
|
||||
@@ -38,4 +38,6 @@ class DiffusionPlugin(BasePlugin):
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: DiffusionTrainer):
|
||||
"""Configure trainer after creation."""
|
||||
trainer.set_config(cfg)
|
||||
if hasattr(trainer, "axolotl_cfg"):
|
||||
trainer.axolotl_cfg = cfg
|
||||
trainer.post_set_axolotl_cfg()
|
||||
|
||||
@@ -7,7 +7,6 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .callbacks import DiffusionGenerationCallback
|
||||
@@ -21,19 +20,17 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cfg = None
|
||||
self._special_token_ids = None
|
||||
|
||||
def set_config(self, config: DictDefault):
|
||||
def post_set_axolotl_cfg(self):
|
||||
"""Set config for diffusion training."""
|
||||
self.cfg = config
|
||||
self._cache_special_token_ids()
|
||||
self._resolve_mask_token_id()
|
||||
|
||||
token_id = int(getattr(self.cfg.diffusion, "mask_token_id", 0))
|
||||
token_id = int(getattr(self.axolotl_cfg.diffusion, "mask_token_id", 0))
|
||||
LOG.info(f"Diffusion: using mask_token_id={token_id}")
|
||||
|
||||
if getattr(config.diffusion, "generate_samples", True):
|
||||
if getattr(self.axolotl_cfg.diffusion, "generate_samples", True):
|
||||
generation_callback = DiffusionGenerationCallback(self)
|
||||
self.add_callback(generation_callback)
|
||||
|
||||
@@ -41,18 +38,20 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
"""Ensure mask_token_id is valid for the current tokenizer."""
|
||||
from .utils import resolve_mask_token_id
|
||||
|
||||
assert self.axolotl_cfg is not None, "axolotl_cfg is not set yet"
|
||||
|
||||
tokenizer = getattr(self, "processing_class", None)
|
||||
if tokenizer is None:
|
||||
return
|
||||
|
||||
mid = resolve_mask_token_id(
|
||||
tokenizer,
|
||||
self.cfg,
|
||||
self.axolotl_cfg,
|
||||
allow_add=True,
|
||||
model=getattr(self, "model", None),
|
||||
)
|
||||
try:
|
||||
self.cfg.diffusion.mask_token_id = int(mid)
|
||||
self.axolotl_cfg.diffusion.mask_token_id = int(mid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -150,7 +149,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
masked_indices = masked_indices & answer_mask
|
||||
|
||||
# Create masked input
|
||||
mask_token_id = int(self.cfg.diffusion.mask_token_id)
|
||||
mask_token_id = int(self.axolotl_cfg.diffusion.mask_token_id)
|
||||
mask_value = torch.full_like(input_ids, mask_token_id)
|
||||
noisy_batch = torch.where(masked_indices, mask_value, input_ids)
|
||||
|
||||
@@ -194,12 +193,12 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
|
||||
# Apply forward process
|
||||
noisy_batch, masked_indices, p_mask = self._forward_process(
|
||||
input_ids, attention_mask, labels, self.cfg.diffusion.eps
|
||||
input_ids, attention_mask, labels, self.axolotl_cfg.diffusion.eps
|
||||
)
|
||||
|
||||
# Create bidirectional attention mask
|
||||
bidirectional_mask = create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask, sample_packing=self.cfg.sample_packing
|
||||
input_ids, attention_mask, sample_packing=self.axolotl_cfg.sample_packing
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
@@ -222,7 +221,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
masked_logits.float(), masked_targets, reduction="none"
|
||||
)
|
||||
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
if self.axolotl_cfg.diffusion.importance_weighting:
|
||||
masked_p_mask = masked_p_mask.float()
|
||||
weighted_loss = token_loss / masked_p_mask
|
||||
else:
|
||||
@@ -251,7 +250,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
# Non-SFT: when importance weighting is enabled, use unbiased estimator
|
||||
# (sum(loss/p) / total_tokens). Otherwise, average over masked tokens
|
||||
# for stable scaling across varying mask ratios.
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
if self.axolotl_cfg.diffusion.importance_weighting:
|
||||
loss = weighted_loss.sum() / (
|
||||
input_ids.shape[0] * input_ids.shape[1]
|
||||
)
|
||||
@@ -283,7 +282,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
}
|
||||
|
||||
# If doing SFT training, log answer-specific metrics
|
||||
if self.cfg.datasets is not None:
|
||||
if self.axolotl_cfg.datasets is not None:
|
||||
with torch.no_grad():
|
||||
answer_mask = labels != -100
|
||||
answer_lengths = answer_mask.sum(dim=1).float() # type: ignore
|
||||
@@ -292,7 +291,7 @@ class DiffusionTrainer(AxolotlTrainer):
|
||||
metrics["answer_ratio"] = total_answer_tokens / max(total_tokens, 1)
|
||||
metrics["avg_answer_length"] = answer_lengths.mean().item()
|
||||
|
||||
if self.cfg.diffusion.importance_weighting:
|
||||
if self.axolotl_cfg.diffusion.importance_weighting:
|
||||
metrics["importance_weight_avg"] = (1.0 / masked_p_mask).mean().item()
|
||||
|
||||
train_eval: Literal["train", "eval"] = "train" if model.training else "eval"
|
||||
|
||||
@@ -48,7 +48,8 @@ def set_pytorch_cuda_alloc_conf():
|
||||
"""Set up CUDA allocation config"""
|
||||
torch_version = torch.__version__.split(".")
|
||||
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||
config_value = "expandable_segments:True,roundup_power2_divisions:16"
|
||||
config_value = "expandable_segments:True"
|
||||
config_older_suffix = ",roundup_power2_divisions:16"
|
||||
if (
|
||||
torch_major == 2
|
||||
and torch_minor >= 9
|
||||
@@ -60,7 +61,7 @@ def set_pytorch_cuda_alloc_conf():
|
||||
and torch_minor >= 2
|
||||
and os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None
|
||||
):
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_value + config_older_suffix
|
||||
|
||||
|
||||
def set_misc_env():
|
||||
|
||||
Reference in New Issue
Block a user