upgrade torchao to 0.17.0 (#3569)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* upgrade to torchao 0.17.0 * upgrade mistral-common too * chore: lint * patch fix for torchao low bit optimizers * fix up * propagate dtype * fix test for ao change * address PR comments
This commit is contained in:
@@ -66,7 +66,7 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.16.0
|
||||
torchao==0.17.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
|
||||
mistral-common==1.10.0
|
||||
mistral-common==1.11.0
|
||||
|
||||
@@ -329,7 +329,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
from torchao.optim.adam import AdamWFp8
|
||||
|
||||
optimizer_cls = AdamWFp8
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
@@ -95,6 +95,7 @@ class PatchManager:
|
||||
def apply_pre_model_load_patches(self):
|
||||
"""Apply pre-model load patches based on config."""
|
||||
self._deactivate_hf_async_load()
|
||||
self._apply_torchao_patches()
|
||||
self._apply_transformers_patches()
|
||||
# self._apply_flex_attention_patches()
|
||||
self._apply_flash_attention_patches()
|
||||
@@ -125,6 +126,12 @@ class PatchManager:
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
self._apply_moe_expert_quantization_patch()
|
||||
|
||||
@staticmethod
|
||||
def _apply_torchao_patches():
|
||||
from axolotl.monkeypatch.torchao_optim import patch_torchao_optim_state_8bit
|
||||
|
||||
patch_torchao_optim_state_8bit()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
patch_evaluation_loop,
|
||||
|
||||
154
src/axolotl/monkeypatch/torchao_optim.py
Normal file
154
src/axolotl/monkeypatch/torchao_optim.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Patch for torchao optim subclasses that crash under torch.compile.
|
||||
|
||||
torchao 0.17.0 PR #3934 added an "appearance dtype" to OptimState{4,8}bit and
|
||||
OptimStateFp8, allowing them to report as e.g. bf16 while internally storing
|
||||
quantized codes. Three issues:
|
||||
|
||||
1. aten.view.default doesn't propagate the appearance dtype, so views (e.g. from
|
||||
DTensor.from_local()) revert to float32 while the base is bf16. torch.compile's
|
||||
fake-tensor metadata check then fails (AssertionError: torch.bfloat16 != torch.float32).
|
||||
|
||||
2. aten._to_copy doesn't clone internal tensors, so same-device dtype changes
|
||||
(e.g. .float()) create an accidental view relationship with the same issue.
|
||||
|
||||
3. aten.view.dtype is unimplemented, so if the dtype-view path IS taken, it crashes
|
||||
with NotImplementedError.
|
||||
|
||||
Fix: propagate dtype in view.default (primary), clone in _to_copy, register view.dtype.
|
||||
|
||||
Upstream fix: https://github.com/pytorch/ao/pull/4216
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
def _needs_view_dtype_patch(cls):
|
||||
"""Check if a subclass is missing aten.view.dtype."""
|
||||
op_table = getattr(cls, "_ATEN_OP_TABLE", {}).get(cls, {})
|
||||
return aten.view.dtype not in op_table
|
||||
|
||||
|
||||
def patch_torchao_optim_state_8bit():
|
||||
"""Patch torchao optim subclasses for torch.compile compatibility."""
|
||||
try:
|
||||
from torchao.optim.subclass_8bit import OptimState8bit
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
# Patch view.default to propagate appearance dtype
|
||||
@OptimState8bit.implements(aten.view.default)
|
||||
def _(func, types, args, kwargs):
|
||||
x, shape = args
|
||||
return OptimState8bit(
|
||||
x.codes.view(shape), x.scale, x.qmap, x.signed, dtype=x.dtype
|
||||
)
|
||||
|
||||
# Patch _to_copy to clone internal tensors (breaks accidental view)
|
||||
@OptimState8bit.implements(aten._to_copy.default)
|
||||
def _(func, types, args, kwargs):
|
||||
dtype = kwargs.get("dtype", args[0].dtype)
|
||||
device = kwargs.get("device", None)
|
||||
out = OptimState8bit(
|
||||
args[0].codes.to(device=device).clone(),
|
||||
args[0].scale.to(device=device).clone(),
|
||||
args[0].qmap.to(device=device).clone(),
|
||||
args[0].signed,
|
||||
dtype=dtype,
|
||||
)
|
||||
return return_and_correct_aliasing(func, args, kwargs, out)
|
||||
|
||||
if _needs_view_dtype_patch(OptimState8bit):
|
||||
|
||||
@OptimState8bit.implements(aten.view.dtype)
|
||||
def _(func, types, args, kwargs):
|
||||
x, dtype = args
|
||||
return OptimState8bit(x.codes, x.scale, x.qmap, x.signed, dtype=dtype)
|
||||
|
||||
LOG.debug("Patched OptimState8bit for torch.compile compatibility")
|
||||
|
||||
try:
|
||||
from torchao.optim.subclass_4bit import OptimState4bit
|
||||
except ImportError:
|
||||
OptimState4bit = None
|
||||
|
||||
if OptimState4bit is not None:
|
||||
|
||||
@OptimState4bit.implements(aten.view.default)
|
||||
def _(func, types, args, kwargs):
|
||||
x, shape = args
|
||||
if tuple(x.shape) == tuple(shape):
|
||||
return OptimState4bit(
|
||||
x.codes, x.scale, x.qmap, x.signed, x._shape, dtype=x.dtype
|
||||
)
|
||||
if len(shape) == 1 and shape[0] == -1:
|
||||
return OptimState4bit(
|
||||
x.codes, x.scale, x.qmap, x.signed, (x.numel(),), dtype=x.dtype
|
||||
)
|
||||
raise ValueError(
|
||||
f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]"
|
||||
)
|
||||
|
||||
@OptimState4bit.implements(aten._to_copy.default)
|
||||
def _(func, types, args, kwargs):
|
||||
dtype = kwargs.get("dtype", args[0].dtype)
|
||||
device = kwargs.get("device", None)
|
||||
out = OptimState4bit(
|
||||
args[0].codes.to(device=device).clone(),
|
||||
args[0].scale.to(device=device).clone(),
|
||||
args[0].qmap.to(device=device).clone(),
|
||||
args[0].signed,
|
||||
args[0].shape,
|
||||
dtype=dtype,
|
||||
)
|
||||
return return_and_correct_aliasing(func, args, kwargs, out)
|
||||
|
||||
if _needs_view_dtype_patch(OptimState4bit):
|
||||
|
||||
@OptimState4bit.implements(aten.view.dtype)
|
||||
def _(func, types, args, kwargs):
|
||||
x, dtype = args
|
||||
return OptimState4bit(
|
||||
x.codes, x.scale, x.qmap, x.signed, x.shape, dtype=dtype
|
||||
)
|
||||
|
||||
LOG.debug("Patched OptimState4bit for torch.compile compatibility")
|
||||
|
||||
try:
|
||||
from torchao.optim.subclass_fp8 import OptimStateFp8
|
||||
except ImportError:
|
||||
OptimStateFp8 = None
|
||||
|
||||
if OptimStateFp8 is not None:
|
||||
|
||||
@OptimStateFp8.implements(aten.view.default)
|
||||
def _(func, types, args, kwargs):
|
||||
x, shape = args
|
||||
return OptimStateFp8(x.codes.view(shape), x.scale, dtype=x.dtype)
|
||||
|
||||
@OptimStateFp8.implements(aten._to_copy.default)
|
||||
def _(func, types, args, kwargs):
|
||||
dtype = kwargs.get("dtype", args[0].dtype)
|
||||
device = kwargs.get("device", None)
|
||||
out = OptimStateFp8(
|
||||
args[0].codes.to(device=device).clone(),
|
||||
args[0].scale.to(device=device).clone(),
|
||||
dtype=dtype,
|
||||
)
|
||||
return return_and_correct_aliasing(func, args, kwargs, out)
|
||||
|
||||
if _needs_view_dtype_patch(OptimStateFp8):
|
||||
|
||||
@OptimStateFp8.implements(aten.view.dtype)
|
||||
def _(func, types, args, kwargs):
|
||||
x, dtype = args
|
||||
return OptimStateFp8(x.codes, x.scale, dtype=dtype)
|
||||
|
||||
LOG.debug("Patched OptimStateFp8 for torch.compile compatibility")
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
from packaging import version
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.quantization import quantize_
|
||||
from torchao.quantization.granularity import PerGroup
|
||||
from torchao.quantization.qat import (
|
||||
QATConfig,
|
||||
)
|
||||
@@ -14,30 +15,22 @@ from torchao.quantization.quant_api import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationIntxWeightConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
|
||||
except ImportError:
|
||||
from torchao.quantization.quant_api import (
|
||||
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
|
||||
)
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||
|
||||
quantization_config_to_str = {
|
||||
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
||||
Int8DynamicActivationIntxWeightConfig: "int8int4",
|
||||
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
||||
Float8DynamicActivationInt4WeightConfig: "fp8int4",
|
||||
}
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||||
try:
|
||||
from torchao.prototype.mx_formats import (
|
||||
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
|
||||
)
|
||||
from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig
|
||||
|
||||
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
|
||||
quantization_config_to_str[NVFP4WeightOnlyConfig] = "nvfp4"
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
@@ -108,10 +101,10 @@ def get_quantization_config(
|
||||
activation_dtype == TorchAOQuantDType.int8
|
||||
and weight_dtype == TorchAOQuantDType.int4
|
||||
):
|
||||
kwargs = {"weight_dtype": torch.int4}
|
||||
if group_size is not None:
|
||||
return Int8DynamicActivationInt4WeightConfig(group_size=group_size)
|
||||
else:
|
||||
return Int8DynamicActivationInt4WeightConfig()
|
||||
kwargs["weight_granularity"] = PerGroup(group_size=group_size)
|
||||
return Int8DynamicActivationIntxWeightConfig(**kwargs)
|
||||
if (
|
||||
activation_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||
@@ -123,13 +116,11 @@ def get_quantization_config(
|
||||
):
|
||||
return Float8DynamicActivationInt4WeightConfig()
|
||||
if weight_dtype == TorchAOQuantDType.nvfp4:
|
||||
from torchao.prototype.mx_formats import (
|
||||
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
|
||||
)
|
||||
from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig
|
||||
|
||||
if group_size is not None and group_size != 16:
|
||||
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
||||
return NVFP4InferenceConfig()
|
||||
return NVFP4WeightOnlyConfig()
|
||||
|
||||
if weight_dtype == TorchAOQuantDType.mxfp4:
|
||||
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
|
||||
|
||||
@@ -5,20 +5,14 @@ Tests for axolotl.utils.quantization
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.quantization import LinearActivationQuantizedTensor
|
||||
from torchao.quantization import IntxUnpackedToInt8Tensor
|
||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
from torchao.quantization.quant_api import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
Int8DynamicActivationIntxWeightConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
|
||||
except ImportError:
|
||||
from torchao.quantization.quant_api import (
|
||||
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
|
||||
)
|
||||
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.trainer_callback import TrainerState
|
||||
@@ -71,7 +65,7 @@ ptq_config_test_cases = [
|
||||
TorchAOQuantDType.int4,
|
||||
TorchAOQuantDType.int8,
|
||||
None,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
Int8DynamicActivationIntxWeightConfig,
|
||||
),
|
||||
(
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
@@ -96,7 +90,7 @@ ptq_test_cases = [
|
||||
8,
|
||||
False,
|
||||
None,
|
||||
LinearActivationQuantizedTensor,
|
||||
IntxUnpackedToInt8Tensor,
|
||||
),
|
||||
# (
|
||||
# TorchAOQuantDType.int4,
|
||||
|
||||
Reference in New Issue
Block a user