refactor inference, warn if model is frozen
This commit is contained in:
@@ -183,6 +183,12 @@ def load_model(
|
||||
model.is_parallelizable = True
|
||||
model.model_parallel = True
|
||||
|
||||
requires_grad = []
|
||||
for name, param in model.named_parameters(recurse=True):
|
||||
if param.requires_grad:
|
||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||
if len(requires_grad) == 0:
|
||||
logging.warning("there are no parameters that require gradient updates")
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, tokenizer, lora_config
|
||||
|
||||
Reference in New Issue
Block a user