Compare commits
1 Commits
dynamic-sf
...
upgrade-to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a08e4117a |
@@ -63,7 +63,7 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.13.0
|
||||
torchao==0.15.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
|
||||
@@ -9,6 +9,10 @@ from torchao.quantization import quantize_
|
||||
from torchao.quantization.qat import (
|
||||
QATConfig,
|
||||
)
|
||||
from torchao.quantization.qat import fake_quantizer
|
||||
from torchao.quantization.qat.fake_quantizer import (
|
||||
Int4WeightFakeQuantizer as AoInt4WeightFakeQuantizer,
|
||||
)
|
||||
from torchao.quantization.quant_api import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
@@ -17,6 +21,27 @@ from torchao.quantization.quant_api import (
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||
|
||||
|
||||
class Int4WeightFakeQuantizer(AoInt4WeightFakeQuantizer):
|
||||
"""
|
||||
Adds 'enabled' attribute to Int4WeightFakeQuantizer (removed in torchao 0.15).
|
||||
Allows toggling fake quantization on/off for fake_quant_after_n_steps.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.enabled = True
|
||||
|
||||
def forward(self, w: torch.Tensor) -> torch.Tensor:
|
||||
if not self.enabled:
|
||||
return w
|
||||
return super().forward(w)
|
||||
|
||||
|
||||
# Replace the original Int4WeightFakeQuantizer in the fake_quantizer module
|
||||
# so that torchao's quantize_() function will use our version
|
||||
fake_quantizer.Int4WeightFakeQuantizer = Int4WeightFakeQuantizer
|
||||
|
||||
quantization_config_to_str = {
|
||||
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
||||
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
||||
|
||||
Reference in New Issue
Block a user