diff --git a/scripts/finetune.py b/scripts/finetune.py index 7c4d865fa..ab8f068aa 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -182,6 +182,9 @@ def train( if cfg.bf16: cfg.fp16 = True cfg.bf16 = False + + # Store inference mode into cfg when passed via args + cfg.inference = True if "inference" in kwargs else cfg.get("inference", False) # load the tokenizer first tokenizer_config = cfg.tokenizer_config or cfg.base_model_config @@ -189,8 +192,8 @@ def train( tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) if check_not_in( - ["inference", "shard", "merge_lora"], kwargs - ): # don't need to load dataset for these + ["shard", "merge_lora"], kwargs + ) and not cfg.inference: # don't need to load dataset for these train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) @@ -216,8 +219,7 @@ def train( cfg.model_type, tokenizer, cfg, - adapter=cfg.adapter, - inference=("inference" in kwargs), + adapter=cfg.adapter ) if "merge_lora" in kwargs and cfg.adapter is not None: @@ -230,7 +232,7 @@ def train( model.save_pretrained(str(Path(cfg.output_dir) / "merged")) return - if "inference" in kwargs: + if cfg.inference: logging.info("calling do_inference function") do_inference(cfg, model, tokenizer) return