From 3985ec2f67145a7ced8978325c4ae419bdb8c68f Mon Sep 17 00:00:00 2001 From: madScientist10 <42779409+madScientist10@users.noreply.github.com> Date: Mon, 13 Apr 2026 03:50:37 +0300 Subject: [PATCH] feat: add FineGrainedFP8Config support for model quantization (#3587) [skip ci] Allow loading FP8-quantized models (e.g. Mistral-Small-4-119B) with FineGrainedFP8Config and optional dequantize kwarg for full fine-tuning. Made-with: Cursor --- src/axolotl/loaders/model.py | 10 ++++++++++ src/axolotl/utils/schemas/model.py | 8 +++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 83b6452dc..4f5779327 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -547,6 +547,16 @@ class ModelLoader: mxfp4_kwargs = self.cfg.model_quantization_config_kwargs self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs) + if self.cfg.model_quantization_config == "FineGrainedFP8Config": + from transformers import FineGrainedFP8Config + + fp8_kwargs = {} + if self.cfg.model_quantization_config_kwargs: + fp8_kwargs = self.cfg.model_quantization_config_kwargs + self.model_kwargs["quantization_config"] = FineGrainedFP8Config( + **fp8_kwargs + ) + if self.cfg.gptq: if not hasattr(self.model_config, "quantization_config"): LOG.warning( diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 02b971c1d..3c5dfc6e3 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -87,9 +87,11 @@ class ModelInputConfig(BaseModel): json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."}, ) - model_quantization_config: Literal["Mxfp4Config"] | None = Field( - default=None, - json_schema_extra={"description": "Model loading quantization config"}, + model_quantization_config: Literal["Mxfp4Config", "FineGrainedFP8Config"] | None = ( + Field( + default=None, + json_schema_extra={"description": "Model loading quantization config"}, + ) ) model_quantization_config_kwargs: dict[str, Any] | None = Field( default=None,