option to not concatenate during pretraining
This commit is contained in:
@@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
|
if self.cfg.pretraining_sample_concatenation is False:
|
||||||
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
|
|||||||
@@ -698,6 +698,12 @@ class AxolotlInputConfig(
|
|||||||
pad_to_sequence_len: Optional[bool] = None
|
pad_to_sequence_len: Optional[bool] = None
|
||||||
curriculum_sampling: Optional[bool] = None
|
curriculum_sampling: Optional[bool] = None
|
||||||
multipack_real_batches: Optional[bool] = None
|
multipack_real_batches: Optional[bool] = None
|
||||||
|
pretraining_sample_concatenation: Optional[bool] = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "whether to soft pack/concatenate samples during pretraining",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
batch_flattening: Optional[Union[Literal["auto"], bool]] = None
|
batch_flattening: Optional[Union[Literal["auto"], bool]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,10 @@ LOG = logging.getLogger("axolotl")
|
|||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
max_tokens: int,
|
||||||
|
examples: Dict[str, List],
|
||||||
|
concatenate: bool = True,
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
res = tokenizer(
|
res = tokenizer(
|
||||||
examples["text"],
|
examples["text"],
|
||||||
@@ -30,6 +33,13 @@ def encode_pretraining(
|
|||||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||||
targets = [torch.tensor(seq) for seq in res["input_ids"]]
|
targets = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||||
|
if not concatenate:
|
||||||
|
return {
|
||||||
|
"input_ids": [seq.tolist() for seq in input_ids],
|
||||||
|
"labels": [seq.tolist() for seq in targets],
|
||||||
|
"attention_mask": [seq.tolist() for seq in attention_mask],
|
||||||
|
}
|
||||||
|
|
||||||
new_input_ids = []
|
new_input_ids = []
|
||||||
new_labels = []
|
new_labels = []
|
||||||
new_attention_mask = []
|
new_attention_mask = []
|
||||||
@@ -195,6 +205,10 @@ def wrap_pretraining_dataset(
|
|||||||
)
|
)
|
||||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||||
cfg.micro_batch_size = 1
|
cfg.micro_batch_size = 1
|
||||||
|
elif cfg.pretraining_sample_concatenation is False:
|
||||||
|
encode = functools.partial(
|
||||||
|
encode_pretraining, tokenizer, max_tokens, concatenate=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user