flex sample packing WIP
This commit is contained in:
@@ -151,6 +151,35 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlexBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
"""
|
||||||
|
Collator for multipack specific to Flex Attention using the BatchSampler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
if not isinstance(features[0], list):
|
||||||
|
features = [features]
|
||||||
|
out_features = [{} for _ in features]
|
||||||
|
for i, features_ in enumerate(features):
|
||||||
|
for feature in features_[0].keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
if feature == "attention_mask":
|
||||||
|
arrays = [
|
||||||
|
(i + 1) * np.array(item[feature])
|
||||||
|
for i, item in enumerate(features_)
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
else:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
|
]
|
||||||
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -185,9 +185,11 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
min_sequence_len=cfg.min_sample_len or 2,
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
dataset_lengths = get_dataset_lengths(train_dataset)
|
||||||
|
|
||||||
|
min_input_len = np.min(dataset_lengths)
|
||||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
max_input_len = np.max(dataset_lengths)
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
|
|
||||||
if cfg.model_config_type == "mamba":
|
if cfg.model_config_type == "mamba":
|
||||||
|
|||||||
Reference in New Issue
Block a user