don't train if eval split is too small (#873)
* allow zero len dataset * better handling and warning of small eval splits * raise error if eval split is too small * don't mess with calculating total num steps in distributed context * fix eval_sample_packing training args logic
This commit is contained in:
@@ -658,7 +658,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["eval_sample_packing"] = (
|
training_arguments_kwargs["eval_sample_packing"] = (
|
||||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
self.cfg.sample_packing
|
||||||
|
if self.cfg.eval_sample_packing is not False
|
||||||
|
else False
|
||||||
)
|
)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sample_packing_seq_len_multiplier"
|
"sample_packing_seq_len_multiplier"
|
||||||
|
|||||||
@@ -79,6 +79,14 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
cfg, train_dataset, eval_dataset, tokenizer
|
cfg, train_dataset, eval_dataset, tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||||
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||||
|
if total_eval_steps == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.max_steps:
|
if cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
|
|
||||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||||
return max(
|
return max(
|
||||||
1,
|
0,
|
||||||
(
|
(
|
||||||
world_size
|
world_size
|
||||||
* math.floor(
|
* math.floor(
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
def calculate_total_num_steps(cfg, train_dataset):
|
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||||
if not cfg.total_num_tokens:
|
if not cfg.total_num_tokens:
|
||||||
total_num_tokens = np.sum(
|
total_num_tokens = np.sum(
|
||||||
train_dataset.data.column("input_ids")
|
train_dataset.data.column("input_ids")
|
||||||
@@ -150,7 +150,8 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|||||||
.values
|
.values
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
||||||
cfg.total_num_tokens = total_num_tokens
|
if update:
|
||||||
|
cfg.total_num_tokens = total_num_tokens
|
||||||
|
|
||||||
if not cfg.total_supervised_tokens:
|
if not cfg.total_supervised_tokens:
|
||||||
total_supervised_tokens = (
|
total_supervised_tokens = (
|
||||||
@@ -163,7 +164,8 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|||||||
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
cfg.total_supervised_tokens = total_supervised_tokens
|
if update:
|
||||||
|
cfg.total_supervised_tokens = total_supervised_tokens
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
# we have to drop anything longer then sequence len otherwise
|
# we have to drop anything longer then sequence len otherwise
|
||||||
@@ -232,7 +234,8 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|||||||
sample_packing_eff_est = (
|
sample_packing_eff_est = (
|
||||||
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
||||||
)
|
)
|
||||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
if update:
|
||||||
|
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user