Compare commits

..

9 Commits

Author SHA1 Message Date
Wing Lian
6c49083d8b improve check for base case 2025-01-24 12:02:34 -05:00
Wing Lian
94c226edb3 fixes last eos token not in labels on basic use case 2025-01-24 12:00:06 -05:00
Wing Lian
8fb72cbc0b use the extracted field_messages to parse the role fields (#2265) 2025-01-21 15:39:30 -05:00
Adithya Kamath
bb9d4102c4 Add 5000 line history limit to tmux for docker cloud (#2268) 2025-01-21 15:39:17 -05:00
Wing Lian
af727eedf7 option to not concatenate during pretraining (#2263)
* option to not concatenate during pretraining

* simplify conditional and add doc to config.qmd
2025-01-20 14:07:34 -05:00
jwongTensora
8606093921 fix for indexing error from token/embeddings mismatch (#2257)
Co-authored-by: jwong <jwongTensora@gmail.com>
2025-01-14 22:09:29 -05:00
NanoCode012
cba5a457d9 fix: use text_column even when not packing for pretraining (#2254)
* fix: use text_column even when not packing for pretraining

* feat: update test to check when not packing

* chore: lint

* Update src/axolotl/utils/data/pretraining.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-01-14 22:08:56 -05:00
Wing Lian
19cd83d408 rename references to dpo dataset prep to pref data (#2258) 2025-01-14 22:07:55 -05:00
Dan Saunders
1ed4de73b6 CLI cleanup and documentation (#2244)
* CLI init refactor

* fix

* cleanup and (partial) docs

* Adding documentation and continuing cleanup (in progress)

* remove finetune.py script

* continued cleanup and documentation

* pytest fixes

* review comments

* fix

* Fix

* typing fixes

* make sure the batch dataset patcher for multipack is always loaded when handling datasets

* review comments

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-01-13 17:55:29 +00:00
14 changed files with 58 additions and 26 deletions

View File

@@ -20,7 +20,8 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"] ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"] CMD ["sleep", "infinity"]

View File

@@ -244,6 +244,8 @@ total_num_tokens:
sample_packing_group_size: 100000 sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200 sample_packing_bin_size: 200
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
# Use batch flattening for speedups when not using sample_packing # Use batch flattening for speedups when not using sample_packing
batch_flattening: batch_flattening:

View File

@@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"):
) )
ds_cfg["field_messages"] = field_messages ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys() message_fields = features[field_messages][0].keys()
message_field_role = None message_field_role = None
for key in ["from", "role"]: for key in ["from", "role"]:
if key in message_fields: if key in message_fields:

View File

@@ -11,7 +11,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
@@ -103,9 +103,9 @@ def load_preference_datasets(
cli_args: Union[PreprocessCliArgs, TrainerCliArgs], cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """
Loads one or more training or evaluation datasets for DPO training, calling Loads one or more training or evaluation datasets for RL training using paired
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
information. Optionally, logs out debug information.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -115,7 +115,7 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`. `total_num_steps`.
""" """
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )

View File

@@ -1877,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):
if training_args.pretraining: if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":

View File

@@ -223,7 +223,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# Old simple legacy behavior that works reliably. # Old simple legacy behavior that works reliably.
if ( if (
not self.roles_to_train (not self.roles_to_train or self.roles_to_train == ["assistant"])
and not self.train_on_eos and not self.train_on_eos
and not self.prompter.message_field_training and not self.prompter.message_field_training
and not self.prompter.message_field_training_detail and not self.prompter.message_field_training_detail

View File

@@ -706,6 +706,12 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None batch_flattening: Optional[Union[Literal["auto"], bool]] = None

View File

@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,
) )
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper, get_dataset_wrapper,
load_prepare_datasets, load_prepare_datasets,

View File

@@ -18,10 +18,14 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining( def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]: ) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples["text"], examples[text_column],
truncation=True, truncation=True,
max_length=max_tokens - 2, max_length=max_tokens - 2,
add_special_tokens=True, add_special_tokens=True,
@@ -30,6 +34,13 @@ def encode_pretraining(
input_ids = [torch.tensor(seq) for seq in res["input_ids"]] input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]] targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = [] new_input_ids = []
new_labels = [] new_labels = []
new_attention_mask = [] new_attention_mask = []
@@ -196,7 +207,13 @@ def wrap_pretraining_dataset(
# set this to 1 so downstream data_loader doesn't try to increase the batch again # set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1 cfg.micro_batch_size = 1
else: else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=cfg.pretraining_sample_concatenation is True,
)
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)

View File

@@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg): def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg): def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = [] split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs): for i, ds_cfg in enumerate(dataset_cfgs):

View File

@@ -1057,7 +1057,7 @@ class ModelLoader:
) )
if ( if (
hasattr(self.model, "get_input_embeddings") hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings < embeddings_len and self.model.get_input_embeddings().num_embeddings != embeddings_len
): ):
resize_kwargs = {} resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None: if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -4,7 +4,8 @@ E2E tests for llama pretrain
import logging import logging
import os import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
@@ -12,19 +13,22 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama(unittest.TestCase): class TestPretrainLlama:
""" """
Test case for Llama models w pretraining Test case for Llama models w pretraining
""" """
@with_temp_dir @pytest.mark.parametrize(
def test_pretrain_w_sample_packing(self, temp_dir): "sample_packing",
[True, False],
)
def test_pretrain(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase):
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": sample_packing,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "unk_token": "<unk>",
"bos_token": "<s>", "bos_token": "<s>",

View File

@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_dpo_datasets(cfg) train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_dpo_datasets(cfg) train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features

View File

@@ -12,7 +12,7 @@ from datasets import Dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading with deduplication removes duplicates.""" """Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting # Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_dpo_datasets(self.cfg) train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset has been deduplicated # Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading without deduplication retains duplicates.""" """Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication # Load the dataset without deduplication
train_dataset, _ = load_prepare_dpo_datasets(self.cfg) train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset retains duplicates # Verify that the dataset retains duplicates
assert ( assert (