match up gradient checkpointing when using lora w config

This commit is contained in:
Wing Lian
2023-06-11 09:20:40 -04:00
parent e944311442
commit fe0b76854e

View File

@@ -305,7 +305,9 @@ def load_model(
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training(model)
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
model, lora_config = load_adapter(model, cfg, adapter)