diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index d0b8ab743..aa88e9924 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -334,7 +334,10 @@ def _load_raw_datasets( dataset = merge_datasets(datasets, cfg) if not cfg.skip_prepare_dataset: - dataset = drop_long_seq_in_dataset(dataset, cfg) + if split == "test" and cfg.eval_sequence_len: + dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) + else: + dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg) if cfg.sample_packing: dataset, _ = process_datasets_for_packing(cfg, dataset, None) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 4f7f6f8dd..c0efb7a42 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -148,11 +148,14 @@ def deduplicate_and_log_datasets( return dataset, other_dataset -def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset: +def drop_long_seq_in_dataset( + dataset: Dataset, sequence_len: int, cfg: DictDefault +) -> Dataset: """Remove sequences longer than configured maximum from dataset. Args: dataset: Dataset to filter. + sequence_len: Maximum length for sequences to keep cfg: Dictionary mapping `axolotl` config keys to values. Returns: @@ -167,7 +170,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset: drop_long = functools.partial( drop_long_seq, - sequence_len=cfg.sequence_len, + sequence_len=sequence_len, min_sequence_len=cfg.min_sample_len, ) @@ -187,7 +190,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset: drop_long_kwargs = {} if filter_map_kwargs: - drop_long_kwargs["desc"] = "Dropping Long Sequences" + drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" dataset = dataset.filter( drop_long, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index da0d3c935..1530fabe0 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -366,6 +366,12 @@ class AxolotlInputConfig( "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" }, ) + eval_sequence_len: int | None = Field( + default=None, + json_schema_extra={ + "description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len" + }, + ) min_sample_len: int | None = None max_prompt_len: int = Field( default=512, diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index d91f63d94..7cb645db7 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -70,7 +70,7 @@ class TestBatchedSamplerPacking: ) train_dataset = concatenate_datasets([dataset_wrapper]) - train_dataset = drop_long_seq_in_dataset(train_dataset, cfg) + train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg) lengths = get_dataset_lengths(train_dataset) batch_sampler = MultipackBatchSampler(