refactor inference, warn if model is frozen

This commit is contained in:
Wing Lian
2023-05-07 01:53:30 -04:00
parent cb9a887047
commit 247825bd57
3 changed files with 20 additions and 4 deletions

View File

@@ -6,9 +6,11 @@ import random
import signal import signal
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional
import fire import fire
import torch import torch
import transformers
import yaml import yaml
from attrdict import AttrDefault from attrdict import AttrDefault
@@ -46,6 +48,15 @@ def choose_device(cfg):
cfg.device_map = {"": cfg.device} cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + Z to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"}) tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"}) tokenizer.add_special_tokens({"bos_token": "<s>"})
@@ -55,8 +66,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
while True: while True:
# support for multiline inputs # support for multiline inputs
print("Give me an instruction (Ctrl + D to finish): ") instruction = get_multi_line_input()
instruction = pathlib.Path("/proc/self/fd/0").read_text()
if not instruction: if not instruction:
return return
prompt = prompter_module().build_prompt(instruction=instruction) prompt = prompter_module().build_prompt(instruction=instruction)
@@ -66,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
with torch.no_grad(): with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this # gc = GenerationConfig() # TODO swap out and use this
generated = model.generate( generated = model.generate(
inputs=batch["input_ids"].to("cuda"), inputs=batch["input_ids"].to(cfg.device),
do_sample=True, do_sample=True,
use_cache=True, use_cache=True,
repetition_penalty=1.1, repetition_penalty=1.1,

View File

@@ -183,6 +183,12 @@ def load_model(
model.is_parallelizable = True model.is_parallelizable = True
model.model_parallel = True model.model_parallel = True
requires_grad = []
for name, param in model.named_parameters(recurse=True):
if param.requires_grad:
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
logging.warning("there are no parameters that require gradient updates")
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, tokenizer, lora_config return model, tokenizer, lora_config

View File

@@ -105,7 +105,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
run_name=cfg.wandb_run_id if cfg.use_wandb else None, run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else None, optim=cfg.optimizer if cfg.optimizer else None,
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine", lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0, weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
**training_arguments_kwargs, **training_arguments_kwargs,
) )