diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 9bef37406..d2cb572f3 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,9 +1,14 @@ def validate_config(cfg): if cfg.adapter == "qlora": - assert cfg.load_in_8bit is False - assert cfg.load_4bit is False - assert cfg.load_in_4bit is True - pass + if cfg.merge_lora: + # can't merge qlora if loaded in 8bit or 4bit + assert cfg.load_in_8bit is False + assert cfg.load_4bit is False + assert cfg.load_in_4bit is False + else: + assert cfg.load_in_8bit is False + assert cfg.load_4bit is False + assert cfg.load_in_4bit is True # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25