|
|
|
@@ -9,6 +9,10 @@ from torchao.quantization import quantize_
|
|
|
|
from torchao.quantization.qat import (
|
|
|
|
from torchao.quantization.qat import (
|
|
|
|
QATConfig,
|
|
|
|
QATConfig,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
from torchao.quantization.qat import fake_quantizer
|
|
|
|
|
|
|
|
from torchao.quantization.qat.fake_quantizer import (
|
|
|
|
|
|
|
|
Int4WeightFakeQuantizer as AoInt4WeightFakeQuantizer,
|
|
|
|
|
|
|
|
)
|
|
|
|
from torchao.quantization.quant_api import (
|
|
|
|
from torchao.quantization.quant_api import (
|
|
|
|
Float8DynamicActivationFloat8WeightConfig,
|
|
|
|
Float8DynamicActivationFloat8WeightConfig,
|
|
|
|
Float8DynamicActivationInt4WeightConfig,
|
|
|
|
Float8DynamicActivationInt4WeightConfig,
|
|
|
|
@@ -17,6 +21,27 @@ from torchao.quantization.quant_api import (
|
|
|
|
|
|
|
|
|
|
|
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
|
|
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Int4WeightFakeQuantizer(AoInt4WeightFakeQuantizer):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Adds 'enabled' attribute to Int4WeightFakeQuantizer (removed in torchao 0.15).
|
|
|
|
|
|
|
|
Allows toggling fake quantization on/off for fake_quant_after_n_steps.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
|
|
|
super().__init__(config)
|
|
|
|
|
|
|
|
self.enabled = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, w: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
if not self.enabled:
|
|
|
|
|
|
|
|
return w
|
|
|
|
|
|
|
|
return super().forward(w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Replace the original Int4WeightFakeQuantizer in the fake_quantizer module
|
|
|
|
|
|
|
|
# so that torchao's quantize_() function will use our version
|
|
|
|
|
|
|
|
fake_quantizer.Int4WeightFakeQuantizer = Int4WeightFakeQuantizer
|
|
|
|
|
|
|
|
|
|
|
|
quantization_config_to_str = {
|
|
|
|
quantization_config_to_str = {
|
|
|
|
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
|
|
|
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
|
|
|
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
|
|
|
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
|
|
|
|