From 573726c839d6e1f275da75a4cd30447dfefd1fdb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 2 Apr 2026 10:18:00 -0400 Subject: [PATCH] upgrade torchao to 0.17.0 (#3569) * upgrade to torchao 0.17.0 * upgrade mistral-common too * chore: lint * patch fix for torchao low bit optimizers * fix up * propagate dtype * fix test for ao change * address PR comments --- requirements.txt | 4 +- src/axolotl/core/builders/base.py | 2 +- src/axolotl/loaders/patch_manager.py | 7 ++ src/axolotl/monkeypatch/torchao_optim.py | 154 +++++++++++++++++++++++ src/axolotl/utils/quantization.py | 29 ++--- tests/e2e/test_quantization.py | 14 +-- 6 files changed, 178 insertions(+), 32 deletions(-) create mode 100644 src/axolotl/monkeypatch/torchao_optim.py diff --git a/requirements.txt b/requirements.txt index 2bd4c4aeb..fb429df90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -66,7 +66,7 @@ langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 -torchao==0.16.0 +torchao==0.17.0 openenv-core==0.1.0 schedulefree==1.4.1 @@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6 # telemetry posthog==6.7.11 -mistral-common==1.10.0 +mistral-common==1.11.0 diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 90c813927..9dba48b88 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -329,7 +329,7 @@ class TrainerBuilderBase(abc.ABC): optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) elif self.cfg.optimizer == "ao_adamw_fp8": - from torchao.prototype.low_bit_optim import AdamWFp8 + from torchao.optim.adam import AdamWFp8 optimizer_cls = AdamWFp8 optimizer_kwargs.update(adam_kwargs) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index aa67c3b1e..018ca52a0 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -95,6 +95,7 @@ class PatchManager: def apply_pre_model_load_patches(self): """Apply pre-model load patches based on config.""" self._deactivate_hf_async_load() + self._apply_torchao_patches() self._apply_transformers_patches() # self._apply_flex_attention_patches() self._apply_flash_attention_patches() @@ -125,6 +126,12 @@ class PatchManager: self._apply_tiled_mlp(self.cfg.model_config_type) self._apply_moe_expert_quantization_patch() + @staticmethod + def _apply_torchao_patches(): + from axolotl.monkeypatch.torchao_optim import patch_torchao_optim_state_8bit + + patch_torchao_optim_state_8bit() + def _apply_transformers_patches(self): from axolotl.monkeypatch.transformers.trainer_loss_calc import ( patch_evaluation_loop, diff --git a/src/axolotl/monkeypatch/torchao_optim.py b/src/axolotl/monkeypatch/torchao_optim.py new file mode 100644 index 000000000..98c325a0b --- /dev/null +++ b/src/axolotl/monkeypatch/torchao_optim.py @@ -0,0 +1,154 @@ +""" +Patch for torchao optim subclasses that crash under torch.compile. + +torchao 0.17.0 PR #3934 added an "appearance dtype" to OptimState{4,8}bit and +OptimStateFp8, allowing them to report as e.g. bf16 while internally storing +quantized codes. Three issues: + +1. aten.view.default doesn't propagate the appearance dtype, so views (e.g. from + DTensor.from_local()) revert to float32 while the base is bf16. torch.compile's + fake-tensor metadata check then fails (AssertionError: torch.bfloat16 != torch.float32). + +2. aten._to_copy doesn't clone internal tensors, so same-device dtype changes + (e.g. .float()) create an accidental view relationship with the same issue. + +3. aten.view.dtype is unimplemented, so if the dtype-view path IS taken, it crashes + with NotImplementedError. + +Fix: propagate dtype in view.default (primary), clone in _to_copy, register view.dtype. + +Upstream fix: https://github.com/pytorch/ao/pull/4216 +""" + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +aten = torch.ops.aten + + +def _needs_view_dtype_patch(cls): + """Check if a subclass is missing aten.view.dtype.""" + op_table = getattr(cls, "_ATEN_OP_TABLE", {}).get(cls, {}) + return aten.view.dtype not in op_table + + +def patch_torchao_optim_state_8bit(): + """Patch torchao optim subclasses for torch.compile compatibility.""" + try: + from torchao.optim.subclass_8bit import OptimState8bit + except ImportError: + return + + # Patch view.default to propagate appearance dtype + @OptimState8bit.implements(aten.view.default) + def _(func, types, args, kwargs): + x, shape = args + return OptimState8bit( + x.codes.view(shape), x.scale, x.qmap, x.signed, dtype=x.dtype + ) + + # Patch _to_copy to clone internal tensors (breaks accidental view) + @OptimState8bit.implements(aten._to_copy.default) + def _(func, types, args, kwargs): + dtype = kwargs.get("dtype", args[0].dtype) + device = kwargs.get("device", None) + out = OptimState8bit( + args[0].codes.to(device=device).clone(), + args[0].scale.to(device=device).clone(), + args[0].qmap.to(device=device).clone(), + args[0].signed, + dtype=dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + if _needs_view_dtype_patch(OptimState8bit): + + @OptimState8bit.implements(aten.view.dtype) + def _(func, types, args, kwargs): + x, dtype = args + return OptimState8bit(x.codes, x.scale, x.qmap, x.signed, dtype=dtype) + + LOG.debug("Patched OptimState8bit for torch.compile compatibility") + + try: + from torchao.optim.subclass_4bit import OptimState4bit + except ImportError: + OptimState4bit = None + + if OptimState4bit is not None: + + @OptimState4bit.implements(aten.view.default) + def _(func, types, args, kwargs): + x, shape = args + if tuple(x.shape) == tuple(shape): + return OptimState4bit( + x.codes, x.scale, x.qmap, x.signed, x._shape, dtype=x.dtype + ) + if len(shape) == 1 and shape[0] == -1: + return OptimState4bit( + x.codes, x.scale, x.qmap, x.signed, (x.numel(),), dtype=x.dtype + ) + raise ValueError( + f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]" + ) + + @OptimState4bit.implements(aten._to_copy.default) + def _(func, types, args, kwargs): + dtype = kwargs.get("dtype", args[0].dtype) + device = kwargs.get("device", None) + out = OptimState4bit( + args[0].codes.to(device=device).clone(), + args[0].scale.to(device=device).clone(), + args[0].qmap.to(device=device).clone(), + args[0].signed, + args[0].shape, + dtype=dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + if _needs_view_dtype_patch(OptimState4bit): + + @OptimState4bit.implements(aten.view.dtype) + def _(func, types, args, kwargs): + x, dtype = args + return OptimState4bit( + x.codes, x.scale, x.qmap, x.signed, x.shape, dtype=dtype + ) + + LOG.debug("Patched OptimState4bit for torch.compile compatibility") + + try: + from torchao.optim.subclass_fp8 import OptimStateFp8 + except ImportError: + OptimStateFp8 = None + + if OptimStateFp8 is not None: + + @OptimStateFp8.implements(aten.view.default) + def _(func, types, args, kwargs): + x, shape = args + return OptimStateFp8(x.codes.view(shape), x.scale, dtype=x.dtype) + + @OptimStateFp8.implements(aten._to_copy.default) + def _(func, types, args, kwargs): + dtype = kwargs.get("dtype", args[0].dtype) + device = kwargs.get("device", None) + out = OptimStateFp8( + args[0].codes.to(device=device).clone(), + args[0].scale.to(device=device).clone(), + dtype=dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, out) + + if _needs_view_dtype_patch(OptimStateFp8): + + @OptimStateFp8.implements(aten.view.dtype) + def _(func, types, args, kwargs): + x, dtype = args + return OptimStateFp8(x.codes, x.scale, dtype=dtype) + + LOG.debug("Patched OptimStateFp8 for torch.compile compatibility") diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index 6a479d260..04b2c6341 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -6,6 +6,7 @@ import torch from packaging import version from torchao.core.config import AOBaseConfig from torchao.quantization import quantize_ +from torchao.quantization.granularity import PerGroup from torchao.quantization.qat import ( QATConfig, ) @@ -14,30 +15,22 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, ) -try: - from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig -except ImportError: - from torchao.quantization.quant_api import ( - Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig, - ) - from axolotl.utils.schemas.enums import TorchAOQuantDType quantization_config_to_str = { - Int8DynamicActivationInt4WeightConfig: "int8int4", + Int8DynamicActivationIntxWeightConfig: "int8int4", Float8DynamicActivationFloat8WeightConfig: "fp8fp8", Float8DynamicActivationInt4WeightConfig: "fp8int4", } if version.parse(torch.__version__) >= version.parse("2.8.0"): try: - from torchao.prototype.mx_formats import ( - NVFP4WeightOnlyConfig as NVFP4InferenceConfig, - ) + from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig - quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" + quantization_config_to_str[NVFP4WeightOnlyConfig] = "nvfp4" except (ImportError, RuntimeError): pass @@ -108,10 +101,10 @@ def get_quantization_config( activation_dtype == TorchAOQuantDType.int8 and weight_dtype == TorchAOQuantDType.int4 ): + kwargs = {"weight_dtype": torch.int4} if group_size is not None: - return Int8DynamicActivationInt4WeightConfig(group_size=group_size) - else: - return Int8DynamicActivationInt4WeightConfig() + kwargs["weight_granularity"] = PerGroup(group_size=group_size) + return Int8DynamicActivationIntxWeightConfig(**kwargs) if ( activation_dtype == TorchAOQuantDType.float8_e4m3fn and weight_dtype == TorchAOQuantDType.float8_e4m3fn @@ -123,13 +116,11 @@ def get_quantization_config( ): return Float8DynamicActivationInt4WeightConfig() if weight_dtype == TorchAOQuantDType.nvfp4: - from torchao.prototype.mx_formats import ( - NVFP4WeightOnlyConfig as NVFP4InferenceConfig, - ) + from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig if group_size is not None and group_size != 16: raise ValueError("NVFP4 quantization must use a group_size of 16") - return NVFP4InferenceConfig() + return NVFP4WeightOnlyConfig() if weight_dtype == TorchAOQuantDType.mxfp4: # MXFP4 uses block_size=32 by default (vs NVFP4's 16) diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py index 6bbc34949..a70a46194 100644 --- a/tests/e2e/test_quantization.py +++ b/tests/e2e/test_quantization.py @@ -5,20 +5,14 @@ Tests for axolotl.utils.quantization import pytest import torch from torch import nn -from torchao.quantization import LinearActivationQuantizedTensor +from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.qat.embedding import FakeQuantizedEmbedding from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, ) - -try: - from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig -except ImportError: - from torchao.quantization.quant_api import ( - Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig, - ) from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor from transformers import AutoModelForCausalLM from transformers.trainer_callback import TrainerState @@ -71,7 +65,7 @@ ptq_config_test_cases = [ TorchAOQuantDType.int4, TorchAOQuantDType.int8, None, - Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, ), ( TorchAOQuantDType.float8_e4m3fn, @@ -96,7 +90,7 @@ ptq_test_cases = [ 8, False, None, - LinearActivationQuantizedTensor, + IntxUnpackedToInt8Tensor, ), # ( # TorchAOQuantDType.int4,