From 797f3dd1de8fd8c0eafbd1c9fdb172abd9ff840a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 16 Nov 2023 11:35:42 -0500 Subject: [PATCH] don't train if eval split is too small (#873) * allow zero len dataset * better handling and warning of small eval splits * raise error if eval split is too small * don't mess with calculating total num steps in distributed context * fix eval_sample_packing training args logic --- src/axolotl/core/trainer_builder.py | 4 +++- src/axolotl/utils/data.py | 8 ++++++++ src/axolotl/utils/samplers/multipack.py | 2 +- src/axolotl/utils/trainer.py | 11 +++++++---- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index bcd5e3219..6b78f1f1a 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -658,7 +658,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.sample_packing if self.cfg.sample_packing else False ) training_arguments_kwargs["eval_sample_packing"] = ( - self.cfg.sample_packing if self.cfg.sample_packing else False + self.cfg.sample_packing + if self.cfg.eval_sample_packing is not False + else False ) training_arguments_kwargs[ "sample_packing_seq_len_multiplier" diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index a62b34e1d..49b36202c 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -79,6 +79,14 @@ def prepare_dataset(cfg, tokenizer): train_dataset, eval_dataset = process_datasets_for_packing( cfg, train_dataset, eval_dataset, tokenizer ) + + if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: + total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) + if total_eval_steps == 0: + raise ValueError( + "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. " + ) + if cfg.max_steps: total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 1acaa51b9..629a1a44c 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -182,7 +182,7 @@ class MultipackBatchSampler(BatchSampler): # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler return max( - 1, + 0, ( world_size * math.floor( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index cac760700..6d09a4559 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -141,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): return train_dataset, eval_dataset -def calculate_total_num_steps(cfg, train_dataset): +def calculate_total_num_steps(cfg, train_dataset, update=True): if not cfg.total_num_tokens: total_num_tokens = np.sum( train_dataset.data.column("input_ids") @@ -150,7 +150,8 @@ def calculate_total_num_steps(cfg, train_dataset): .values ) LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True) - cfg.total_num_tokens = total_num_tokens + if update: + cfg.total_num_tokens = total_num_tokens if not cfg.total_supervised_tokens: total_supervised_tokens = ( @@ -163,7 +164,8 @@ def calculate_total_num_steps(cfg, train_dataset): f"`total_supervised_tokens: {total_supervised_tokens}`", main_process_only=True, ) - cfg.total_supervised_tokens = total_supervised_tokens + if update: + cfg.total_supervised_tokens = total_supervised_tokens if cfg.sample_packing: # we have to drop anything longer then sequence len otherwise @@ -232,7 +234,8 @@ def calculate_total_num_steps(cfg, train_dataset): sample_packing_eff_est = ( math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 ) - cfg.sample_packing_eff_est = sample_packing_eff_est + if update: + cfg.sample_packing_eff_est = sample_packing_eff_est LOG.debug( f"sample_packing_eff_est: {cfg.sample_packing_eff_est}", main_process_only=True,