From 41353d2ea04db3478d2f6f9069b7d0adb1f30ae8 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 29 Dec 2023 18:16:26 +0900 Subject: [PATCH] feat: expose bnb kwargs (#1018) * feat: expose bnb kwargs * chore: added examples and link per suggestion * Uncomment defaults per suggestion for readability Co-authored-by: Hamel Husain --------- Co-authored-by: Hamel Husain --- README.md | 8 ++++++++ src/axolotl/utils/models.py | 19 +++++++++++++------ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index bbed3e10d..d15c4c001 100644 --- a/README.md +++ b/README.md @@ -520,6 +520,14 @@ model_config: type: # linear | dynamic factor: # float +# optional overrides to the bnb 4bit quantization configuration +# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig +bnb_config_kwargs: + # These are default values + llm_int8_has_fp16_weight: false + bnb_4bit_quant_type: nf4 + bnb_4bit_use_double_quant: true + # Whether you are training a 4-bit GPTQ quantized model gptq: true diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 872d530ab..c2b3a758c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -301,13 +301,20 @@ def load_model( **model_config.quantization_config ) if cfg.adapter == "qlora" and cfg.load_in_4bit: + bnb_config = { + "load_in_4bit": True, + "llm_int8_threshold": 6.0, + "llm_int8_has_fp16_weight": False, + "bnb_4bit_compute_dtype": cfg.torch_dtype, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_quant_type": "nf4", + } + + if cfg.bnb_config_kwargs: + bnb_config.update(cfg.bnb_config_kwargs) + model_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - llm_int8_threshold=6.0, - llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=cfg.torch_dtype, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", + **bnb_config, ) # sample packing uses custom FA2 patch if cfg.flash_attention: