Fix strict and Lint

This commit is contained in:
AngainorDev
2023-06-11 15:23:38 +02:00
parent a808bf913f
commit b565ecf0a1
2 changed files with 7 additions and 12 deletions

View File

@@ -158,7 +158,7 @@ def train(
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False:
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
@@ -198,9 +198,9 @@ def train(
logging.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if check_not_in(
["shard", "merge_lora"], kwargs
) and not cfg.inference: # don't need to load dataset for these
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
@@ -226,7 +226,7 @@ def train(
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter
adapter=cfg.adapter,
)
if "merge_lora" in kwargs and cfg.adapter is not None:

View File

@@ -77,14 +77,9 @@ def load_tokenizer(
def load_model(
base_model,
base_model_config,
model_type,
tokenizer,
cfg,
adapter="lora"
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
):
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
# type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""