From 3678a6c41d051ca6376d013c11c948e55b4c8b4f Mon Sep 17 00:00:00 2001 From: Tazik Shahjahan <35576188+taziksh@users.noreply.github.com> Date: Fri, 29 Dec 2023 14:15:53 -0800 Subject: [PATCH] Fix: bf16 support for inference (#981) * Fix: bf16 torch dtype * simplify casting to device and dtype --------- Co-authored-by: Wing Lian --- src/axolotl/cli/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index e6537ad05..85f6b358a 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -103,7 +103,7 @@ def do_inference( importlib.import_module("axolotl.prompters"), prompter ) - model = model.to(cfg.device) + model = model.to(cfg.device, dtype=cfg.torch_dtype) while True: print("=" * 80) @@ -168,7 +168,7 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) - model = model.to(cfg.device) + model = model.to(cfg.device, dtype=cfg.torch_dtype) def generate(instruction): if not instruction: