refactor inference, warn if model is frozen

This commit is contained in:
Wing Lian
2023-05-07 01:53:30 -04:00
parent cb9a887047
commit 247825bd57
3 changed files with 20 additions and 4 deletions

View File

@@ -183,6 +183,12 @@ def load_model(
model.is_parallelizable = True
model.model_parallel = True
requires_grad = []
for name, param in model.named_parameters(recurse=True):
if param.requires_grad:
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
logging.warning("there are no parameters that require gradient updates")
# TODO resume_from_checkpoint handling
return model, tokenizer, lora_config

View File

@@ -105,7 +105,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else None,
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
**training_arguments_kwargs,
)