Feed cfg.inference

This commit is contained in:
Angainor Development
2023-06-09 08:59:05 +02:00
committed by GitHub
parent 813cfa4c14
commit bd3b537344

View File

@@ -182,6 +182,9 @@ def train(
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
# Store inference mode into cfg when passed via args
cfg.inference = True if "inference" in kwargs else cfg.get("inference", False)
# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
@@ -189,8 +192,8 @@ def train(
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if check_not_in(
["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these
["shard", "merge_lora"], kwargs
) and not cfg.inference: # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
@@ -216,8 +219,7 @@ def train(
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
inference=("inference" in kwargs),
adapter=cfg.adapter
)
if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -230,7 +232,7 @@ def train(
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if "inference" in kwargs:
if cfg.inference:
logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer)
return