diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d5184def6..01703cd51 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | - pip install -e . + pip install -e .[peft] pip install -r requirements-tests.txt - name: Run tests diff --git a/setup.py b/setup.py index 6cacd0c5b..7b99794de 100644 --- a/setup.py +++ b/setup.py @@ -32,5 +32,8 @@ setup( "extras": [ "deepspeed", ], + "peft": [ + "peft @ git+https://github.com/huggingface/peft.git", + ], }, ) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 8d07ce9b6..3d2e029c3 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -62,6 +62,13 @@ def normalize_config(cfg): else: torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False + if cfg.bf16 or cfg.bfloat16: + cfg.torch_dtype = torch.bfloat16 + elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: + cfg.torch_dtype = torch.float16 + else: + cfg.torch_dtype = torch.float32 + log_gpu_memory_usage(LOG, "baseline", cfg.device) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8e5445fff..522ab3cb4 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -11,6 +11,7 @@ import bitsandbytes as bnb import torch import transformers from optimum.bettertransformer import BetterTransformer +from peft.tuners.lora import LoraLayer from transformers import ( # noqa: F401 AutoConfig, AutoModelForCausalLM, @@ -146,12 +147,6 @@ def load_model( LOG.info("patching _expand_mask") hijack_expand_mask() - if cfg.bf16 or cfg.bfloat16: - torch_dtype = torch.bfloat16 - elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: - torch_dtype = torch.float16 - else: - torch_dtype = torch.float32 try: if cfg.gptq: from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( @@ -183,7 +178,7 @@ def load_model( load_in_4bit=True, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=torch_dtype, + bnb_4bit_compute_dtype=cfg.torch_dtype, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) @@ -242,7 +237,7 @@ def load_model( device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=torch_dtype, + torch_dtype=cfg.torch_dtype, **model_kwargs, ) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: @@ -277,7 +272,7 @@ def load_model( device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=torch_dtype, + torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -308,7 +303,7 @@ def load_model( device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=torch_dtype, + torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -322,7 +317,7 @@ def load_model( device_map=cfg.device_map, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=torch_dtype, + torch_dtype=cfg.torch_dtype, trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) @@ -356,16 +351,6 @@ def load_model( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) - # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - if cfg.flash_attention and cfg.is_llama_derived_model: - for name, module in model.named_modules(): - if "norm" in name: - module.to(torch_dtype) - if "lm_head" in name or "embed_tokens" in name: - if hasattr(module, "weight"): - module.to(torch_dtype) - model, lora_config = load_adapter(model, cfg, cfg.adapter) if cfg.ddp and not load_in_8bit: @@ -509,6 +494,22 @@ def load_lora(model, cfg): else: model = get_peft_model(model, lora_config) + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + module = module.to(cfg.torch_dtype) + if "norm" in name: + module = module.to(torch.float32) + if "lm_head" in name or "embed_tokens" in name: + if hasattr(module, "weight"): + module = module.to(cfg.torch_dtype) + + # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to + # convert them back to fp16/bf16 for flash-attn compatibility. + if cfg.flash_attention and cfg.is_llama_derived_model: + for name, module in model.named_modules(): + if "norm" in name: + module = module.to(cfg.torch_dtype) + model.print_trainable_parameters() return model, lora_config