Compare commits
1 Commits
diffusion-
...
quantize-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a51852af1 |
@@ -43,7 +43,7 @@ def do_quantize(
|
|||||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||||
)
|
)
|
||||||
|
|
||||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
model_path = cli_args.get("base_model") or cfg.output_dir
|
||||||
if weight_dtype := cli_args.get("weight_dtype"):
|
if weight_dtype := cli_args.get("weight_dtype"):
|
||||||
weight_dtype = TorchIntDType[weight_dtype]
|
weight_dtype = TorchIntDType[weight_dtype]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from torchao.quantization.quant_api import (
|
|||||||
UIntXWeightOnlyConfig,
|
UIntXWeightOnlyConfig,
|
||||||
_is_linear,
|
_is_linear,
|
||||||
)
|
)
|
||||||
|
from transformers import TorchAoConfig
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import TorchIntDType
|
from axolotl.utils.schemas.enums import TorchIntDType
|
||||||
|
|
||||||
@@ -149,7 +150,9 @@ def quantize_model_for_ptq(
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
)
|
)
|
||||||
quantize_(model, linear_ptq_config)
|
quantize_(model, linear_ptq_config)
|
||||||
|
quantization_config = TorchAoConfig(linear_ptq_config)
|
||||||
if quantize_embedding:
|
if quantize_embedding:
|
||||||
|
quantization_config.include_input_output_embeddings = True
|
||||||
embedding_quantize_config = get_ptq_config(
|
embedding_quantize_config = get_ptq_config(
|
||||||
weight_dtype=weight_dtype,
|
weight_dtype=weight_dtype,
|
||||||
activation_dtype=None,
|
activation_dtype=None,
|
||||||
@@ -160,6 +163,7 @@ def quantize_model_for_ptq(
|
|||||||
embedding_quantize_config,
|
embedding_quantize_config,
|
||||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||||
)
|
)
|
||||||
|
model.config.quantization_config = quantization_config
|
||||||
|
|
||||||
|
|
||||||
def convert_qat_model_for_ptq(
|
def convert_qat_model_for_ptq(
|
||||||
|
|||||||
Reference in New Issue
Block a user