prepare ia3 for 8bit

This commit is contained in:
Wing Lian
2023-09-18 18:49:44 -04:00
parent 2d7cccfc8e
commit 1da328eb9a
2 changed files with 5 additions and 2 deletions

View File

@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-boolean-expressions,

View File

@@ -407,8 +407,10 @@ def load_model(
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
if (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
or (cfg.adapter == "ia3" and load_in_8bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: