Feed cfg.inference

This commit is contained in:
Angainor Development
2023-06-09 08:59:05 +02:00
committed by GitHub
parent 813cfa4c14
commit bd3b537344

View File

@@ -182,6 +182,9 @@ def train(
if cfg.bf16: if cfg.bf16:
cfg.fp16 = True cfg.fp16 = True
cfg.bf16 = False cfg.bf16 = False
# Store inference mode into cfg when passed via args
cfg.inference = True if "inference" in kwargs else cfg.get("inference", False)
# load the tokenizer first # load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
@@ -189,8 +192,8 @@ def train(
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if check_not_in( if check_not_in(
["inference", "shard", "merge_lora"], kwargs ["shard", "merge_lora"], kwargs
): # don't need to load dataset for these ) and not cfg.inference: # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
@@ -216,8 +219,7 @@ def train(
cfg.model_type, cfg.model_type,
tokenizer, tokenizer,
cfg, cfg,
adapter=cfg.adapter, adapter=cfg.adapter
inference=("inference" in kwargs),
) )
if "merge_lora" in kwargs and cfg.adapter is not None: if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -230,7 +232,7 @@ def train(
model.save_pretrained(str(Path(cfg.output_dir) / "merged")) model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return return
if "inference" in kwargs: if cfg.inference:
logging.info("calling do_inference function") logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer) do_inference(cfg, model, tokenizer)
return return