feat: exclude mamba blocks for jamba (#1578)

This commit is contained in:
NanoCode012
2024-05-07 22:52:57 +09:00
committed by GitHub
parent 9e1480e9ca
commit 8b9c15b17f

View File

@@ -1,4 +1,5 @@
"""Module for models and model loading"""
# pylint: disable=too-many-lines
import logging
@@ -504,6 +505,9 @@ def load_model(
bnb_config = {
"load_in_8bit": True,
}
# Exclude mamba blocks from int8 quantization for jamba
if cfg.model_config_type == "jamba":
bnb_config["llm_int8_skip_modules"] = ["mamba"]
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)