optionally be able to specify alpaca or chat style prompts

This commit is contained in:
Wing Lian
2023-05-20 18:16:22 -04:00
parent fa8bd14be4
commit 1d5ab84486
6 changed files with 223 additions and 53 deletions

View File

@@ -31,7 +31,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def choose_device(cfg):
def get_device():
if torch.cuda.is_available():
return "cuda"
return f"cuda:{cfg.local_rank}"
else:
try:
if torch.backends.mps.is_available():
@@ -131,7 +131,8 @@ def train(
# then overwrite the value
cfg_keys = dict(cfg).keys()
for k in kwargs:
if k in cfg_keys:
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
@@ -169,6 +170,15 @@ def train(
inference=("inference" in kwargs),
)
if "merge_lora" in kwargs and cfg.adapter is not None:
print("running merge of LoRA with base model")
model = model.merge_and_unload()
if cfg.local_rank == 0:
print("saving merged model")
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if "inference" in kwargs:
logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer)
@@ -216,6 +226,8 @@ def train(
)
logging.info("Starting trainer...")
if cfg.group_by_length:
logging.info("hang tight... sorting dataset for group_by_length")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [