Fix: bf16 support for inference (#981)
* Fix: bf16 torch dtype * simplify casting to device and dtype --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -103,7 +103,7 @@ def do_inference(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
@@ -168,7 +168,7 @@ def do_inference_gradio(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
def generate(instruction):
|
||||
if not instruction:
|
||||
|
||||
Reference in New Issue
Block a user