This commit is contained in:
mhenrhcsen
2025-07-16 21:30:01 +02:00
parent f9bdf1fb44
commit f40e8caa28

View File

@@ -80,7 +80,7 @@ def setup_model_and_tokenizer(
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
model, peft_config = model_loader.load()
if model.generation_config is not None:
if hasattr(model, "generation_config") and model.generation_config is not None:
model.generation_config.do_sample = True
# Apply freezing if specified
@@ -90,7 +90,10 @@ def setup_model_and_tokenizer(
any(embed in param for embed in ["lm_head", "embed_tokens"])
for param in cfg.unfrozen_parameters
):
model.enable_input_require_grads()
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
LOG.warning("Model does not have enable_input_require_grads method, skipping")
return model, tokenizer, peft_config, processor
@@ -246,9 +249,12 @@ def save_trained_model(
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
# Post training module hooks
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
if hasattr(model, "named_modules"):
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
else:
LOG.warning("Model does not have named_modules attribute, skipping post training hooks")
# handle QAT
if cfg.qat:
@@ -308,11 +314,17 @@ def save_trained_model(
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
if hasattr(trainer.model, "save_pretrained"):
trainer.model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
else:
LOG.warning("Trainer model does not have save_pretrained method, skipping save")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if hasattr(model, "save_pretrained"):
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
else:
LOG.warning("Model does not have save_pretrained method, skipping save")
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
@@ -398,7 +410,10 @@ def save_initial_configs(
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
if hasattr(model.config, "save_pretrained"):
model.config.save_pretrained(str(output_dir))
else:
LOG.warning("Model config does not have save_pretrained method, skipping config save")
if processor:
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
@@ -461,9 +476,12 @@ def handle_untrained_tokens_fix(
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
if hasattr(model, "save_pretrained"):
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
else:
LOG.warning("Model does not have save_pretrained method, skipping save")
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[