fix token state json and mistral tokenizer issue (#3522) [skip ci]

* fix token state json and mistral tokenizer issue

* centralize constants

* forgot to commit constants file

* Fix weakref in pickling relora state dict

* make curl a bit quieter so it doesn't log 2K lines

* fix path traversal for olmoe test

* more test fixes that weren't flagged previously

* chore: lint

* skip tests that fail b/c of OutOfResources

* scattermoe as slow tests

* update fbgemm-genai for torch 2.10
This commit is contained in:
Wing Lian
2026-03-21 22:46:10 -04:00
committed by GitHub
parent 2c05847a5f
commit 0ee98a0309
22 changed files with 249 additions and 57 deletions

View File

@@ -208,7 +208,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
if (
self.cfg.adapter
and self.peft_config
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
):
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (

View File

@@ -29,6 +29,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.experimental.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
@@ -51,8 +52,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state."
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,

View File

@@ -0,0 +1 @@
TOKENS_STATE_FILE = "tokens_state.json"

View File

@@ -2,7 +2,8 @@
Axolotl specific DPO args
"""
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional
from trl import DPOConfig
@@ -16,3 +17,4 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
dpo_norm_loss: bool | None = False
rpo_alpha: Optional[float] = field(default=None)

View File

@@ -221,6 +221,14 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
# Generic fallback: if tokenizer still has no pad_token, use eos_token
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
LOG.warning(
"Tokenizer does not have a pad_token, falling back to eos_token: %s",
tokenizer.eos_token,
)
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()

View File

@@ -82,7 +82,7 @@ def setup_model_and_tokenizer(
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
model, peft_config = model_loader.load()
if model.generation_config is not None:
if getattr(model, "generation_config", None) is not None:
model.generation_config.do_sample = True
model_properties = model.config.to_dict()

View File

@@ -25,9 +25,11 @@ def toggle_fake_quant(mod: nn.Module, enable: bool):
if (
isinstance(mod, FakeQuantizedLinear)
and mod.activation_fake_quantizer is not None
and hasattr(mod.activation_fake_quantizer, "enabled")
):
mod.activation_fake_quantizer.enabled = enable
mod.weight_fake_quantizer.enabled = enable
if hasattr(mod.weight_fake_quantizer, "enabled"):
mod.weight_fake_quantizer.enabled = enable
class QATCallback(TrainerCallback):

View File

@@ -12,12 +12,11 @@ from transformers import (
TrainingArguments,
)
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state.json"
class TokensPerSecondCallback(TrainerCallback):
"""

View File

@@ -10,9 +10,11 @@ from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
)
from torchao.quantization.qat.fake_quantize_config import Int4WeightFakeQuantizeConfig
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
)
@@ -173,6 +175,70 @@ def quantize_model(
)
def _make_qat_config(
base_config: AOBaseConfig,
weight_dtype: TorchAOQuantDType,
activation_dtype: TorchAOQuantDType | None,
group_size: int | None,
) -> QATConfig:
"""Build a QATConfig, explicitly constructing fake quantize configs to ensure
group_size and other params are properly propagated (torchao's QATConfig(base_config)
does not always map these correctly)."""
from torchao.quantization.qat.fake_quantize_config import (
Float8FakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
if isinstance(base_config, MXFakeQuantizeConfig):
return QATConfig(
activation_config=base_config,
weight_config=base_config,
)
# Build explicit weight config
weight_fq_config: (
Int4WeightFakeQuantizeConfig
| IntxFakeQuantizeConfig
| Float8FakeQuantizeConfig
| None
) = None
if weight_dtype == TorchAOQuantDType.int4:
gs = (
group_size
if group_size is not None
else getattr(base_config, "group_size", 128)
)
activation_dt = None
if activation_dtype == TorchAOQuantDType.int8:
activation_dt = torch.bfloat16
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
activation_dt = torch.float8_e4m3fn
kwargs = {"group_size": gs}
if activation_dt is not None:
kwargs["activation_dtype"] = activation_dt
weight_fq_config = Int4WeightFakeQuantizeConfig(**kwargs)
elif weight_dtype == TorchAOQuantDType.float8_e4m3fn:
weight_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
# Build explicit activation config
activation_fq_config = None
if activation_dtype == TorchAOQuantDType.int8:
activation_fq_config = IntxFakeQuantizeConfig(
dtype=torch.int8, granularity="per_token", is_symmetric=False
)
elif activation_dtype == TorchAOQuantDType.float8_e4m3fn:
activation_fq_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn)
if weight_fq_config is not None:
return QATConfig(
weight_config=weight_fq_config,
activation_config=activation_fq_config,
)
# Fallback to base_config for unhandled combos
return QATConfig(base_config)
def prepare_model_for_qat(
model,
weight_dtype: TorchAOQuantDType,
@@ -200,13 +266,9 @@ def prepare_model_for_qat(
activation_dtype=activation_dtype,
group_size=group_size,
)
if isinstance(base_config, MXFakeQuantizeConfig):
qat_config = QATConfig(
activation_config=base_config,
weight_config=base_config,
)
else:
qat_config = QATConfig(base_config)
qat_config = _make_qat_config(
base_config, weight_dtype, activation_dtype, group_size
)
quantize_(model, qat_config)
if quantize_embedding:
# activation fake quantization is not supported for embedding layers
@@ -215,12 +277,9 @@ def prepare_model_for_qat(
activation_dtype=None,
group_size=group_size,
)
if isinstance(embedding_base_config, MXFakeQuantizeConfig):
embedding_qat_config = QATConfig(
weight_config=embedding_base_config,
)
else:
embedding_qat_config = QATConfig(embedding_base_config)
embedding_qat_config = _make_qat_config(
embedding_base_config, weight_dtype, None, group_size
)
quantize_(
model,
embedding_qat_config,

View File

@@ -2,7 +2,7 @@
import math
from functools import partial
from typing import Sequence
from typing import Any, Sequence
from torch import Tensor
from torch.optim import Optimizer
@@ -340,3 +340,19 @@ class JaggedLRRestartScheduler(LRScheduler):
return [lr * scale for lr in original]
return original * scale
def state_dict(self) -> dict[str, Any]:
"""Return serializable state, saving inner_schedule as its own state_dict."""
state = {
key: value
for key, value in self.__dict__.items()
if key not in ("optimizer", "inner_schedule")
}
state["inner_schedule_state"] = self.inner_schedule.state_dict()
return state
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Restore state, including inner_schedule."""
inner_state = state_dict.pop("inner_schedule_state")
self.__dict__.update(state_dict)
self.inner_schedule.load_state_dict(inner_state)