upgrade torchao to 0.17.0 (#3569)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled

* upgrade to torchao 0.17.0

* upgrade mistral-common too

* chore: lint

* patch fix for torchao low bit optimizers

* fix up

* propagate dtype

* fix test for ao change

* address PR comments
This commit is contained in:
Wing Lian
2026-04-02 10:18:00 -04:00
committed by GitHub
parent 842fa039dd
commit 573726c839
6 changed files with 178 additions and 32 deletions

View File

@@ -66,7 +66,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.16.0
torchao==0.17.0
openenv-core==0.1.0
schedulefree==1.4.1
@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.10.0
mistral-common==1.11.0

View File

@@ -329,7 +329,7 @@ class TrainerBuilderBase(abc.ABC):
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
from torchao.optim.adam import AdamWFp8
optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs)

View File

@@ -95,6 +95,7 @@ class PatchManager:
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
self._deactivate_hf_async_load()
self._apply_torchao_patches()
self._apply_transformers_patches()
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
@@ -125,6 +126,12 @@ class PatchManager:
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_moe_expert_quantization_patch()
@staticmethod
def _apply_torchao_patches():
from axolotl.monkeypatch.torchao_optim import patch_torchao_optim_state_8bit
patch_torchao_optim_state_8bit()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,

View File

@@ -0,0 +1,154 @@
"""
Patch for torchao optim subclasses that crash under torch.compile.
torchao 0.17.0 PR #3934 added an "appearance dtype" to OptimState{4,8}bit and
OptimStateFp8, allowing them to report as e.g. bf16 while internally storing
quantized codes. Three issues:
1. aten.view.default doesn't propagate the appearance dtype, so views (e.g. from
DTensor.from_local()) revert to float32 while the base is bf16. torch.compile's
fake-tensor metadata check then fails (AssertionError: torch.bfloat16 != torch.float32).
2. aten._to_copy doesn't clone internal tensors, so same-device dtype changes
(e.g. .float()) create an accidental view relationship with the same issue.
3. aten.view.dtype is unimplemented, so if the dtype-view path IS taken, it crashes
with NotImplementedError.
Fix: propagate dtype in view.default (primary), clone in _to_copy, register view.dtype.
Upstream fix: https://github.com/pytorch/ao/pull/4216
"""
import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
aten = torch.ops.aten
def _needs_view_dtype_patch(cls):
"""Check if a subclass is missing aten.view.dtype."""
op_table = getattr(cls, "_ATEN_OP_TABLE", {}).get(cls, {})
return aten.view.dtype not in op_table
def patch_torchao_optim_state_8bit():
"""Patch torchao optim subclasses for torch.compile compatibility."""
try:
from torchao.optim.subclass_8bit import OptimState8bit
except ImportError:
return
# Patch view.default to propagate appearance dtype
@OptimState8bit.implements(aten.view.default)
def _(func, types, args, kwargs):
x, shape = args
return OptimState8bit(
x.codes.view(shape), x.scale, x.qmap, x.signed, dtype=x.dtype
)
# Patch _to_copy to clone internal tensors (breaks accidental view)
@OptimState8bit.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
dtype = kwargs.get("dtype", args[0].dtype)
device = kwargs.get("device", None)
out = OptimState8bit(
args[0].codes.to(device=device).clone(),
args[0].scale.to(device=device).clone(),
args[0].qmap.to(device=device).clone(),
args[0].signed,
dtype=dtype,
)
return return_and_correct_aliasing(func, args, kwargs, out)
if _needs_view_dtype_patch(OptimState8bit):
@OptimState8bit.implements(aten.view.dtype)
def _(func, types, args, kwargs):
x, dtype = args
return OptimState8bit(x.codes, x.scale, x.qmap, x.signed, dtype=dtype)
LOG.debug("Patched OptimState8bit for torch.compile compatibility")
try:
from torchao.optim.subclass_4bit import OptimState4bit
except ImportError:
OptimState4bit = None
if OptimState4bit is not None:
@OptimState4bit.implements(aten.view.default)
def _(func, types, args, kwargs):
x, shape = args
if tuple(x.shape) == tuple(shape):
return OptimState4bit(
x.codes, x.scale, x.qmap, x.signed, x._shape, dtype=x.dtype
)
if len(shape) == 1 and shape[0] == -1:
return OptimState4bit(
x.codes, x.scale, x.qmap, x.signed, (x.numel(),), dtype=x.dtype
)
raise ValueError(
f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]"
)
@OptimState4bit.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
dtype = kwargs.get("dtype", args[0].dtype)
device = kwargs.get("device", None)
out = OptimState4bit(
args[0].codes.to(device=device).clone(),
args[0].scale.to(device=device).clone(),
args[0].qmap.to(device=device).clone(),
args[0].signed,
args[0].shape,
dtype=dtype,
)
return return_and_correct_aliasing(func, args, kwargs, out)
if _needs_view_dtype_patch(OptimState4bit):
@OptimState4bit.implements(aten.view.dtype)
def _(func, types, args, kwargs):
x, dtype = args
return OptimState4bit(
x.codes, x.scale, x.qmap, x.signed, x.shape, dtype=dtype
)
LOG.debug("Patched OptimState4bit for torch.compile compatibility")
try:
from torchao.optim.subclass_fp8 import OptimStateFp8
except ImportError:
OptimStateFp8 = None
if OptimStateFp8 is not None:
@OptimStateFp8.implements(aten.view.default)
def _(func, types, args, kwargs):
x, shape = args
return OptimStateFp8(x.codes.view(shape), x.scale, dtype=x.dtype)
@OptimStateFp8.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
dtype = kwargs.get("dtype", args[0].dtype)
device = kwargs.get("device", None)
out = OptimStateFp8(
args[0].codes.to(device=device).clone(),
args[0].scale.to(device=device).clone(),
dtype=dtype,
)
return return_and_correct_aliasing(func, args, kwargs, out)
if _needs_view_dtype_patch(OptimStateFp8):
@OptimStateFp8.implements(aten.view.dtype)
def _(func, types, args, kwargs):
x, dtype = args
return OptimStateFp8(x.codes, x.scale, dtype=dtype)
LOG.debug("Patched OptimStateFp8 for torch.compile compatibility")

View File

@@ -6,6 +6,7 @@ import torch
from packaging import version
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.granularity import PerGroup
from torchao.quantization.qat import (
QATConfig,
)
@@ -14,30 +15,22 @@ from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
)
try:
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
except ImportError:
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
)
from axolotl.utils.schemas.enums import TorchAOQuantDType
quantization_config_to_str = {
Int8DynamicActivationInt4WeightConfig: "int8int4",
Int8DynamicActivationIntxWeightConfig: "int8int4",
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
Float8DynamicActivationInt4WeightConfig: "fp8int4",
}
if version.parse(torch.__version__) >= version.parse("2.8.0"):
try:
from torchao.prototype.mx_formats import (
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
)
from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
quantization_config_to_str[NVFP4WeightOnlyConfig] = "nvfp4"
except (ImportError, RuntimeError):
pass
@@ -108,10 +101,10 @@ def get_quantization_config(
activation_dtype == TorchAOQuantDType.int8
and weight_dtype == TorchAOQuantDType.int4
):
kwargs = {"weight_dtype": torch.int4}
if group_size is not None:
return Int8DynamicActivationInt4WeightConfig(group_size=group_size)
else:
return Int8DynamicActivationInt4WeightConfig()
kwargs["weight_granularity"] = PerGroup(group_size=group_size)
return Int8DynamicActivationIntxWeightConfig(**kwargs)
if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
@@ -123,13 +116,11 @@ def get_quantization_config(
):
return Float8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.nvfp4:
from torchao.prototype.mx_formats import (
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
)
from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig
if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig()
return NVFP4WeightOnlyConfig()
if weight_dtype == TorchAOQuantDType.mxfp4:
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)

View File

@@ -5,20 +5,14 @@ Tests for axolotl.utils.quantization
import pytest
import torch
from torch import nn
from torchao.quantization import LinearActivationQuantizedTensor
from torchao.quantization import IntxUnpackedToInt8Tensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationIntxWeightConfig,
)
try:
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
except ImportError:
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState
@@ -71,7 +65,7 @@ ptq_config_test_cases = [
TorchAOQuantDType.int4,
TorchAOQuantDType.int8,
None,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationIntxWeightConfig,
),
(
TorchAOQuantDType.float8_e4m3fn,
@@ -96,7 +90,7 @@ ptq_test_cases = [
8,
False,
None,
LinearActivationQuantizedTensor,
IntxUnpackedToInt8Tensor,
),
# (
# TorchAOQuantDType.int4,