optionally be able to specify alpaca or chat style prompts
This commit is contained in:
@@ -31,7 +31,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
def choose_device(cfg):
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return f"cuda:{cfg.local_rank}"
|
||||
else:
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -131,7 +131,8 @@ def train(
|
||||
# then overwrite the value
|
||||
cfg_keys = dict(cfg).keys()
|
||||
for k in kwargs:
|
||||
if k in cfg_keys:
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or cfg.strict is False:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
@@ -169,6 +170,15 @@ def train(
|
||||
inference=("inference" in kwargs),
|
||||
)
|
||||
|
||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||
print("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
print("saving merged model")
|
||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
return
|
||||
|
||||
if "inference" in kwargs:
|
||||
logging.info("calling do_inference function")
|
||||
do_inference(cfg, model, tokenizer)
|
||||
@@ -216,6 +226,8 @@ def train(
|
||||
)
|
||||
|
||||
logging.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
logging.info("hang tight... sorting dataset for group_by_length")
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
|
||||
Reference in New Issue
Block a user