swap the data collator for evals if not using sample packing (#1076)
* swap the data collator for evals if not using sample packing * drop last from dataloader to help with issues with evals
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
# pylint: disable=too-many-lines
|
||||||
"""
|
"""
|
||||||
Builder for the training args and trainer
|
Builder for the training args and trainer
|
||||||
"""
|
"""
|
||||||
@@ -137,10 +138,19 @@ class AxolotlTrainer(Trainer):
|
|||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: AxolotlTrainingArguments
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
|
|
||||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*_args,
|
||||||
|
num_epochs=1,
|
||||||
|
bench_data_collator=None,
|
||||||
|
eval_data_collator=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
self.num_epochs = num_epochs
|
self.num_epochs = num_epochs
|
||||||
self.bench_data_collator = bench_data_collator
|
self.bench_data_collator = bench_data_collator
|
||||||
super().__init__(*args, **kwargs)
|
self.eval_data_collator = eval_data_collator
|
||||||
|
super().__init__(*_args, **kwargs)
|
||||||
|
self.train_data_collator = self.data_collator
|
||||||
|
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
@@ -239,6 +249,16 @@ class AxolotlTrainer(Trainer):
|
|||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.eval_data_collator
|
||||||
|
)
|
||||||
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.train_data_collator
|
||||||
|
)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
@@ -269,6 +289,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
return self.accelerator.prepare_data_loader(
|
return self.accelerator.prepare_data_loader(
|
||||||
DataLoader(eval_dataset, **dataloader_params)
|
DataLoader(eval_dataset, **dataloader_params)
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def _get_bench_sampler(
|
def _get_bench_sampler(
|
||||||
@@ -651,6 +672,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"dataloader_prefetch_factor"
|
"dataloader_prefetch_factor"
|
||||||
] = self.cfg.dataloader_prefetch_factor
|
] = self.cfg.dataloader_prefetch_factor
|
||||||
|
if self.cfg.dataloader_drop_last is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"dataloader_drop_last"
|
||||||
|
] = self.cfg.dataloader_drop_last
|
||||||
|
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||||
|
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||||
|
|
||||||
if self.cfg.val_set_size == 0:
|
if self.cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
@@ -831,6 +858,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||||
|
eval_data_collator=self.build_collator(
|
||||||
|
training_args, is_eval=True, **data_collator_kwargs
|
||||||
|
),
|
||||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -851,14 +881,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
|
def build_collator(
|
||||||
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|
||||||
if training_args.sample_packing:
|
use_batch_sampler_collator = False
|
||||||
|
if is_eval is False and training_args.sample_packing:
|
||||||
|
use_batch_sampler_collator = True
|
||||||
|
if is_eval and training_args.eval_sample_packing:
|
||||||
|
use_batch_sampler_collator = True
|
||||||
|
|
||||||
|
if use_batch_sampler_collator:
|
||||||
return BatchSamplerDataCollatorForSeq2Seq(
|
return BatchSamplerDataCollatorForSeq2Seq(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
|
|||||||
Reference in New Issue
Block a user