checks
This commit is contained in:
@@ -80,7 +80,7 @@ def setup_model_and_tokenizer(
|
|||||||
|
|
||||||
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
||||||
model, peft_config = model_loader.load()
|
model, peft_config = model_loader.load()
|
||||||
if model.generation_config is not None:
|
if hasattr(model, "generation_config") and model.generation_config is not None:
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
# Apply freezing if specified
|
# Apply freezing if specified
|
||||||
@@ -90,7 +90,10 @@ def setup_model_and_tokenizer(
|
|||||||
any(embed in param for embed in ["lm_head", "embed_tokens"])
|
any(embed in param for embed in ["lm_head", "embed_tokens"])
|
||||||
for param in cfg.unfrozen_parameters
|
for param in cfg.unfrozen_parameters
|
||||||
):
|
):
|
||||||
model.enable_input_require_grads()
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have enable_input_require_grads method, skipping")
|
||||||
|
|
||||||
return model, tokenizer, peft_config, processor
|
return model, tokenizer, peft_config, processor
|
||||||
|
|
||||||
@@ -246,9 +249,12 @@ def save_trained_model(
|
|||||||
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
||||||
|
|
||||||
# Post training module hooks
|
# Post training module hooks
|
||||||
for name, module in model.named_modules():
|
if hasattr(model, "named_modules"):
|
||||||
if hasattr(module, "_post_training"):
|
for name, module in model.named_modules():
|
||||||
module._post_training(model, name) # pylint: disable=protected-access
|
if hasattr(module, "_post_training"):
|
||||||
|
module._post_training(model, name) # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have named_modules attribute, skipping post training hooks")
|
||||||
|
|
||||||
# handle QAT
|
# handle QAT
|
||||||
if cfg.qat:
|
if cfg.qat:
|
||||||
@@ -308,11 +314,17 @@ def save_trained_model(
|
|||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
trainer.model.save_pretrained(
|
if hasattr(trainer.model, "save_pretrained"):
|
||||||
cfg.output_dir, safe_serialization=safe_serialization
|
trainer.model.save_pretrained(
|
||||||
)
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning("Trainer model does not have save_pretrained method, skipping save")
|
||||||
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
if hasattr(model, "save_pretrained"):
|
||||||
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have save_pretrained method, skipping save")
|
||||||
|
|
||||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||||
# TODO: add integration support so this can be implemented completely within the plugin
|
# TODO: add integration support so this can be implemented completely within the plugin
|
||||||
@@ -398,7 +410,10 @@ def save_initial_configs(
|
|||||||
tokenizer.save_pretrained(str(output_dir))
|
tokenizer.save_pretrained(str(output_dir))
|
||||||
if hasattr(model, "config"):
|
if hasattr(model, "config"):
|
||||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||||
model.config.save_pretrained(str(output_dir))
|
if hasattr(model.config, "save_pretrained"):
|
||||||
|
model.config.save_pretrained(str(output_dir))
|
||||||
|
else:
|
||||||
|
LOG.warning("Model config does not have save_pretrained method, skipping config save")
|
||||||
|
|
||||||
if processor:
|
if processor:
|
||||||
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
||||||
@@ -461,9 +476,12 @@ def handle_untrained_tokens_fix(
|
|||||||
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
model.save_pretrained(
|
if hasattr(model, "save_pretrained"):
|
||||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
model.save_pretrained(
|
||||||
)
|
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have save_pretrained method, skipping save")
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||||
|
|||||||
Reference in New Issue
Block a user