diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 58e0e97ec..b5d5124cb 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -402,6 +402,7 @@ def load_lora(model, cfg): model = PeftModel.from_pretrained( model, cfg.lora_model_dir, + is_trainable=True, device_map=cfg.device_map, # torch_dtype=torch.float16, )