diff --git a/requirements.txt b/requirements.txt index ae6689680..e45037fc0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ peft @ git+https://github.com/huggingface/peft.git transformers @ git+https://github.com/huggingface/transformers.git +bitsandbytes @ git+https://github.com/TimDettmers/bitsandbytes.git attrdict fire PyYAML==6.0 black -bitsandbytes==0.37.2 datasets accelerate>=0.19.0 sentencepiece diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 934f2f74c..992abf3ed 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -6,11 +6,12 @@ from typing import Optional, Tuple, TYPE_CHECKING import torch import transformers +from torch import nn from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, - AutoConfig, + AutoConfig, BitsAndBytesConfig, ) try: @@ -81,6 +82,16 @@ def load_model( logging.exception(e) raise e + model_kwargs = {} + if cfg.adapter == "qlora": + model_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ) try: if cfg.load_4bit and is_llama_derived_model: from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram @@ -125,6 +136,7 @@ def load_model( load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map=cfg.device_map, + **model_kwargs, ) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # This is a WIP, still an issue with the backward pass @@ -159,6 +171,7 @@ def load_model( torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, + **model_kwargs, ) else: config = AutoConfig.from_pretrained( @@ -172,6 +185,7 @@ def load_model( torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, + **model_kwargs, ) except Exception as e: logging.error( @@ -184,8 +198,24 @@ def load_model( torch_dtype=torch_dtype, device_map=cfg.device_map, trust_remote_code=True if cfg.trust_remote_code is True else False, + **model_kwargs, ) + """### Post-processing on the model + Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons. + """ + if cfg.adapter == "qlora": + for param in model.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + class CastOutputToFloat(nn.Sequential): + def forward(self, x): + return super().forward(x).to(torch.float32) + + model.lm_head = CastOutputToFloat(model.lm_head) + if not tokenizer: try: if is_llama_derived_model and "LlamaTokenizer" in globals(): @@ -270,7 +300,7 @@ def load_adapter(model, cfg, adapter): if adapter is None: return model, None - if adapter == "lora": + if adapter == "lora" or adapter == "qlora": return load_lora(model, cfg) if adapter == "llama-adapter": return load_llama_adapter(model, cfg)