fix patch to check position ids and don't use multipack for evals
This commit is contained in:
@@ -149,7 +149,7 @@ def flashattn_forward(
|
|||||||
# only on first autoregressive step q,k,v have same seqlen
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
is_causal = past_key_value is not None
|
is_causal = past_key_value is not None
|
||||||
|
|
||||||
if self.training and attention_mask.shape[0] == 1:
|
if self.training and position_ids.shape[0] == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
qkv = torch.stack(
|
||||||
[query_states, key_states, value_states], dim=2
|
[query_states, key_states, value_states], dim=2
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import bitsandbytes as bnb
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset, set_caching_enabled
|
from datasets import set_caching_enabled
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||||
@@ -188,27 +188,27 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
def get_eval_dataloader(
|
# def get_eval_dataloader(
|
||||||
self, eval_dataset: Optional[Dataset] = None
|
# self, eval_dataset: Optional[Dataset] = None
|
||||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
# ) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing:
|
# if self.args.sample_packing:
|
||||||
eval_dataset = (
|
# eval_dataset = (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
# eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
)
|
# )
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
# eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
return self.accelerator.prepare(
|
# return self.accelerator.prepare(
|
||||||
MultipackDistributedDataloader(
|
# MultipackDistributedDataloader(
|
||||||
eval_dataset,
|
# eval_dataset,
|
||||||
batch_size=self.args.eval_batch_size,
|
# batch_size=self.args.eval_batch_size,
|
||||||
seq_max_length=self.args.max_seq_length,
|
# seq_max_length=self.args.max_seq_length,
|
||||||
collate_fn=self.data_collator,
|
# collate_fn=self.data_collator,
|
||||||
sampler=eval_sampler,
|
# sampler=eval_sampler,
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
# packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
# sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
)
|
# )
|
||||||
)
|
# )
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
# return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
# use one's weighted cross entropy loss calc
|
# use one's weighted cross entropy loss calc
|
||||||
|
|||||||
Reference in New Issue
Block a user