Feed cfg.inference
This commit is contained in:
committed by
GitHub
parent
813cfa4c14
commit
bd3b537344
@@ -182,6 +182,9 @@ def train(
|
|||||||
if cfg.bf16:
|
if cfg.bf16:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
cfg.bf16 = False
|
cfg.bf16 = False
|
||||||
|
|
||||||
|
# Store inference mode into cfg when passed via args
|
||||||
|
cfg.inference = True if "inference" in kwargs else cfg.get("inference", False)
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
@@ -189,8 +192,8 @@ def train(
|
|||||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||||
|
|
||||||
if check_not_in(
|
if check_not_in(
|
||||||
["inference", "shard", "merge_lora"], kwargs
|
["shard", "merge_lora"], kwargs
|
||||||
): # don't need to load dataset for these
|
) and not cfg.inference: # don't need to load dataset for these
|
||||||
train_dataset, eval_dataset = load_prepare_datasets(
|
train_dataset, eval_dataset = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
@@ -216,8 +219,7 @@ def train(
|
|||||||
cfg.model_type,
|
cfg.model_type,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
adapter=cfg.adapter,
|
adapter=cfg.adapter
|
||||||
inference=("inference" in kwargs),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||||
@@ -230,7 +232,7 @@ def train(
|
|||||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
return
|
return
|
||||||
|
|
||||||
if "inference" in kwargs:
|
if cfg.inference:
|
||||||
logging.info("calling do_inference function")
|
logging.info("calling do_inference function")
|
||||||
do_inference(cfg, model, tokenizer)
|
do_inference(cfg, model, tokenizer)
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user