From ee262818ef7f142012f3754514dd666b8aab3b27 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 20 Aug 2023 23:39:42 -0400 Subject: [PATCH] fix evals (#447) --- .../monkeypatch/llama_attn_hijack_flash.py | 2 +- src/axolotl/utils/models.py | 3 +- src/axolotl/utils/trainer.py | 62 ++++++++++++------- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 14056fa54..a445c3a5a 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -169,7 +169,7 @@ def flashattn_forward( qkv = rearrange(qkv, "b s ... -> (b s) ...") output = flash_attn_varlen_qkvpacked_func( - qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal + qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) elif query_states.shape == key_states.shape: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 66cf70b64..8e5445fff 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -438,7 +438,7 @@ def load_llama_adapter(model, cfg): ) if cfg.lora_model_dir: - LOG.info("Loading pretained LORA") + LOG.debug("Loading pretained PEFT - llama_adapter") model = PeftModel.from_pretrained( model, cfg.lora_model_dir, @@ -500,6 +500,7 @@ def load_lora(model, cfg): ) if cfg.lora_model_dir: + LOG.debug("Loading pretained PEFT - LoRA") model = PeftModel.from_pretrained( model, cfg.lora_model_dir, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 5f24e13c0..5245011f6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -14,12 +14,15 @@ import bitsandbytes as bnb import numpy as np import torch.cuda import transformers -from datasets import set_caching_enabled +from datasets import Dataset, set_caching_enabled from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments -from transformers.trainer_pt_utils import get_parameter_names +from transformers.trainer_pt_utils import ( + SequentialDistributedSampler, + get_parameter_names, +) from axolotl.utils.callbacks import ( GPUStatsCallback, @@ -171,6 +174,18 @@ class AxolotlTrainer(Trainer): ) return super()._get_train_sampler() + def _get_eval_sampler( + self, eval_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size > 1 and self.args.sample_packing: + return SequentialDistributedSampler( + eval_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + batch_size=self.args.per_device_eval_batch_size, + ) + return super()._get_eval_sampler() + def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: if self.args.sample_packing: train_sampler = self._get_train_sampler() @@ -188,27 +203,28 @@ class AxolotlTrainer(Trainer): ) return super().get_train_dataloader() - # def get_eval_dataloader( - # self, eval_dataset: Optional[Dataset] = None - # ) -> Union[DataLoader, MultipackDistributedDataloader]: - # if self.args.sample_packing: - # eval_dataset = ( - # eval_dataset if eval_dataset is not None else self.eval_dataset - # ) - # eval_sampler = self._get_eval_sampler(eval_dataset) - # return self.accelerator.prepare( - # MultipackDistributedDataloader( - # eval_dataset, - # batch_size=self.args.eval_batch_size, - # seq_max_length=self.args.max_seq_length, - # collate_fn=self.data_collator, - # sampler=eval_sampler, - # packing_efficiency_estimate=self.args.sample_packing_efficiency, - # sample_packing_seq_len_multiplier=self.args.eval_batch_size, - # device_count=int(os.environ.get("WORLD_SIZE", 1)), - # ) - # ) - # return super().get_eval_dataloader(eval_dataset) + def get_eval_dataloader( + self, eval_dataset: Optional[Dataset] = None + ) -> Union[DataLoader, MultipackDistributedDataloader]: + if self.args.sample_packing: + eval_dataset = ( + eval_dataset if eval_dataset is not None else self.eval_dataset + ) + + eval_sampler = self._get_eval_sampler(eval_dataset) + return self.accelerator.prepare( + MultipackDistributedDataloader( + eval_dataset, + batch_size=self.args.eval_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=eval_sampler, + packing_efficiency_estimate=self.args.sample_packing_efficiency, + sample_packing_seq_len_multiplier=self.args.eval_batch_size, + device_count=int(os.environ.get("WORLD_SIZE", 1)), + ) + ) + return super().get_eval_dataloader(eval_dataset) def compute_loss(self, model, inputs, return_outputs=False): # use one's weighted cross entropy loss calc