Compare commits
1 Commits
streaming-
...
quantize-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a51852af1 |
@@ -43,7 +43,7 @@ def do_quantize(
|
||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||
)
|
||||
|
||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
||||
model_path = cli_args.get("base_model") or cfg.output_dir
|
||||
if weight_dtype := cli_args.get("weight_dtype"):
|
||||
weight_dtype = TorchIntDType[weight_dtype]
|
||||
else:
|
||||
|
||||
@@ -20,6 +20,7 @@ from torchao.quantization.quant_api import (
|
||||
UIntXWeightOnlyConfig,
|
||||
_is_linear,
|
||||
)
|
||||
from transformers import TorchAoConfig
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
|
||||
@@ -149,7 +150,9 @@ def quantize_model_for_ptq(
|
||||
group_size=group_size,
|
||||
)
|
||||
quantize_(model, linear_ptq_config)
|
||||
quantization_config = TorchAoConfig(linear_ptq_config)
|
||||
if quantize_embedding:
|
||||
quantization_config.include_input_output_embeddings = True
|
||||
embedding_quantize_config = get_ptq_config(
|
||||
weight_dtype=weight_dtype,
|
||||
activation_dtype=None,
|
||||
@@ -160,6 +163,7 @@ def quantize_model_for_ptq(
|
||||
embedding_quantize_config,
|
||||
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
|
||||
)
|
||||
model.config.quantization_config = quantization_config
|
||||
|
||||
|
||||
def convert_qat_model_for_ptq(
|
||||
|
||||
Reference in New Issue
Block a user