optionally configure sample packing for evals (#589)
This commit is contained in:
@@ -117,6 +117,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use sample packing for efficient training."},
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
)
|
)
|
||||||
|
eval_sample_packing: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Use sample packing for efficient evals."},
|
||||||
|
)
|
||||||
sample_packing_efficiency: float = field(
|
sample_packing_efficiency: float = field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
@@ -212,7 +216,11 @@ class AxolotlTrainer(Trainer):
|
|||||||
def _get_eval_sampler(
|
def _get_eval_sampler(
|
||||||
self, eval_dataset: Dataset
|
self, eval_dataset: Dataset
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.world_size > 1 and self.args.sample_packing:
|
if (
|
||||||
|
self.args.world_size > 1
|
||||||
|
and self.args.sample_packing
|
||||||
|
and self.args.eval_sample_packing is not False
|
||||||
|
):
|
||||||
return SequentialDistributedSampler(
|
return SequentialDistributedSampler(
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
num_replicas=self.args.world_size,
|
num_replicas=self.args.world_size,
|
||||||
@@ -241,7 +249,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
def get_eval_dataloader(
|
def get_eval_dataloader(
|
||||||
self, eval_dataset: Optional[Dataset] = None
|
self, eval_dataset: Optional[Dataset] = None
|
||||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing:
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
)
|
)
|
||||||
@@ -659,6 +667,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
else "cosine",
|
else "cosine",
|
||||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||||
|
eval_sample_packing=cfg.eval_sample_packing,
|
||||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||||
relora_steps=cfg.relora_steps,
|
relora_steps=cfg.relora_steps,
|
||||||
relora_warmup_steps=cfg.relora_warmup_steps,
|
relora_warmup_steps=cfg.relora_warmup_steps,
|
||||||
|
|||||||
Reference in New Issue
Block a user