Compare commits

..

5 Commits

Author SHA1 Message Date
Wing Lian
6c49083d8b improve check for base case 2025-01-24 12:02:34 -05:00
Wing Lian
94c226edb3 fixes last eos token not in labels on basic use case 2025-01-24 12:00:06 -05:00
Wing Lian
8fb72cbc0b use the extracted field_messages to parse the role fields (#2265) 2025-01-21 15:39:30 -05:00
Adithya Kamath
bb9d4102c4 Add 5000 line history limit to tmux for docker cloud (#2268) 2025-01-21 15:39:17 -05:00
Wing Lian
af727eedf7 option to not concatenate during pretraining (#2263)
* option to not concatenate during pretraining

* simplify conditional and add doc to config.qmd
2025-01-20 14:07:34 -05:00
6 changed files with 22 additions and 2 deletions

View File

@@ -20,7 +20,8 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"] ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"] CMD ["sleep", "infinity"]

View File

@@ -244,6 +244,8 @@ total_num_tokens:
sample_packing_group_size: 100000 sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200 sample_packing_bin_size: 200
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
# Use batch flattening for speedups when not using sample_packing # Use batch flattening for speedups when not using sample_packing
batch_flattening: batch_flattening:

View File

@@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):
if training_args.pretraining: if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":

View File

@@ -223,7 +223,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# Old simple legacy behavior that works reliably. # Old simple legacy behavior that works reliably.
if ( if (
not self.roles_to_train (not self.roles_to_train or self.roles_to_train == ["assistant"])
and not self.train_on_eos and not self.train_on_eos
and not self.prompter.message_field_training and not self.prompter.message_field_training
and not self.prompter.message_field_training_detail and not self.prompter.message_field_training_detail

View File

@@ -706,6 +706,12 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None batch_flattening: Optional[Union[Literal["auto"], bool]] = None

View File

@@ -22,6 +22,7 @@ def encode_pretraining(
max_tokens: int, max_tokens: int,
examples: Dict[str, List], examples: Dict[str, List],
text_column: str = "text", text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]: ) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples[text_column], examples[text_column],
@@ -33,6 +34,13 @@ def encode_pretraining(
input_ids = [torch.tensor(seq) for seq in res["input_ids"]] input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]] targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = [] new_input_ids = []
new_labels = [] new_labels = []
new_attention_mask = [] new_attention_mask = []
@@ -204,6 +212,7 @@ def wrap_pretraining_dataset(
tokenizer, tokenizer,
max_tokens, max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text", text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=cfg.pretraining_sample_concatenation is True,
) )
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets: