don't train if eval split is too small (#873)

* allow zero len dataset

* better handling and warning of small eval splits

* raise error if eval split is too small

* don't mess with calculating total num steps in distributed context

* fix eval_sample_packing training args logic
This commit is contained in:
Wing Lian
2023-11-16 11:35:42 -05:00
committed by GitHub
parent 0de1457189
commit 797f3dd1de
4 changed files with 19 additions and 6 deletions

View File

@@ -658,7 +658,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
self.cfg.sample_packing
if self.cfg.eval_sample_packing is not False
else False
)
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"

View File

@@ -79,6 +79,14 @@ def prepare_dataset(cfg, tokenizer):
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset, tokenizer
)
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
raise ValueError(
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
)
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps

View File

@@ -182,7 +182,7 @@ class MultipackBatchSampler(BatchSampler):
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return max(
1,
0,
(
world_size
* math.floor(

View File

@@ -141,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset
def calculate_total_num_steps(cfg, train_dataset):
def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
@@ -150,7 +150,8 @@ def calculate_total_num_steps(cfg, train_dataset):
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
cfg.total_num_tokens = total_num_tokens
if update:
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
@@ -163,7 +164,8 @@ def calculate_total_num_steps(cfg, train_dataset):
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
cfg.total_supervised_tokens = total_supervised_tokens
if update:
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
@@ -232,7 +234,8 @@ def calculate_total_num_steps(cfg, train_dataset):
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
)
cfg.sample_packing_eff_est = sample_packing_eff_est
if update:
cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.debug(
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
main_process_only=True,