From 53f93f67bb699fb3bde36995c3502f31bab2704d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 29 Oct 2023 06:08:38 -0400 Subject: [PATCH] fix to set training args so projector properly saves --- src/axolotl/core/trainer_builder.py | 32 +++++++++++++++++++++-------- src/axolotl/train.py | 6 +++--- src/axolotl/utils/trainer.py | 6 ++++-- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 01bcf359b..23370a953 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -271,21 +271,26 @@ class AxolotlTrainer(Trainer): run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) - # Only save Adapter - keys_to_match = ["mm_projector", "vision_resampler"] - if getattr(self.args, "use_im_start_end", False): - keys_to_match.extend(["embed_tokens", "embed_in"]) - - weight_to_save = get_mm_adapter_state_maybe_zero_3( - self.model.named_parameters(), keys_to_match - ) + weights_to_save = self._get_mm_mlp_adapter_weights() if self.args.local_rank in (0, -1): self.model.config.save_pretrained(output_dir) - torch.save(weight_to_save, os.path.join(output_dir, "mm_projector.bin")) + torch.save( + weights_to_save, os.path.join(output_dir, "mm_projector.bin") + ) else: super()._save_checkpoint(model, trial, metrics) + def _get_mm_mlp_adapter_weights(self): + # Only save Adapter + keys_to_match = ["mm_projector", "vision_resampler"] + if getattr(self.args, "use_im_start_end", False): + keys_to_match.extend(["embed_tokens", "embed_in"]) + + return get_mm_adapter_state_maybe_zero_3( + self.model.named_parameters(), keys_to_match + ) + def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, "tune_mm_mlp_adapter", False): pass @@ -659,8 +664,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "sample_packing_seq_len_multiplier" ] = self.cfg.micro_batch_size + training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps + + # multimodal: llava + training_arguments_kwargs["tune_mm_mlp_adapter"] = self.cfg.tune_mm_mlp_adapter + training_arguments_kwargs[ + "freeze_mm_mlp_adapter" + ] = self.cfg.freeze_mm_mlp_adapter + training_arguments_kwargs["mm_projector_lr"] = self.cfg.mm_projector_lr + training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 54e8972e4..c5939e753 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -159,14 +159,14 @@ def train( # The model name saved is `pytorch_model.bin` unwrapped_model.save_pretrained( cfg.output_dir, - is_main_process=trainer.accelerator.is_main_process, + is_main_process=trainer.args.should_save, save_function=trainer.accelerator.save, state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped), ) - elif cfg.local_rank == 0: + elif trainer.args.should_save: if cfg.flash_optimum: model = BetterTransformer.reverse(model) - + # TODO figure out if `trainer.save_model(cfg.output_dir)` is sufficient here model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if not cfg.hub_model_id: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 0d275cbf5..bfea684de 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -13,7 +13,7 @@ import torch.distributed as dist from datasets import set_caching_enabled from torch.utils.data import DistributedSampler, RandomSampler -from axolotl.core.trainer_builder import HFCausalTrainerBuilder +from axolotl.core.trainer_builder import AxolotlTrainer, HFCausalTrainerBuilder from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.distributed import ( @@ -259,7 +259,9 @@ def setup_fsdp_envs(cfg): ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): +def setup_trainer( + cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps +) -> AxolotlTrainer: if cfg.fsdp: setup_fsdp_envs(cfg) elif cfg.deepspeed: