Merge branch 'main' into flash-optimum

This commit is contained in:
Wing Lian
2023-06-12 13:12:15 -04:00
committed by GitHub
36 changed files with 461 additions and 1009 deletions

View File

@@ -72,7 +72,19 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
while True:
print("=" * 80)
@@ -80,10 +92,14 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
instruction = get_multi_line_input()
if not instruction:
return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
@@ -159,7 +175,7 @@ def train(
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False:
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
@@ -199,8 +215,8 @@ def train(
logging.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if check_not_in(
["inference", "shard", "merge_lora"], kwargs
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
@@ -239,7 +255,6 @@ def train(
tokenizer,
cfg,
adapter=cfg.adapter,
inference=("inference" in kwargs),
)
if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -252,9 +267,15 @@ def train(
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if "inference" in kwargs:
if cfg.inference:
logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer)
inf_kwargs: Dict[str, Any] = {}
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
inf_kwargs["prompter"] = None
else:
inf_kwargs["prompter"] = kwargs["prompter"]
do_inference(cfg, model, tokenizer, **inf_kwargs)
return
if "shard" in kwargs: