ensure merged model matches the training dtype (#902)

* ensure merged model matches the training dtype

* Update src/axolotl/cli/__init__.py

* Update src/axolotl/cli/__init__.py
This commit is contained in:
Wing Lian
2023-11-29 09:55:19 -05:00
committed by GitHub
parent 71b7ea3c05
commit 1d21aa6b0a

View File

@@ -72,7 +72,7 @@ def do_merge_lora(
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
model.to(dtype=cfg.torch_dtype)
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")