option to not concatenate during pretraining (#2263)
* option to not concatenate during pretraining * simplify conditional and add doc to config.qmd
This commit is contained in:
@@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if self.cfg.pretraining_sample_concatenation is False:
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
return None
|
||||
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
|
||||
@@ -706,6 +706,12 @@ class AxolotlInputConfig(
|
||||
pad_to_sequence_len: Optional[bool] = None
|
||||
curriculum_sampling: Optional[bool] = None
|
||||
multipack_real_batches: Optional[bool] = None
|
||||
pretraining_sample_concatenation: Optional[bool] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "whether to soft pack/concatenate samples during pretraining",
|
||||
},
|
||||
)
|
||||
|
||||
batch_flattening: Optional[Union[Literal["auto"], bool]] = None
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ def encode_pretraining(
|
||||
max_tokens: int,
|
||||
examples: Dict[str, List],
|
||||
text_column: str = "text",
|
||||
concatenate: bool = True,
|
||||
) -> Dict[str, List]:
|
||||
res = tokenizer(
|
||||
examples[text_column],
|
||||
@@ -33,6 +34,13 @@ def encode_pretraining(
|
||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||
targets = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||
if not concatenate:
|
||||
return {
|
||||
"input_ids": [seq.tolist() for seq in input_ids],
|
||||
"labels": [seq.tolist() for seq in targets],
|
||||
"attention_mask": [seq.tolist() for seq in attention_mask],
|
||||
}
|
||||
|
||||
new_input_ids = []
|
||||
new_labels = []
|
||||
new_attention_mask = []
|
||||
@@ -204,6 +212,7 @@ def wrap_pretraining_dataset(
|
||||
tokenizer,
|
||||
max_tokens,
|
||||
text_column=cfg.pretraining_dataset[0].text_column or "text",
|
||||
concatenate=cfg.pretraining_sample_concatenation is True,
|
||||
)
|
||||
|
||||
if cfg.shuffle_merged_datasets:
|
||||
|
||||
Reference in New Issue
Block a user