From a6f2c5d583a82ab97a7df7e1402df8b3f7b35dba Mon Sep 17 00:00:00 2001 From: Sunny Date: Wed, 15 Jan 2025 21:12:33 -0500 Subject: [PATCH] flex sample packing WIP --- src/axolotl/utils/collators/batching.py | 29 +++++++++++++++++++++++++ src/axolotl/utils/trainer.py | 6 +++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 7cf771421..13a6e1967 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -151,6 +151,35 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): return super().__call__(out_features, return_tensors=return_tensors) +@dataclass +class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): + """ + Collator for multipack specific to Flex Attention using the BatchSampler + """ + + def __call__(self, features, return_tensors=None): + if not isinstance(features[0], list): + features = [features] + out_features = [{} for _ in features] + for i, features_ in enumerate(features): + for feature in features_[0].keys(): + if feature == "length": + continue + if feature == "attention_mask": + arrays = [ + (i + 1) * np.array(item[feature]) + for i, item in enumerate(features_) + if feature in item + ] + out_features[i][feature] = np.concatenate(arrays) + else: + arrays = [ + np.array(item[feature]) for item in features_ if feature in item + ] + out_features[i][feature] = np.concatenate(arrays) + return super().__call__(out_features, return_tensors=return_tensors) + + @dataclass class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): """ diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index de8fff625..24398d944 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -185,9 +185,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): min_sequence_len=cfg.min_sample_len or 2, ) - min_input_len = np.min(get_dataset_lengths(train_dataset)) + dataset_lengths = get_dataset_lengths(train_dataset) + + min_input_len = np.min(dataset_lengths) LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) - max_input_len = np.max(get_dataset_lengths(train_dataset)) + max_input_len = np.max(dataset_lengths) LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) if cfg.model_config_type == "mamba":