Merge pull request #159 from AngainorDev/patch-1
Fix training over existing lora
This commit is contained in:
@@ -165,7 +165,7 @@ def train(
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or cfg.strict is False:
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
@@ -205,8 +205,8 @@ def train(
|
||||
logging.info(f"loading tokenizer... {tokenizer_config}")
|
||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||
|
||||
if check_not_in(
|
||||
["inference", "shard", "merge_lora"], kwargs
|
||||
if (
|
||||
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
||||
): # don't need to load dataset for these
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
@@ -234,7 +234,6 @@ def train(
|
||||
tokenizer,
|
||||
cfg,
|
||||
adapter=cfg.adapter,
|
||||
inference=("inference" in kwargs),
|
||||
)
|
||||
|
||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||
@@ -247,7 +246,7 @@ def train(
|
||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
return
|
||||
|
||||
if "inference" in kwargs:
|
||||
if cfg.inference:
|
||||
logging.info("calling do_inference function")
|
||||
inf_kwargs: Dict[str, Any] = {}
|
||||
if "prompter" in kwargs:
|
||||
|
||||
Reference in New Issue
Block a user