diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index e6537ad05..85f6b358a 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -103,7 +103,7 @@ def do_inference( importlib.import_module("axolotl.prompters"), prompter ) - model = model.to(cfg.device) + model = model.to(cfg.device, dtype=cfg.torch_dtype) while True: print("=" * 80) @@ -168,7 +168,7 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) - model = model.to(cfg.device) + model = model.to(cfg.device, dtype=cfg.torch_dtype) def generate(instruction): if not instruction: