diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 83b6452dc..4f5779327 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -547,6 +547,16 @@ class ModelLoader: mxfp4_kwargs = self.cfg.model_quantization_config_kwargs self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs) + if self.cfg.model_quantization_config == "FineGrainedFP8Config": + from transformers import FineGrainedFP8Config + + fp8_kwargs = {} + if self.cfg.model_quantization_config_kwargs: + fp8_kwargs = self.cfg.model_quantization_config_kwargs + self.model_kwargs["quantization_config"] = FineGrainedFP8Config( + **fp8_kwargs + ) + if self.cfg.gptq: if not hasattr(self.model_config, "quantization_config"): LOG.warning( diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 02b971c1d..3c5dfc6e3 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -87,9 +87,11 @@ class ModelInputConfig(BaseModel): json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."}, ) - model_quantization_config: Literal["Mxfp4Config"] | None = Field( - default=None, - json_schema_extra={"description": "Model loading quantization config"}, + model_quantization_config: Literal["Mxfp4Config", "FineGrainedFP8Config"] | None = ( + Field( + default=None, + json_schema_extra={"description": "Model loading quantization config"}, + ) ) model_quantization_config_kwargs: dict[str, Any] | None = Field( default=None,