diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index 04b2c6341..9c09778ee 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -11,6 +11,7 @@ from torchao.quantization.qat import ( QATConfig, ) from torchao.quantization.qat.fake_quantize_config import Int4WeightFakeQuantizeConfig +from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, Float8DynamicActivationInt4WeightConfig,