diff --git a/scripts/finetune.py b/scripts/finetune.py index b79079e26..0c8727401 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -176,6 +176,7 @@ def train( if "merge_lora" in kwargs and cfg.adapter is not None: logging.info("running merge of LoRA with base model") model = model.merge_and_unload() + model.to(dtype=torch.float16) if cfg.local_rank == 0: logging.info("saving merged model") diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index d2cb572f3..babf246f5 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,3 +1,6 @@ +import logging + + def validate_config(cfg): if cfg.adapter == "qlora": if cfg.merge_lora: @@ -9,6 +12,9 @@ def validate_config(cfg): assert cfg.load_in_8bit is False assert cfg.load_4bit is False assert cfg.load_in_4bit is True + if cfg.load_in_8bit and cfg.adapter == "lora": + logging.warning("we recommend setting `load_in_8bit: true`") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25