refactor inference, warn if model is frozen
This commit is contained in:
@@ -6,9 +6,11 @@ import random
|
|||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
import yaml
|
import yaml
|
||||||
from attrdict import AttrDefault
|
from attrdict import AttrDefault
|
||||||
|
|
||||||
@@ -46,6 +48,15 @@ def choose_device(cfg):
|
|||||||
cfg.device_map = {"": cfg.device}
|
cfg.device_map = {"": cfg.device}
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_line_input() -> Optional[str]:
|
||||||
|
print("Give me an instruction (Ctrl + Z to finish): ")
|
||||||
|
instruction = ""
|
||||||
|
for line in sys.stdin:
|
||||||
|
instruction += line
|
||||||
|
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
||||||
|
return instruction
|
||||||
|
|
||||||
|
|
||||||
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||||
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
tokenizer.add_special_tokens({"unk_token": "<unk>"})
|
||||||
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
tokenizer.add_special_tokens({"bos_token": "<s>"})
|
||||||
@@ -55,8 +66,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
# support for multiline inputs
|
# support for multiline inputs
|
||||||
print("Give me an instruction (Ctrl + D to finish): ")
|
instruction = get_multi_line_input()
|
||||||
instruction = pathlib.Path("/proc/self/fd/0").read_text()
|
|
||||||
if not instruction:
|
if not instruction:
|
||||||
return
|
return
|
||||||
prompt = prompter_module().build_prompt(instruction=instruction)
|
prompt = prompter_module().build_prompt(instruction=instruction)
|
||||||
@@ -66,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# gc = GenerationConfig() # TODO swap out and use this
|
# gc = GenerationConfig() # TODO swap out and use this
|
||||||
generated = model.generate(
|
generated = model.generate(
|
||||||
inputs=batch["input_ids"].to("cuda"),
|
inputs=batch["input_ids"].to(cfg.device),
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
repetition_penalty=1.1,
|
repetition_penalty=1.1,
|
||||||
|
|||||||
@@ -183,6 +183,12 @@ def load_model(
|
|||||||
model.is_parallelizable = True
|
model.is_parallelizable = True
|
||||||
model.model_parallel = True
|
model.model_parallel = True
|
||||||
|
|
||||||
|
requires_grad = []
|
||||||
|
for name, param in model.named_parameters(recurse=True):
|
||||||
|
if param.requires_grad:
|
||||||
|
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||||
|
if len(requires_grad) == 0:
|
||||||
|
logging.warning("there are no parameters that require gradient updates")
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, tokenizer, lora_config
|
return model, tokenizer, lora_config
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
||||||
optim=cfg.optimizer if cfg.optimizer else None,
|
optim=cfg.optimizer if cfg.optimizer else None,
|
||||||
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
|
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
|
||||||
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
|
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user