feat: add FineGrainedFP8Config support for model quantization (#3587) [skip ci]
Allow loading FP8-quantized models (e.g. Mistral-Small-4-119B) with FineGrainedFP8Config and optional dequantize kwarg for full fine-tuning. Made-with: Cursor
This commit is contained in:
@@ -547,6 +547,16 @@ class ModelLoader:
|
||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||
|
||||
if self.cfg.model_quantization_config == "FineGrainedFP8Config":
|
||||
from transformers import FineGrainedFP8Config
|
||||
|
||||
fp8_kwargs = {}
|
||||
if self.cfg.model_quantization_config_kwargs:
|
||||
fp8_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = FineGrainedFP8Config(
|
||||
**fp8_kwargs
|
||||
)
|
||||
|
||||
if self.cfg.gptq:
|
||||
if not hasattr(self.model_config, "quantization_config"):
|
||||
LOG.warning(
|
||||
|
||||
@@ -87,9 +87,11 @@ class ModelInputConfig(BaseModel):
|
||||
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
|
||||
)
|
||||
|
||||
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Model loading quantization config"},
|
||||
model_quantization_config: Literal["Mxfp4Config", "FineGrainedFP8Config"] | None = (
|
||||
Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Model loading quantization config"},
|
||||
)
|
||||
)
|
||||
model_quantization_config_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
|
||||
Reference in New Issue
Block a user