This commit is contained in:
mhenrhcsen
2025-07-16 21:30:01 +02:00
parent f9bdf1fb44
commit f40e8caa28

View File

@@ -80,7 +80,7 @@ def setup_model_and_tokenizer(
model_loader = ModelLoader(cfg, tokenizer, processor=processor) model_loader = ModelLoader(cfg, tokenizer, processor=processor)
model, peft_config = model_loader.load() model, peft_config = model_loader.load()
if model.generation_config is not None: if hasattr(model, "generation_config") and model.generation_config is not None:
model.generation_config.do_sample = True model.generation_config.do_sample = True
# Apply freezing if specified # Apply freezing if specified
@@ -90,7 +90,10 @@ def setup_model_and_tokenizer(
any(embed in param for embed in ["lm_head", "embed_tokens"]) any(embed in param for embed in ["lm_head", "embed_tokens"])
for param in cfg.unfrozen_parameters for param in cfg.unfrozen_parameters
): ):
model.enable_input_require_grads() if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
LOG.warning("Model does not have enable_input_require_grads method, skipping")
return model, tokenizer, peft_config, processor return model, tokenizer, peft_config, processor
@@ -246,9 +249,12 @@ def save_trained_model(
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.") LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
# Post training module hooks # Post training module hooks
for name, module in model.named_modules(): if hasattr(model, "named_modules"):
if hasattr(module, "_post_training"): for name, module in model.named_modules():
module._post_training(model, name) # pylint: disable=protected-access if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
else:
LOG.warning("Model does not have named_modules attribute, skipping post training hooks")
# handle QAT # handle QAT
if cfg.qat: if cfg.qat:
@@ -308,11 +314,17 @@ def save_trained_model(
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
trainer.model.save_pretrained( if hasattr(trainer.model, "save_pretrained"):
cfg.output_dir, safe_serialization=safe_serialization trainer.model.save_pretrained(
) cfg.output_dir, safe_serialization=safe_serialization
)
else:
LOG.warning("Trainer model does not have save_pretrained method, skipping save")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if hasattr(model, "save_pretrained"):
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
else:
LOG.warning("Model does not have save_pretrained method, skipping save")
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin # TODO: add integration support so this can be implemented completely within the plugin
@@ -398,7 +410,10 @@ def save_initial_configs(
tokenizer.save_pretrained(str(output_dir)) tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"): if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...") LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir)) if hasattr(model.config, "save_pretrained"):
model.config.save_pretrained(str(output_dir))
else:
LOG.warning("Model config does not have save_pretrained method, skipping config save")
if processor: if processor:
LOG.info(f"Pre-saving processor to {cfg.output_dir}...") LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
@@ -461,9 +476,12 @@ def handle_untrained_tokens_fix(
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs) fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0: if cfg.local_rank == 0:
model.save_pretrained( if hasattr(model, "save_pretrained"):
str(Path(cfg.output_dir)), safe_serialization=safe_serialization model.save_pretrained(
) str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
else:
LOG.warning("Model does not have save_pretrained method, skipping save")
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[