fix evals (#447)
This commit is contained in:
@@ -169,7 +169,7 @@ def flashattn_forward(
|
|||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal
|
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif query_states.shape == key_states.shape:
|
elif query_states.shape == key_states.shape:
|
||||||
|
|||||||
@@ -438,7 +438,7 @@ def load_llama_adapter(model, cfg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.info("Loading pretained LORA")
|
LOG.debug("Loading pretained PEFT - llama_adapter")
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.lora_model_dir,
|
cfg.lora_model_dir,
|
||||||
@@ -500,6 +500,7 @@ def load_lora(model, cfg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
|
LOG.debug("Loading pretained PEFT - LoRA")
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.lora_model_dir,
|
cfg.lora_model_dir,
|
||||||
|
|||||||
@@ -14,12 +14,15 @@ import bitsandbytes as bnb
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import set_caching_enabled
|
from datasets import Dataset, set_caching_enabled
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import (
|
||||||
|
SequentialDistributedSampler,
|
||||||
|
get_parameter_names,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
@@ -171,6 +174,18 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
|
def _get_eval_sampler(
|
||||||
|
self, eval_dataset: Dataset
|
||||||
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
|
if self.args.world_size > 1 and self.args.sample_packing:
|
||||||
|
return SequentialDistributedSampler(
|
||||||
|
eval_dataset,
|
||||||
|
num_replicas=self.args.world_size,
|
||||||
|
rank=self.args.process_index,
|
||||||
|
batch_size=self.args.per_device_eval_batch_size,
|
||||||
|
)
|
||||||
|
return super()._get_eval_sampler()
|
||||||
|
|
||||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
train_sampler = self._get_train_sampler()
|
train_sampler = self._get_train_sampler()
|
||||||
@@ -188,27 +203,28 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
# def get_eval_dataloader(
|
def get_eval_dataloader(
|
||||||
# self, eval_dataset: Optional[Dataset] = None
|
self, eval_dataset: Optional[Dataset] = None
|
||||||
# ) -> Union[DataLoader, MultipackDistributedDataloader]:
|
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
# if self.args.sample_packing:
|
if self.args.sample_packing:
|
||||||
# eval_dataset = (
|
eval_dataset = (
|
||||||
# eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
# )
|
)
|
||||||
# eval_sampler = self._get_eval_sampler(eval_dataset)
|
|
||||||
# return self.accelerator.prepare(
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
# MultipackDistributedDataloader(
|
return self.accelerator.prepare(
|
||||||
# eval_dataset,
|
MultipackDistributedDataloader(
|
||||||
# batch_size=self.args.eval_batch_size,
|
eval_dataset,
|
||||||
# seq_max_length=self.args.max_seq_length,
|
batch_size=self.args.eval_batch_size,
|
||||||
# collate_fn=self.data_collator,
|
seq_max_length=self.args.max_seq_length,
|
||||||
# sampler=eval_sampler,
|
collate_fn=self.data_collator,
|
||||||
# packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
sampler=eval_sampler,
|
||||||
# sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||||
# )
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
# )
|
)
|
||||||
# return super().get_eval_dataloader(eval_dataset)
|
)
|
||||||
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
# use one's weighted cross entropy loss calc
|
# use one's weighted cross entropy loss calc
|
||||||
|
|||||||
Reference in New Issue
Block a user