diff --git a/.pylintrc b/.pylintrc index ed973d285..9f0e453d5 100644 --- a/.pylintrc +++ b/.pylintrc @@ -12,3 +12,4 @@ generated-members=numpy.*, torch.* disable=missing-function-docstring, line-too-long, import-error, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, + too-many-boolean-expressions, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 07ae116a3..c47b724a1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -407,8 +407,10 @@ def load_model( module.to(torch.float32) needs_fa2_dtype = cfg.adapter or cfg.fsdp - if (cfg.adapter == "lora" and load_in_8bit) or ( - cfg.adapter == "qlora" and cfg.load_in_4bit + if ( + (cfg.adapter == "lora" and load_in_8bit) + or (cfg.adapter == "qlora" and cfg.load_in_4bit) + or (cfg.adapter == "ia3" and load_in_8bit) ): LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") if cfg.gradient_checkpointing: