wip ao upgrade
This commit is contained in:
@@ -63,7 +63,7 @@ langdetect==1.0.9
|
|||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.13.0
|
torchao==0.15.0
|
||||||
openenv-core==0.1.0
|
openenv-core==0.1.0
|
||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ from torchao.quantization import quantize_
|
|||||||
from torchao.quantization.qat import (
|
from torchao.quantization.qat import (
|
||||||
QATConfig,
|
QATConfig,
|
||||||
)
|
)
|
||||||
|
from torchao.quantization.qat import fake_quantizer
|
||||||
|
from torchao.quantization.qat.fake_quantizer import (
|
||||||
|
Int4WeightFakeQuantizer as AoInt4WeightFakeQuantizer,
|
||||||
|
)
|
||||||
from torchao.quantization.quant_api import (
|
from torchao.quantization.quant_api import (
|
||||||
Float8DynamicActivationFloat8WeightConfig,
|
Float8DynamicActivationFloat8WeightConfig,
|
||||||
Float8DynamicActivationInt4WeightConfig,
|
Float8DynamicActivationInt4WeightConfig,
|
||||||
@@ -17,6 +21,27 @@ from torchao.quantization.quant_api import (
|
|||||||
|
|
||||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||||
|
|
||||||
|
|
||||||
|
class Int4WeightFakeQuantizer(AoInt4WeightFakeQuantizer):
|
||||||
|
"""
|
||||||
|
Adds 'enabled' attribute to Int4WeightFakeQuantizer (removed in torchao 0.15).
|
||||||
|
Allows toggling fake quantization on/off for fake_quant_after_n_steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.enabled = True
|
||||||
|
|
||||||
|
def forward(self, w: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.enabled:
|
||||||
|
return w
|
||||||
|
return super().forward(w)
|
||||||
|
|
||||||
|
|
||||||
|
# Replace the original Int4WeightFakeQuantizer in the fake_quantizer module
|
||||||
|
# so that torchao's quantize_() function will use our version
|
||||||
|
fake_quantizer.Int4WeightFakeQuantizer = Int4WeightFakeQuantizer
|
||||||
|
|
||||||
quantization_config_to_str = {
|
quantization_config_to_str = {
|
||||||
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
||||||
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
||||||
|
|||||||
Reference in New Issue
Block a user