flex sample packing WIP

This commit is contained in:
Sunny
2025-01-15 21:12:33 -05:00
parent dbcd11e533
commit a6f2c5d583
2 changed files with 33 additions and 2 deletions

View File

@@ -151,6 +151,35 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
return super().__call__(out_features, return_tensors=return_tensors)
@dataclass
class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to Flex Attention using the BatchSampler
"""
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(i + 1) * np.array(item[feature])
for i, item in enumerate(features_)
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)
@dataclass
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""

View File

@@ -185,9 +185,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
min_sequence_len=cfg.min_sample_len or 2,
)
min_input_len = np.min(get_dataset_lengths(train_dataset))
dataset_lengths = get_dataset_lengths(train_dataset)
min_input_len = np.min(dataset_lengths)
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
max_input_len = np.max(dataset_lengths)
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "mamba":