Pretrain multipack v2 (#1470)
This commit is contained in:
@@ -40,3 +40,4 @@ gcsfs
|
|||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
|
||||||
|
zstandard==0.22.0
|
||||||
|
|||||||
@@ -217,13 +217,24 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
Collator for multipack specific to the using the BatchSampler
|
Collator for multipack specific to the using the BatchSampler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, multipack_attn=True, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.multipack_attn = multipack_attn
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
chunked_data = {}
|
chunked_data = {}
|
||||||
for feature in features.keys():
|
for feature in features.keys():
|
||||||
if feature == "length":
|
if feature == "length":
|
||||||
continue
|
continue
|
||||||
if feature == "attention_mask":
|
if feature == "attention_mask":
|
||||||
arrays = [(1) * np.array(item) for item in features[feature]]
|
if self.multipack_attn:
|
||||||
|
arrays = [
|
||||||
|
(i + 1) * np.array(item[feature])
|
||||||
|
for i, item in enumerate(features[feature])
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
arrays = [(1) * np.array(item) for item in features[feature]]
|
||||||
chunked_data[feature] = np.concatenate(arrays)
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
else:
|
else:
|
||||||
arrays = [np.array(item) for item in features[feature]]
|
arrays = [np.array(item) for item in features[feature]]
|
||||||
|
|||||||
@@ -511,6 +511,14 @@ class AxolotlInputConfig(
|
|||||||
eval_sample_packing: Optional[bool] = None
|
eval_sample_packing: Optional[bool] = None
|
||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
|
|
||||||
|
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||||
|
pretrain_multipack_attn: Optional[bool] = Field(
|
||||||
|
default=True,
|
||||||
|
metadata={
|
||||||
|
"help": "whether to prevent cross attention for packed sequences during pretraining",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
xformers_attention: Optional[bool] = None
|
xformers_attention: Optional[bool] = None
|
||||||
sdp_attention: Optional[bool] = None
|
sdp_attention: Optional[bool] = None
|
||||||
s2_attention: Optional[bool] = None
|
s2_attention: Optional[bool] = None
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
max_tokens=cfg.sequence_len,
|
max_tokens=cfg.sequence_len,
|
||||||
batch_size=cfg.micro_batch_size,
|
batch_size=cfg.micro_batch_size,
|
||||||
seed=cfg.seed or 42,
|
seed=cfg.seed or 42,
|
||||||
|
buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
|
||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
@@ -816,6 +817,7 @@ def wrap_pretraining_dataset(
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
padding=True,
|
padding=True,
|
||||||
pad_to_multiple_of=max_tokens * batch_size,
|
pad_to_multiple_of=max_tokens * batch_size,
|
||||||
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
)
|
)
|
||||||
encode = functools.partial(
|
encode = functools.partial(
|
||||||
encode_packed_pretraining,
|
encode_packed_pretraining,
|
||||||
@@ -823,6 +825,7 @@ def wrap_pretraining_dataset(
|
|||||||
ds_wrapper_fn,
|
ds_wrapper_fn,
|
||||||
max_seq_length=max_tokens,
|
max_seq_length=max_tokens,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
)
|
)
|
||||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||||
cfg.micro_batch_size = 1
|
cfg.micro_batch_size = 1
|
||||||
@@ -861,6 +864,7 @@ def encode_packed_pretraining(
|
|||||||
examples: Dict[str, List],
|
examples: Dict[str, List],
|
||||||
max_seq_length: int = 2048,
|
max_seq_length: int = 2048,
|
||||||
batch_size: int = 4,
|
batch_size: int = 4,
|
||||||
|
multipack_attn: Optional[bool] = False,
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
# tokenize all the examples
|
# tokenize all the examples
|
||||||
@@ -868,7 +872,9 @@ def encode_packed_pretraining(
|
|||||||
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
||||||
|
|
||||||
train_dataset = process_pretraining_datasets_for_packing(
|
train_dataset = process_pretraining_datasets_for_packing(
|
||||||
train_dataset, max_seq_length
|
train_dataset,
|
||||||
|
max_seq_length,
|
||||||
|
skip_position_ids=not multipack_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler = MultipackBatchSampler(
|
sampler = MultipackBatchSampler(
|
||||||
|
|||||||
@@ -172,17 +172,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
|
def process_pretraining_datasets_for_packing(
|
||||||
|
train_dataset, sequence_len, skip_position_ids=True
|
||||||
|
):
|
||||||
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
||||||
|
|
||||||
train_dataset = train_dataset.filter(
|
train_dataset = train_dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
desc="Dropping Long Sequences",
|
desc="Dropping Long Sequences",
|
||||||
)
|
)
|
||||||
train_dataset = train_dataset.map(
|
if skip_position_ids:
|
||||||
add_position_ids,
|
train_dataset = train_dataset.map(
|
||||||
desc="Add position_id column (Pretraining Sample Packing)",
|
add_position_ids,
|
||||||
)
|
desc="Add position_id column (Pretraining Sample Packing)",
|
||||||
|
)
|
||||||
|
|
||||||
return train_dataset
|
return train_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user