upgrade torchao to 0.17.0 (#3569)
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
Some checks failed
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* upgrade to torchao 0.17.0 * upgrade mistral-common too * chore: lint * patch fix for torchao low bit optimizers * fix up * propagate dtype * fix test for ao change * address PR comments
This commit is contained in:
@@ -66,7 +66,7 @@ langdetect==1.0.9
|
|||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.16.0
|
torchao==0.17.0
|
||||||
openenv-core==0.1.0
|
openenv-core==0.1.0
|
||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
@@ -75,4 +75,4 @@ axolotl-contribs-mit==0.0.6
|
|||||||
# telemetry
|
# telemetry
|
||||||
posthog==6.7.11
|
posthog==6.7.11
|
||||||
|
|
||||||
mistral-common==1.10.0
|
mistral-common==1.11.0
|
||||||
|
|||||||
@@ -329,7 +329,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
optimizer_cls = AdamW
|
optimizer_cls = AdamW
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
from torchao.optim.adam import AdamWFp8
|
||||||
|
|
||||||
optimizer_cls = AdamWFp8
|
optimizer_cls = AdamWFp8
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ class PatchManager:
|
|||||||
def apply_pre_model_load_patches(self):
|
def apply_pre_model_load_patches(self):
|
||||||
"""Apply pre-model load patches based on config."""
|
"""Apply pre-model load patches based on config."""
|
||||||
self._deactivate_hf_async_load()
|
self._deactivate_hf_async_load()
|
||||||
|
self._apply_torchao_patches()
|
||||||
self._apply_transformers_patches()
|
self._apply_transformers_patches()
|
||||||
# self._apply_flex_attention_patches()
|
# self._apply_flex_attention_patches()
|
||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
@@ -125,6 +126,12 @@ class PatchManager:
|
|||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
self._apply_moe_expert_quantization_patch()
|
self._apply_moe_expert_quantization_patch()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_torchao_patches():
|
||||||
|
from axolotl.monkeypatch.torchao_optim import patch_torchao_optim_state_8bit
|
||||||
|
|
||||||
|
patch_torchao_optim_state_8bit()
|
||||||
|
|
||||||
def _apply_transformers_patches(self):
|
def _apply_transformers_patches(self):
|
||||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
patch_evaluation_loop,
|
patch_evaluation_loop,
|
||||||
|
|||||||
154
src/axolotl/monkeypatch/torchao_optim.py
Normal file
154
src/axolotl/monkeypatch/torchao_optim.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""
|
||||||
|
Patch for torchao optim subclasses that crash under torch.compile.
|
||||||
|
|
||||||
|
torchao 0.17.0 PR #3934 added an "appearance dtype" to OptimState{4,8}bit and
|
||||||
|
OptimStateFp8, allowing them to report as e.g. bf16 while internally storing
|
||||||
|
quantized codes. Three issues:
|
||||||
|
|
||||||
|
1. aten.view.default doesn't propagate the appearance dtype, so views (e.g. from
|
||||||
|
DTensor.from_local()) revert to float32 while the base is bf16. torch.compile's
|
||||||
|
fake-tensor metadata check then fails (AssertionError: torch.bfloat16 != torch.float32).
|
||||||
|
|
||||||
|
2. aten._to_copy doesn't clone internal tensors, so same-device dtype changes
|
||||||
|
(e.g. .float()) create an accidental view relationship with the same issue.
|
||||||
|
|
||||||
|
3. aten.view.dtype is unimplemented, so if the dtype-view path IS taken, it crashes
|
||||||
|
with NotImplementedError.
|
||||||
|
|
||||||
|
Fix: propagate dtype in view.default (primary), clone in _to_copy, register view.dtype.
|
||||||
|
|
||||||
|
Upstream fix: https://github.com/pytorch/ao/pull/4216
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils._python_dispatch import return_and_correct_aliasing
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
aten = torch.ops.aten
|
||||||
|
|
||||||
|
|
||||||
|
def _needs_view_dtype_patch(cls):
|
||||||
|
"""Check if a subclass is missing aten.view.dtype."""
|
||||||
|
op_table = getattr(cls, "_ATEN_OP_TABLE", {}).get(cls, {})
|
||||||
|
return aten.view.dtype not in op_table
|
||||||
|
|
||||||
|
|
||||||
|
def patch_torchao_optim_state_8bit():
|
||||||
|
"""Patch torchao optim subclasses for torch.compile compatibility."""
|
||||||
|
try:
|
||||||
|
from torchao.optim.subclass_8bit import OptimState8bit
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Patch view.default to propagate appearance dtype
|
||||||
|
@OptimState8bit.implements(aten.view.default)
|
||||||
|
def _(func, types, args, kwargs):
|
||||||
|
x, shape = args
|
||||||
|
return OptimState8bit(
|
||||||
|
x.codes.view(shape), x.scale, x.qmap, x.signed, dtype=x.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patch _to_copy to clone internal tensors (breaks accidental view)
|
||||||
|
@OptimState8bit.implements(aten._to_copy.default)
|
||||||
|
def _(func, types, args, kwargs):
|
||||||
|
dtype = kwargs.get("dtype", args[0].dtype)
|
||||||
|
device = kwargs.get("device", None)
|
||||||
|
out = OptimState8bit(
|
||||||
|
args[0].codes.to(device=device).clone(),
|
||||||
|
args[0].scale.to(device=device).clone(),
|
||||||
|
args[0].qmap.to(device=device).clone(),
|
||||||
|
args[0].signed,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
||||||
|
|
||||||
|
if _needs_view_dtype_patch(OptimState8bit):
|
||||||
|
|
||||||
|
@OptimState8bit.implements(aten.view.dtype)
|
||||||
|
def _(func, types, args, kwargs):
|
||||||
|
x, dtype = args
|
||||||
|
return OptimState8bit(x.codes, x.scale, x.qmap, x.signed, dtype=dtype)
|
||||||
|
|
||||||
|
LOG.debug("Patched OptimState8bit for torch.compile compatibility")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torchao.optim.subclass_4bit import OptimState4bit
|
||||||
|
except ImportError:
|
||||||
|
OptimState4bit = None
|
||||||
|
|
||||||
|
if OptimState4bit is not None:
|
||||||
|
|
||||||
|
@OptimState4bit.implements(aten.view.default)
|
||||||
|
def _(func, types, args, kwargs):
|
||||||
|
x, shape = args
|
||||||
|
if tuple(x.shape) == tuple(shape):
|
||||||
|
return OptimState4bit(
|
||||||
|
x.codes, x.scale, x.qmap, x.signed, x._shape, dtype=x.dtype
|
||||||
|
)
|
||||||
|
if len(shape) == 1 and shape[0] == -1:
|
||||||
|
return OptimState4bit(
|
||||||
|
x.codes, x.scale, x.qmap, x.signed, (x.numel(),), dtype=x.dtype
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]"
|
||||||
|
)
|
||||||
|
|
||||||
|
@OptimState4bit.implements(aten._to_copy.default)
|
||||||
|
def _(func, types, args, kwargs):
|
||||||
|
dtype = kwargs.get("dtype", args[0].dtype)
|
||||||
|
device = kwargs.get("device", None)
|
||||||
|
out = OptimState4bit(
|
||||||
|
args[0].codes.to(device=device).clone(),
|
||||||
|
args[0].scale.to(device=device).clone(),
|
||||||
|
args[0].qmap.to(device=device).clone(),
|
||||||
|
args[0].signed,
|
||||||
|
args[0].shape,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
||||||
|
|
||||||
|
if _needs_view_dtype_patch(OptimState4bit):
|
||||||
|
|
||||||
|
@OptimState4bit.implements(aten.view.dtype)
|
||||||
|
def _(func, types, args, kwargs):
|
||||||
|
x, dtype = args
|
||||||
|
return OptimState4bit(
|
||||||
|
x.codes, x.scale, x.qmap, x.signed, x.shape, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug("Patched OptimState4bit for torch.compile compatibility")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torchao.optim.subclass_fp8 import OptimStateFp8
|
||||||
|
except ImportError:
|
||||||
|
OptimStateFp8 = None
|
||||||
|
|
||||||
|
if OptimStateFp8 is not None:
|
||||||
|
|
||||||
|
@OptimStateFp8.implements(aten.view.default)
|
||||||
|
def _(func, types, args, kwargs):
|
||||||
|
x, shape = args
|
||||||
|
return OptimStateFp8(x.codes.view(shape), x.scale, dtype=x.dtype)
|
||||||
|
|
||||||
|
@OptimStateFp8.implements(aten._to_copy.default)
|
||||||
|
def _(func, types, args, kwargs):
|
||||||
|
dtype = kwargs.get("dtype", args[0].dtype)
|
||||||
|
device = kwargs.get("device", None)
|
||||||
|
out = OptimStateFp8(
|
||||||
|
args[0].codes.to(device=device).clone(),
|
||||||
|
args[0].scale.to(device=device).clone(),
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
return return_and_correct_aliasing(func, args, kwargs, out)
|
||||||
|
|
||||||
|
if _needs_view_dtype_patch(OptimStateFp8):
|
||||||
|
|
||||||
|
@OptimStateFp8.implements(aten.view.dtype)
|
||||||
|
def _(func, types, args, kwargs):
|
||||||
|
x, dtype = args
|
||||||
|
return OptimStateFp8(x.codes, x.scale, dtype=dtype)
|
||||||
|
|
||||||
|
LOG.debug("Patched OptimStateFp8 for torch.compile compatibility")
|
||||||
@@ -6,6 +6,7 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from torchao.core.config import AOBaseConfig
|
from torchao.core.config import AOBaseConfig
|
||||||
from torchao.quantization import quantize_
|
from torchao.quantization import quantize_
|
||||||
|
from torchao.quantization.granularity import PerGroup
|
||||||
from torchao.quantization.qat import (
|
from torchao.quantization.qat import (
|
||||||
QATConfig,
|
QATConfig,
|
||||||
)
|
)
|
||||||
@@ -14,30 +15,22 @@ from torchao.quantization.quant_api import (
|
|||||||
Float8DynamicActivationFloat8WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
Float8DynamicActivationInt4WeightConfig,
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
Int4WeightOnlyConfig,
|
Int4WeightOnlyConfig,
|
||||||
|
Int8DynamicActivationIntxWeightConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
|
|
||||||
except ImportError:
|
|
||||||
from torchao.quantization.quant_api import (
|
|
||||||
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
|
||||||
quantization_config_to_str = {
|
quantization_config_to_str = {
|
||||||
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
Int8DynamicActivationIntxWeightConfig: "int8int4",
|
||||||
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
||||||
Float8DynamicActivationInt4WeightConfig: "fp8int4",
|
Float8DynamicActivationInt4WeightConfig: "fp8int4",
|
||||||
}
|
}
|
||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
if version.parse(torch.__version__) >= version.parse("2.8.0"):
|
||||||
try:
|
try:
|
||||||
from torchao.prototype.mx_formats import (
|
from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig
|
||||||
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
|
quantization_config_to_str[NVFP4WeightOnlyConfig] = "nvfp4"
|
||||||
except (ImportError, RuntimeError):
|
except (ImportError, RuntimeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -108,10 +101,10 @@ def get_quantization_config(
|
|||||||
activation_dtype == TorchAOQuantDType.int8
|
activation_dtype == TorchAOQuantDType.int8
|
||||||
and weight_dtype == TorchAOQuantDType.int4
|
and weight_dtype == TorchAOQuantDType.int4
|
||||||
):
|
):
|
||||||
|
kwargs = {"weight_dtype": torch.int4}
|
||||||
if group_size is not None:
|
if group_size is not None:
|
||||||
return Int8DynamicActivationInt4WeightConfig(group_size=group_size)
|
kwargs["weight_granularity"] = PerGroup(group_size=group_size)
|
||||||
else:
|
return Int8DynamicActivationIntxWeightConfig(**kwargs)
|
||||||
return Int8DynamicActivationInt4WeightConfig()
|
|
||||||
if (
|
if (
|
||||||
activation_dtype == TorchAOQuantDType.float8_e4m3fn
|
activation_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||||
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
|
and weight_dtype == TorchAOQuantDType.float8_e4m3fn
|
||||||
@@ -123,13 +116,11 @@ def get_quantization_config(
|
|||||||
):
|
):
|
||||||
return Float8DynamicActivationInt4WeightConfig()
|
return Float8DynamicActivationInt4WeightConfig()
|
||||||
if weight_dtype == TorchAOQuantDType.nvfp4:
|
if weight_dtype == TorchAOQuantDType.nvfp4:
|
||||||
from torchao.prototype.mx_formats import (
|
from torchao.prototype.mx_formats import NVFP4WeightOnlyConfig
|
||||||
NVFP4WeightOnlyConfig as NVFP4InferenceConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
if group_size is not None and group_size != 16:
|
if group_size is not None and group_size != 16:
|
||||||
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
raise ValueError("NVFP4 quantization must use a group_size of 16")
|
||||||
return NVFP4InferenceConfig()
|
return NVFP4WeightOnlyConfig()
|
||||||
|
|
||||||
if weight_dtype == TorchAOQuantDType.mxfp4:
|
if weight_dtype == TorchAOQuantDType.mxfp4:
|
||||||
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
|
# MXFP4 uses block_size=32 by default (vs NVFP4's 16)
|
||||||
|
|||||||
@@ -5,20 +5,14 @@ Tests for axolotl.utils.quantization
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchao.quantization import LinearActivationQuantizedTensor
|
from torchao.quantization import IntxUnpackedToInt8Tensor
|
||||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||||
from torchao.quantization.quant_api import (
|
from torchao.quantization.quant_api import (
|
||||||
Float8DynamicActivationFloat8WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
Float8DynamicActivationInt4WeightConfig,
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
|
Int8DynamicActivationIntxWeightConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
from torchao.quantization.quant_api import Int8DynamicActivationInt4WeightConfig
|
|
||||||
except ImportError:
|
|
||||||
from torchao.quantization.quant_api import (
|
|
||||||
Int8DynamicActivationIntxWeightConfig as Int8DynamicActivationInt4WeightConfig,
|
|
||||||
)
|
|
||||||
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
|
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
from transformers.trainer_callback import TrainerState
|
from transformers.trainer_callback import TrainerState
|
||||||
@@ -71,7 +65,7 @@ ptq_config_test_cases = [
|
|||||||
TorchAOQuantDType.int4,
|
TorchAOQuantDType.int4,
|
||||||
TorchAOQuantDType.int8,
|
TorchAOQuantDType.int8,
|
||||||
None,
|
None,
|
||||||
Int8DynamicActivationInt4WeightConfig,
|
Int8DynamicActivationIntxWeightConfig,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
TorchAOQuantDType.float8_e4m3fn,
|
TorchAOQuantDType.float8_e4m3fn,
|
||||||
@@ -96,7 +90,7 @@ ptq_test_cases = [
|
|||||||
8,
|
8,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
LinearActivationQuantizedTensor,
|
IntxUnpackedToInt8Tensor,
|
||||||
),
|
),
|
||||||
# (
|
# (
|
||||||
# TorchAOQuantDType.int4,
|
# TorchAOQuantDType.int4,
|
||||||
|
|||||||
Reference in New Issue
Block a user