fix evals (#447)

This commit is contained in:
Wing Lian
2023-08-20 23:39:42 -04:00
committed by GitHub
parent 9d629d8bff
commit ee262818ef
3 changed files with 42 additions and 25 deletions

View File

@@ -169,7 +169,7 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:

View File

@@ -438,7 +438,7 @@ def load_llama_adapter(model, cfg):
) )
if cfg.lora_model_dir: if cfg.lora_model_dir:
LOG.info("Loading pretained LORA") LOG.debug("Loading pretained PEFT - llama_adapter")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.lora_model_dir,
@@ -500,6 +500,7 @@ def load_lora(model, cfg):
) )
if cfg.lora_model_dir: if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - LoRA")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.lora_model_dir,

View File

@@ -14,12 +14,15 @@ import bitsandbytes as bnb
import numpy as np import numpy as np
import torch.cuda import torch.cuda
import transformers import transformers
from datasets import set_caching_enabled from datasets import Dataset, set_caching_enabled
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import (
SequentialDistributedSampler,
get_parameter_names,
)
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
GPUStatsCallback, GPUStatsCallback,
@@ -171,6 +174,18 @@ class AxolotlTrainer(Trainer):
) )
return super()._get_train_sampler() return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
batch_size=self.args.per_device_eval_batch_size,
)
return super()._get_eval_sampler()
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing: if self.args.sample_packing:
train_sampler = self._get_train_sampler() train_sampler = self._get_train_sampler()
@@ -188,27 +203,28 @@ class AxolotlTrainer(Trainer):
) )
return super().get_train_dataloader() return super().get_train_dataloader()
# def get_eval_dataloader( def get_eval_dataloader(
# self, eval_dataset: Optional[Dataset] = None self, eval_dataset: Optional[Dataset] = None
# ) -> Union[DataLoader, MultipackDistributedDataloader]: ) -> Union[DataLoader, MultipackDistributedDataloader]:
# if self.args.sample_packing: if self.args.sample_packing:
# eval_dataset = ( eval_dataset = (
# eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset if eval_dataset is not None else self.eval_dataset
# ) )
# eval_sampler = self._get_eval_sampler(eval_dataset)
# return self.accelerator.prepare( eval_sampler = self._get_eval_sampler(eval_dataset)
# MultipackDistributedDataloader( return self.accelerator.prepare(
# eval_dataset, MultipackDistributedDataloader(
# batch_size=self.args.eval_batch_size, eval_dataset,
# seq_max_length=self.args.max_seq_length, batch_size=self.args.eval_batch_size,
# collate_fn=self.data_collator, seq_max_length=self.args.max_seq_length,
# sampler=eval_sampler, collate_fn=self.data_collator,
# packing_efficiency_estimate=self.args.sample_packing_efficiency, sampler=eval_sampler,
# sample_packing_seq_len_multiplier=self.args.eval_batch_size, packing_efficiency_estimate=self.args.sample_packing_efficiency,
# device_count=int(os.environ.get("WORLD_SIZE", 1)), sample_packing_seq_len_multiplier=self.args.eval_batch_size,
# ) device_count=int(os.environ.get("WORLD_SIZE", 1)),
# ) )
# return super().get_eval_dataloader(eval_dataset) )
return super().get_eval_dataloader(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc