Compare commits

...

1 Commits

Author SHA1 Message Date
Salman Mohammadi
7a08e4117a wip ao upgrade 2026-01-05 18:23:33 +00:00
2 changed files with 26 additions and 1 deletions

View File

@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.13.0
torchao==0.15.0
openenv-core==0.1.0
schedulefree==1.4.1

View File

@@ -9,6 +9,10 @@ from torchao.quantization import quantize_
from torchao.quantization.qat import (
QATConfig,
)
from torchao.quantization.qat import fake_quantizer
from torchao.quantization.qat.fake_quantizer import (
Int4WeightFakeQuantizer as AoInt4WeightFakeQuantizer,
)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
@@ -17,6 +21,27 @@ from torchao.quantization.quant_api import (
from axolotl.utils.schemas.enums import TorchAOQuantDType
class Int4WeightFakeQuantizer(AoInt4WeightFakeQuantizer):
"""
Adds 'enabled' attribute to Int4WeightFakeQuantizer (removed in torchao 0.15).
Allows toggling fake quantization on/off for fake_quant_after_n_steps.
"""
def __init__(self, config):
super().__init__(config)
self.enabled = True
def forward(self, w: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return w
return super().forward(w)
# Replace the original Int4WeightFakeQuantizer in the fake_quantizer module
# so that torchao's quantize_() function will use our version
fake_quantizer.Int4WeightFakeQuantizer = Int4WeightFakeQuantizer
quantization_config_to_str = {
Int8DynamicActivationInt4WeightConfig: "int8int4",
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",