upgrade torchao to 0.17.0 (#3569)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled

* upgrade to torchao 0.17.0

* upgrade mistral-common too

* chore: lint

* patch fix for torchao low bit optimizers

* fix up

* propagate dtype

* fix test for ao change

* address PR comments
This commit is contained in:
Wing Lian
2026-04-02 10:18:00 -04:00
committed by GitHub
parent 842fa039dd
commit 573726c839
6 changed files with 178 additions and 32 deletions

View File

@@ -66,7 +66,7 @@ langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
torchao==0.16.0 torchao==0.17.0
openenv-core==0.1.0 openenv-core==0.1.0
schedulefree==1.4.1 schedulefree==1.4.1
@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
# telemetry # telemetry
posthog==6.7.11 posthog==6.7.11
mistral-common==1.10.0 mistral-common==1.11.0

View File

@@ -329,7 +329,7 @@ class TrainerBuilderBase(abc.ABC):
optimizer_cls = AdamW optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8": elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8 from torchao.optim.adam import AdamWFp8
optimizer_cls = AdamWFp8 optimizer_cls = AdamWFp8
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)

View File

@@ -95,6 +95,7 @@ class PatchManager:
def apply_pre_model_load_patches(self): def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config.""" """Apply pre-model load patches based on config."""
self._deactivate_hf_async_load() self._deactivate_hf_async_load()
self._apply_torchao_patches()
self._apply_transformers_patches() self._apply_transformers_patches()
# self._apply_flex_attention_patches() # self._apply_flex_attention_patches()
self._apply_flash_attention_patches() self._apply_flash_attention_patches()
@@ -125,6 +126,12 @@ class PatchManager:
self._apply_tiled_mlp(self.cfg.model_config_type) self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_moe_expert_quantization_patch() self._apply_moe_expert_quantization_patch()
@staticmethod
def _apply_torchao_patches():
from axolotl.monkeypatch.torchao_optim import patch_torchao_optim_state_8bit
patch_torchao_optim_state_8bit()
def _apply_transformers_patches(self): def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import ( from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop, patch_evaluation_loop,

View File

@@ -0,0 +1,154 @@
"""
Patch for torchao optim subclasses that crash under torch.compile.
torchao 0.17.0 PR #3934 added an "appearance dtype" to OptimState{4,8}bit and
OptimStateFp8, allowing them to report as e.g. bf16 while internally storing
quantized codes. Three issues:
1. aten.view.default doesn't propagate the appearance dtype, so views (e.g. from
DTensor.from_local()) revert to float32 while the base is bf16. torch.compile's
fake-tensor metadata check then fails (AssertionError: torch.bfloat16 != torch.float32).
2. aten._to_copy doesn't clone internal tensors, so same-device dtype changes
(e.g. .float()) create an accidental view relationship with the same issue.
3. aten.view.dtype is unimplemented, so if the dtype-view path IS taken, it crashes
with NotImplementedError.
Fix: propagate dtype in view.default (primary), clone in _to_copy, register view.dtype.
Upstream fix: https://github.com/pytorch/ao/pull/4216
"""
import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
aten = torch.ops.aten
def _needs_view_dtype_patch(cls):
"""Check if a subclass is missing aten.view.dtype."""
op_table = getattr(cls, "_ATEN_OP_TABLE", {}).get(cls, {})
return aten.view.dtype not in op_table
def patch_torchao_optim_state_8bit():
"""Patch torchao optim subclasses for torch.compile compatibility."""
try:
from torchao.optim.subclass_8bit import OptimState8bit
except ImportError:
return
# Patch view.default to propagate appearance dtype
@OptimState8bit.implements(aten.view.default)
def _(func, types, args, kwargs):
x, shape = args
return OptimState8bit(
x.codes.view(shape), x.scale, x.qmap, x.signed, dtype=x.dtype
)
# Patch _to_copy to clone internal tensors (breaks accidental view)
@OptimState8bit.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
dtype = kwargs.get("dtype", args[0].dtype)
device = kwargs.get("device", None)
out = OptimState8bit(
args[0].codes.to(device=device).clone(),
args[0].scale.to(device=device).clone(),
args[0].qmap.to(device=device).clone(),
args[0].signed,
dtype=dtype,
)
return return_and_correct_aliasing(func, args, kwargs, out)
if _needs_view_dtype_patch(OptimState8bit):
@OptimState8bit.implements(aten.view.dtype)
def _(func, types, args, kwargs):
x, dtype = args
return OptimState8bit(x.codes, x.scale, x.qmap, x.signed, dtype=dtype)
LOG.debug("Patched OptimState8bit for torch.compile compatibility")
try:
from torchao.optim.subclass_4bit import OptimState4bit
except ImportError:
OptimState4bit = None
if OptimState4bit is not None:
@OptimState4bit.implements(aten.view.default)
def _(func, types, args, kwargs):
x, shape = args
if tuple(x.shape) == tuple(shape):
return OptimState4bit(
x.codes, x.scale, x.qmap, x.signed, x._shape, dtype=x.dtype
)
if len(shape) == 1 and shape[0] == -1:
return OptimState4bit(
x.codes, x.scale, x.qmap, x.signed, (x.numel(),), dtype=x.dtype
)
raise ValueError(
f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]"
)
@OptimState4bit.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
dtype = kwargs.get("dtype", args[0].dtype)
device = kwargs.get("device", None)
out = OptimState4bit(
args[0].codes.to(device=device).clone(),
args[0].scale.to(device=device).clone(),
args[0].qmap.to(device=device).clone(),
args[0].signed,
args[0].shape,
dtype=dtype,
)
return return_and_correct_aliasing(func, args, kwargs, out)
if _needs_view_dtype_patch(OptimState4bit):
@OptimState4bit.implements(aten.view.dtype)
def _(func, types, args, kwargs):
x, dtype = args
return OptimState4bit(
x.codes, x.scale, x.qmap, x.signed, x.shape, dtype=dtype
)
LOG.debug("Patched OptimState4bit for torch.compile compatibility")
try:
from torchao.optim.subclass_fp8 import OptimStateFp8
except ImportError:
OptimStateFp8 = None
if OptimStateFp8 is not None:
@OptimStateFp8.implements(aten.view.default)
def _(func, types, args, kwargs):
x, shape = args
return OptimStateFp8(x.codes.view(shape), x.scale, dtype=x.dtype)
@OptimStateFp8.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
dtype = kwargs.get("dtype", args[0].dtype)
device = kwargs.get("device", None)
out = OptimStateFp8(
args[0].codes.to(device=device).clone(),
args[0].scale.to(device=device).clone(),
dtype=dtype,
)
return return_and_correct_aliasing(func, args, kwargs, out)
if _needs_view_dtype_patch(OptimStateFp8):
@OptimStateFp8.implements(aten.view.dtype)
def _(func, types, args, kwargs):
x, dtype = args
return OptimStateFp8(x.codes, x.scale, dtype=dtype)
LOG.debug("Patched OptimStateFp8 for torch.compile compatibility")

View File

@@ -6,6 +6,7 @@ import torch
from packaging import version from packaging import version
from torchao.core.config import AOBaseConfig from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_ from torchao.quantization import quantize_
from torchao.quantization.granularity import PerGroup
from torchao.quantization.qat import ( from torchao.quantization.qat import (
QATConfig, QATConfig,
) )
@@ -14,30 +15,22 @@ from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig, Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig, Int4WeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
) )
try:
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
except ImportError:
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
)
from axolotl.utils.schemas.enums import TorchAOQuantDType from axolotl.utils.schemas.enums import TorchAOQuantDType
quantization_config_to_str = { quantization_config_to_str = {
Int8DynamicActivationInt4WeightConfig: "int8int4", Int8DynamicActivationIntxWeightConfig: "int8int4",
Float8DynamicActivationFloat8WeightConfig: "fp8fp8", Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
Float8DynamicActivationInt4WeightConfig: "fp8int4", Float8DynamicActivationInt4WeightConfig: "fp8int4",
} }
if version.parse(torch.__version__) >= version.parse("2.8.0"): if version.parse(torch.__version__) >= version.parse("2.8.0"):
try: try:
from torchao.prototype.mx_formats import ( from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
)
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" quantization_config_to_str[NVFP4WeightOnlyConfig] = "nvfp4"
except (ImportError, RuntimeError): except (ImportError, RuntimeError):
pass pass
@@ -108,10 +101,10 @@ def get_quantization_config(
activation_dtype == TorchAOQuantDType.int8 activation_dtype == TorchAOQuantDType.int8
and weight_dtype == TorchAOQuantDType.int4 and weight_dtype == TorchAOQuantDType.int4
): ):
kwargs = {"weight_dtype": torch.int4}
if group_size is not None: if group_size is not None:
return Int8DynamicActivationInt4WeightConfig(group_size=group_size) kwargs["weight_granularity"] = PerGroup(group_size=group_size)
else: return Int8DynamicActivationIntxWeightConfig(**kwargs)
return Int8DynamicActivationInt4WeightConfig()
if ( if (
activation_dtype == TorchAOQuantDType.float8_e4m3fn activation_dtype == TorchAOQuantDType.float8_e4m3fn
and weight_dtype == TorchAOQuantDType.float8_e4m3fn and weight_dtype == TorchAOQuantDType.float8_e4m3fn
@@ -123,13 +116,11 @@ def get_quantization_config(
): ):
return Float8DynamicActivationInt4WeightConfig() return Float8DynamicActivationInt4WeightConfig()
if weight_dtype == TorchAOQuantDType.nvfp4: if weight_dtype == TorchAOQuantDType.nvfp4:
from torchao.prototype.mx_formats import ( from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
)
if group_size is not None and group_size != 16: if group_size is not None and group_size != 16:
raise ValueError("NVFP4 quantization must use a group_size of 16") raise ValueError("NVFP4 quantization must use a group_size of 16")
return NVFP4InferenceConfig() return NVFP4WeightOnlyConfig()
if weight_dtype == TorchAOQuantDType.mxfp4: if weight_dtype == TorchAOQuantDType.mxfp4:
# MXFP4 uses block_size=32 by default (vs NVFP4's 16) # MXFP4 uses block_size=32 by default (vs NVFP4's 16)

View File

@@ -5,20 +5,14 @@ Tests for axolotl.utils.quantization
import pytest import pytest
import torch import torch
from torch import nn from torch import nn
from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization import IntxUnpackedToInt8Tensor
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import ( from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig, Float8DynamicActivationInt4WeightConfig,
Int8DynamicActivationIntxWeightConfig,
) )
try:
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
except ImportError:
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
@@ -71,7 +65,7 @@ ptq_config_test_cases = [
TorchAOQuantDType.int4, TorchAOQuantDType.int4,
TorchAOQuantDType.int8, TorchAOQuantDType.int8,
None, None,
Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationIntxWeightConfig,
), ),
( (
TorchAOQuantDType.float8_e4m3fn, TorchAOQuantDType.float8_e4m3fn,
@@ -96,7 +90,7 @@ ptq_test_cases = [
8, 8,
False, False,
None, None,
LinearActivationQuantizedTensor, IntxUnpackedToInt8Tensor,
), ),
# ( # (
# TorchAOQuantDType.int4, # TorchAOQuantDType.int4,