diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 967179903..3a206d3da 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -80,7 +80,7 @@ def setup_model_and_tokenizer( model_loader = ModelLoader(cfg, tokenizer, processor=processor) model, peft_config = model_loader.load() - if model.generation_config is not None: + if hasattr(model, "generation_config") and model.generation_config is not None: model.generation_config.do_sample = True # Apply freezing if specified @@ -90,7 +90,10 @@ def setup_model_and_tokenizer( any(embed in param for embed in ["lm_head", "embed_tokens"]) for param in cfg.unfrozen_parameters ): - model.enable_input_require_grads() + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + LOG.warning("Model does not have enable_input_require_grads method, skipping") return model, tokenizer, peft_config, processor @@ -246,9 +249,12 @@ def save_trained_model( LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.") # Post training module hooks - for name, module in model.named_modules(): - if hasattr(module, "_post_training"): - module._post_training(model, name) # pylint: disable=protected-access + if hasattr(model, "named_modules"): + for name, module in model.named_modules(): + if hasattr(module, "_post_training"): + module._post_training(model, name) # pylint: disable=protected-access + else: + LOG.warning("Model does not have named_modules attribute, skipping post training hooks") # handle QAT if cfg.qat: @@ -308,11 +314,17 @@ def save_trained_model( model = BetterTransformer.reverse(model) if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model: - trainer.model.save_pretrained( - cfg.output_dir, safe_serialization=safe_serialization - ) + if hasattr(trainer.model, "save_pretrained"): + trainer.model.save_pretrained( + cfg.output_dir, safe_serialization=safe_serialization + ) + else: + LOG.warning("Trainer model does not have save_pretrained method, skipping save") - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + if hasattr(model, "save_pretrained"): + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + else: + LOG.warning("Model does not have save_pretrained method, skipping save") if hasattr(cfg, "llmcompressor") and cfg.llmcompressor: # TODO: add integration support so this can be implemented completely within the plugin @@ -398,7 +410,10 @@ def save_initial_configs( tokenizer.save_pretrained(str(output_dir)) if hasattr(model, "config"): LOG.info(f"Pre-saving model config to {cfg.output_dir}...") - model.config.save_pretrained(str(output_dir)) + if hasattr(model.config, "save_pretrained"): + model.config.save_pretrained(str(output_dir)) + else: + LOG.warning("Model config does not have save_pretrained method, skipping config save") if processor: LOG.info(f"Pre-saving processor to {cfg.output_dir}...") @@ -461,9 +476,12 @@ def handle_untrained_tokens_fix( fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs) if cfg.local_rank == 0: - model.save_pretrained( - str(Path(cfg.output_dir)), safe_serialization=safe_serialization - ) + if hasattr(model, "save_pretrained"): + model.save_pretrained( + str(Path(cfg.output_dir)), safe_serialization=safe_serialization + ) + else: + LOG.warning("Model does not have save_pretrained method, skipping save") def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[