Feat: Update validate_config and add tests

This commit is contained in:
NanoCode012
2023-05-29 00:25:54 +09:00
parent 88889590ec
commit 52dd92a0cd
2 changed files with 119 additions and 9 deletions

View File

@@ -3,24 +3,39 @@ import logging
def validate_config(cfg):
if cfg.load_4bit:
raise ValueError("cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq")
raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
)
if cfg.adapter == "qlora":
if cfg.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
assert cfg.load_in_8bit is not True
assert cfg.gptq is not True
assert cfg.load_in_4bit is not True
if cfg.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if cfg.gptq:
raise ValueError("Can't merge qlora if gptq")
if cfg.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
assert cfg.load_in_8bit is not True
assert cfg.gptq is not True
assert cfg.load_in_4bit is True
if cfg.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if cfg.gptq:
raise ValueError("Can't load qlora if gptq")
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora":
logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.trust_remote_code:
logging.warning("`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.")
logging.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
)
# TODO
# MPT 7b