Compare commits

..

66 Commits

Author SHA1 Message Date
Wing Lian
35a84f2cb8 more fixes 2025-01-14 22:47:49 -05:00
Wing Lian
510cf45317 improve logprob masking and shift in trainer 2025-01-14 22:47:48 -05:00
Wing Lian
7232cbdeab chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
e8fceb7091 chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
a5e0671738 make sure to use tensorboard to capture loss for checks 2025-01-14 22:47:48 -05:00
Wing Lian
b9847553af fix adapter model check 2025-01-14 22:47:48 -05:00
Wing Lian
513ec9e03b make sure to use the correct tokenizer 2025-01-14 22:47:48 -05:00
Wing Lian
530347856d make sure to set tokenizer from l3 70b and save safetensors 2025-01-14 22:47:47 -05:00
Wing Lian
261e4fb619 lower lr 2025-01-14 22:47:47 -05:00
Wing Lian
158071e95f set lora_dropout explicitly 2025-01-14 22:47:47 -05:00
Wing Lian
432f65f5e6 make the kd e2e fit in vram for ci and add lora version 2025-01-14 22:47:47 -05:00
Wing Lian
1d039f5486 rename test files so it gets picked up 2025-01-14 22:47:47 -05:00
Wing Lian
b9a42b396f linting 2025-01-14 22:47:47 -05:00
Wing Lian
ff2fb0fc1b add kd trainer e2e test 2025-01-14 22:47:47 -05:00
Wing Lian
317f290186 reward model doesn't work well with batched 2025-01-14 22:47:46 -05:00
Wing Lian
ab690f3f01 improve check for batched 2025-01-14 22:47:46 -05:00
Wing Lian
47932f21c4 fix reward trainer calls for tokenization 2025-01-14 22:47:46 -05:00
Wing Lian
808328e041 reward can use same batch check 2025-01-14 22:47:46 -05:00
Wing Lian
6784822cfb tweak check for batched prompt data 2025-01-14 22:47:46 -05:00
Wing Lian
684b38291f ensure that batch vs single is done properly 2025-01-14 22:47:46 -05:00
Wing Lian
01896b1bde improve iterable support 2025-01-14 22:47:46 -05:00
Wing Lian
e659c01646 support streaming for processing sft datasts? 2025-01-14 22:47:45 -05:00
Wing Lian
204d6c43b4 make loss torch script compat 2025-01-14 22:47:45 -05:00
Wing Lian
d3c2b7ce9d kd sample packing 2025-01-14 22:47:45 -05:00
Wing Lian
93dfff92f1 be a bit pickier about loading dynamic prompt strategies 2025-01-14 22:47:45 -05:00
Wing Lian
6e409d2d88 more info on preprocess for kd and fix import 2025-01-14 22:47:45 -05:00
Wing Lian
d5bc214300 remove duplicate code 2025-01-14 22:47:45 -05:00
Wing Lian
92c6c1087e add copyrights 2025-01-14 22:47:45 -05:00
Wing Lian
feed96f95e increase logging around loading plugins 2025-01-14 22:47:44 -05:00
Wing Lian
cba6165ae1 make plugin setup concise 2025-01-14 22:47:44 -05:00
Wing Lian
cdfcd69afa remove moved class from import 2025-01-14 22:47:44 -05:00
Wing Lian
885653d52e move more things to kd plugin 2025-01-14 22:47:44 -05:00
Wing Lian
27faacbf5a refactor kd chat template loader 2025-01-14 22:47:44 -05:00
Wing Lian
c51b0337c1 support for custom trainer classes from plugins 2025-01-14 22:47:44 -05:00
Wing Lian
fa055f9f69 handle token/logprob shifting 2025-01-14 22:47:43 -05:00
Wing Lian
f60c623af0 remove references to triton kd for now 2025-01-14 22:47:43 -05:00
Wing Lian
746891eb5c add license block 2025-01-14 22:47:43 -05:00
Wing Lian
f09b5da60b refactor so we can easily add new loss functions 2025-01-14 22:47:43 -05:00
Wing Lian
689e1c10ba chore: lint 2025-01-14 22:47:43 -05:00
Wing Lian
a5c085e003 var naming and add todo 2025-01-14 22:47:43 -05:00
Wing Lian
63146300b7 fix kd loss so it's causal (fixes repeating tokens) 2025-01-14 22:47:43 -05:00
Wing Lian
ca5e397fc5 use kd_alpha in the correct loss method 2025-01-14 22:47:42 -05:00
Wing Lian
3416302b0d hash for temperature too 2025-01-14 22:47:42 -05:00
Wing Lian
7366efc4ca better rescaling for temperatures 2025-01-14 22:47:42 -05:00
Wing Lian
d8d817eaed don't use triton for now 2025-01-14 22:47:42 -05:00
Wing Lian
c0757e8a20 fix kwarg 2025-01-14 22:47:42 -05:00
Wing Lian
e565694914 v3 2025-01-14 22:47:42 -05:00
Wing Lian
081928e55b no torch.tensor 2025-01-14 22:47:42 -05:00
Wing Lian
dc90c93894 no log etc 2025-01-14 22:47:41 -05:00
Wing Lian
18a46c338a no torch.exp inside triton kernel 2025-01-14 22:47:41 -05:00
Wing Lian
119d586cf4 v2 trial 2025-01-14 22:47:41 -05:00
Wing Lian
c73acd7de0 no where support 2025-01-14 22:47:41 -05:00
Wing Lian
0b59a242d4 triton wip 2025-01-14 22:47:41 -05:00
Wing Lian
ed490517da chore: lint 2025-01-14 22:47:41 -05:00
Wing Lian
00ce77e7ef make sure to multiply against the correct loss 2025-01-14 22:47:41 -05:00
Wing Lian
ae545e0165 cross entropy loss coefficient during KD 2025-01-14 22:47:40 -05:00
Wing Lian
b592c05b93 flipped the slice 2025-01-14 22:47:40 -05:00
Wing Lian
7fe0ad088b make it work 2025-01-14 22:47:40 -05:00
Wing Lian
ddcf5c68b3 handle padding/collation for KD datasets 2025-01-14 22:47:40 -05:00
Wing Lian
e633a12dbe make batch smaller 2025-01-14 22:47:40 -05:00
Wing Lian
d584354ee4 filter bad rows 2025-01-14 22:47:40 -05:00
Wing Lian
303cfa71aa KD dataset loading and KD with logprobs 2025-01-14 22:47:40 -05:00
Wing Lian
88b3198894 refactor trainer to prevent circular dependencies later
fix loader default
2025-01-14 22:47:39 -05:00
jwongTensora
8606093921 fix for indexing error from token/embeddings mismatch (#2257)
Co-authored-by: jwong <jwongTensora@gmail.com>
2025-01-14 22:09:29 -05:00
NanoCode012
cba5a457d9 fix: use text_column even when not packing for pretraining (#2254)
* fix: use text_column even when not packing for pretraining

* feat: update test to check when not packing

* chore: lint

* Update src/axolotl/utils/data/pretraining.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-01-14 22:08:56 -05:00
Wing Lian
19cd83d408 rename references to dpo dataset prep to pref data (#2258) 2025-01-14 22:07:55 -05:00
11 changed files with 68 additions and 34 deletions

View File

@@ -11,7 +11,7 @@ from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
@@ -109,9 +109,9 @@ def load_preference_datasets(
cli_args: Union[PreprocessCliArgs, TrainerCliArgs], cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
""" """
Loads one or more training or evaluation datasets for DPO training, calling Loads one or more training or evaluation datasets for RL training using paired
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
information. Optionally, logs out debug information.
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Dictionary mapping `axolotl` config keys to values.
@@ -121,7 +121,7 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`. `total_num_steps`.
""" """
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg) train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )

View File

@@ -53,6 +53,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
) )
def transform_logprobs(self, sample): def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field) logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs) target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"]) input_seq_len = len(sample["input_ids"])
@@ -62,16 +66,33 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = [] target_token_ids = []
target_mask = [] target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# fill with -inf for padding_len tokens for top_k tokens # fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf # extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(1, input_padding_len): # start at 1 since this is causal
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k) target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k))) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k) target_mask.append([0] * top_k)
for _ in range(target_seq_len): # for _ in range(target_seq_len):
# TODO also check against sample["labels"] # # TODO also check against sample["labels"]
target_mask.append([1] * top_k) # target_mask.append([1] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs): for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids # Initialize collections for logprobs and token_ids
@@ -120,10 +141,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled) target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids) target_token_ids.append(position_token_ids)
# since we started at index 1 for causal, we need one more padding token if shift == 1:
target_logprobs.append([-float("inf")] * top_k) # since we started at index 1 for causal, we need one more padding token
target_token_ids.append(list(range(top_k))) target_logprobs.append([-float("inf")] * top_k)
target_mask.append([0] * top_k) target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs # Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs sample["target_logprobs"] = target_logprobs

View File

@@ -45,7 +45,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
inputs, inputs,
return_outputs=False, return_outputs=False,
num_items_in_batch=None, num_items_in_batch=None,
shift_targets=False, shift_targets=True,
): ):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.

View File

@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,
) )
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401 from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper, get_dataset_wrapper,
load_prepare_datasets, load_prepare_datasets,

View File

@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining( def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
) -> Dict[str, List]: ) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples["text"], examples[text_column],
truncation=True, truncation=True,
max_length=max_tokens - 2, max_length=max_tokens - 2,
add_special_tokens=True, add_special_tokens=True,
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
# set this to 1 so downstream data_loader doesn't try to increase the batch again # set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1 cfg.micro_batch_size = 1
else: else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
)
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)

View File

@@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg): def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg): def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = [] split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs): for i, ds_cfg in enumerate(dataset_cfgs):

View File

@@ -1057,7 +1057,7 @@ class ModelLoader:
) )
if ( if (
hasattr(self.model, "get_input_embeddings") hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings < embeddings_len and self.model.get_input_embeddings().num_embeddings != embeddings_len
): ):
resize_kwargs = {} resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None: if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -279,6 +279,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs["desc"] = "Dropping Long Sequences" drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
drop_long, drop_long,
batched=True,
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
) )
@@ -310,8 +311,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
""" """
labels = sample["labels"] labels = sample["labels"]
if not labels: if not labels:
# Edge case: if labels is empty, decide if you want to keep or drop return True
return True # or False
# Check if single example or batch # Check if single example or batch
# If first element is an int, we assume a single example # If first element is an int, we assume a single example

View File

@@ -4,7 +4,8 @@ E2E tests for llama pretrain
import logging import logging
import os import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
@@ -12,19 +13,22 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama(unittest.TestCase): class TestPretrainLlama:
""" """
Test case for Llama models w pretraining Test case for Llama models w pretraining
""" """
@with_temp_dir @pytest.mark.parametrize(
def test_pretrain_w_sample_packing(self, temp_dir): "sample_packing",
[True, False],
)
def test_pretrain(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase):
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": sample_packing,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "unk_token": "<unk>",
"bos_token": "<s>", "bos_token": "<s>",

View File

@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_dpo_datasets(cfg) train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
} }
) )
train_dataset, _ = load_prepare_dpo_datasets(cfg) train_dataset, _ = load_prepare_preference_datasets(cfg)
assert len(train_dataset) == 1800 assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features assert "conversation" in train_dataset.features

View File

@@ -12,7 +12,7 @@ from datasets import Dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.models import load_processor, load_tokenizer
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading with deduplication removes duplicates.""" """Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting # Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_dpo_datasets(self.cfg) train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset has been deduplicated # Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated" assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
"""Verify that loading without deduplication retains duplicates.""" """Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication # Load the dataset without deduplication
train_dataset, _ = load_prepare_dpo_datasets(self.cfg) train_dataset, _ = load_prepare_preference_datasets(self.cfg)
# Verify that the dataset retains duplicates # Verify that the dataset retains duplicates
assert ( assert (