diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b99c95158..67d0aaac2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -576,6 +576,7 @@ def load_ia3(model, cfg, inference=False): target_modules=cfg.ia3_target_modules, feedforward_modules=cfg.ia3_feedforward_modules, modules_to_save=cfg.ia3_modules_to_save, + task_type="CAUSAL_LM", **ia3_config_kwargs, )