checks
This commit is contained in:
@@ -80,7 +80,7 @@ def setup_model_and_tokenizer(
|
||||
|
||||
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
||||
model, peft_config = model_loader.load()
|
||||
if model.generation_config is not None:
|
||||
if hasattr(model, "generation_config") and model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
# Apply freezing if specified
|
||||
@@ -90,7 +90,10 @@ def setup_model_and_tokenizer(
|
||||
any(embed in param for embed in ["lm_head", "embed_tokens"])
|
||||
for param in cfg.unfrozen_parameters
|
||||
):
|
||||
model.enable_input_require_grads()
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
LOG.warning("Model does not have enable_input_require_grads method, skipping")
|
||||
|
||||
return model, tokenizer, peft_config, processor
|
||||
|
||||
@@ -246,9 +249,12 @@ def save_trained_model(
|
||||
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
||||
|
||||
# Post training module hooks
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
if hasattr(model, "named_modules"):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
else:
|
||||
LOG.warning("Model does not have named_modules attribute, skipping post training hooks")
|
||||
|
||||
# handle QAT
|
||||
if cfg.qat:
|
||||
@@ -308,11 +314,17 @@ def save_trained_model(
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
if hasattr(trainer.model, "save_pretrained"):
|
||||
trainer.model.save_pretrained(
|
||||
cfg.output_dir, safe_serialization=safe_serialization
|
||||
)
|
||||
else:
|
||||
LOG.warning("Trainer model does not have save_pretrained method, skipping save")
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
if hasattr(model, "save_pretrained"):
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
else:
|
||||
LOG.warning("Model does not have save_pretrained method, skipping save")
|
||||
|
||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||
# TODO: add integration support so this can be implemented completely within the plugin
|
||||
@@ -398,7 +410,10 @@ def save_initial_configs(
|
||||
tokenizer.save_pretrained(str(output_dir))
|
||||
if hasattr(model, "config"):
|
||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||
model.config.save_pretrained(str(output_dir))
|
||||
if hasattr(model.config, "save_pretrained"):
|
||||
model.config.save_pretrained(str(output_dir))
|
||||
else:
|
||||
LOG.warning("Model config does not have save_pretrained method, skipping config save")
|
||||
|
||||
if processor:
|
||||
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
||||
@@ -461,9 +476,12 @@ def handle_untrained_tokens_fix(
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
)
|
||||
if hasattr(model, "save_pretrained"):
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
)
|
||||
else:
|
||||
LOG.warning("Model does not have save_pretrained method, skipping save")
|
||||
|
||||
|
||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||
|
||||
Reference in New Issue
Block a user