From 8b9c15b17f875edcfdf0ab170b8284937c0a8ad1 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 7 May 2024 22:52:57 +0900 Subject: [PATCH] feat: exclude mamba blocks for jamba (#1578) --- src/axolotl/utils/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8537b7e75..e94a0f6b8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,4 +1,5 @@ """Module for models and model loading""" + # pylint: disable=too-many-lines import logging @@ -504,6 +505,9 @@ def load_model( bnb_config = { "load_in_8bit": True, } + # Exclude mamba blocks from int8 quantization for jamba + if cfg.model_config_type == "jamba": + bnb_config["llm_int8_skip_modules"] = ["mamba"] model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, )