allow for different sequence_len for evaluations (#2836) [skip ci]
* allow for different sequence_len for evaluations
* reversed 🤦
* add more information to filter msg
This commit is contained in:
@@ -334,7 +334,10 @@ def _load_raw_datasets(
|
||||
dataset = merge_datasets(datasets, cfg)
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
dataset = drop_long_seq_in_dataset(dataset, cfg)
|
||||
if split == "test" and cfg.eval_sequence_len:
|
||||
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||
else:
|
||||
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||
if cfg.sample_packing:
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
|
||||
@@ -148,11 +148,14 @@ def deduplicate_and_log_datasets(
|
||||
return dataset, other_dataset
|
||||
|
||||
|
||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
||||
def drop_long_seq_in_dataset(
|
||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||
) -> Dataset:
|
||||
"""Remove sequences longer than configured maximum from dataset.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to filter.
|
||||
sequence_len: Maximum length for sequences to keep
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
@@ -167,7 +170,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
||||
|
||||
drop_long = functools.partial(
|
||||
drop_long_seq,
|
||||
sequence_len=cfg.sequence_len,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
)
|
||||
|
||||
@@ -187,7 +190,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||
|
||||
dataset = dataset.filter(
|
||||
drop_long,
|
||||
|
||||
@@ -366,6 +366,12 @@ class AxolotlInputConfig(
|
||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||
},
|
||||
)
|
||||
eval_sequence_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
|
||||
},
|
||||
)
|
||||
min_sample_len: int | None = None
|
||||
max_prompt_len: int = Field(
|
||||
default=512,
|
||||
|
||||
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
|
||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
|
||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||
|
||||
lengths = get_dataset_lengths(train_dataset)
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
|
||||
Reference in New Issue
Block a user