diff --git a/requirements.txt b/requirements.txt index f69135902..a1c2082f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -63,7 +63,7 @@ langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 -torchao==0.13.0 +torchao==0.15.0 openenv-core==0.1.0 schedulefree==1.4.1 diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index 6c29a5442..cae507ecc 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -9,6 +9,10 @@ from torchao.quantization import quantize_ from torchao.quantization.qat import ( QATConfig, ) +from torchao.quantization.qat import fake_quantizer +from torchao.quantization.qat.fake_quantizer import ( + Int4WeightFakeQuantizer as AoInt4WeightFakeQuantizer, +) from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig, @@ -17,6 +21,27 @@ from torchao.quantization.quant_api import ( from axolotl.utils.schemas.enums import TorchAOQuantDType + +class Int4WeightFakeQuantizer(AoInt4WeightFakeQuantizer): + """ + Adds 'enabled' attribute to Int4WeightFakeQuantizer (removed in torchao 0.15). + Allows toggling fake quantization on/off for fake_quant_after_n_steps. + """ + + def __init__(self, config): + super().__init__(config) + self.enabled = True + + def forward(self, w: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return w + return super().forward(w) + + +# Replace the original Int4WeightFakeQuantizer in the fake_quantizer module +# so that torchao's quantize_() function will use our version +fake_quantizer.Int4WeightFakeQuantizer = Int4WeightFakeQuantizer + quantization_config_to_str = { Int8DynamicActivationInt4WeightConfig: "int8int4", Float8DynamicActivationFloat8WeightConfig: "fp8fp8",