Fix: bf16 support for inference (#981)
* Fix: bf16 torch dtype * simplify casting to device and dtype --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -103,7 +103,7 @@ def do_inference(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.to(cfg.device)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -168,7 +168,7 @@ def do_inference_gradio(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.to(cfg.device)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
def generate(instruction):
|
def generate(instruction):
|
||||||
if not instruction:
|
if not instruction:
|
||||||
|
|||||||
Reference in New Issue
Block a user