Fix: bf16 support for inference (#981)

* Fix: bf16 torch dtype

* simplify casting to device and dtype

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Tazik Shahjahan
2023-12-29 14:15:53 -08:00
committed by GitHub
parent f8ae59b0a8
commit 3678a6c41d

View File

@@ -103,7 +103,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
model = model.to(cfg.device)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
while True:
print("=" * 80)
@@ -168,7 +168,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
model = model.to(cfg.device)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
def generate(instruction):
if not instruction: