Compare commits
64 Commits
kd-trainer
...
kd-trainer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7232cbdeab | ||
|
|
e8fceb7091 | ||
|
|
a5e0671738 | ||
|
|
b9847553af | ||
|
|
513ec9e03b | ||
|
|
530347856d | ||
|
|
261e4fb619 | ||
|
|
158071e95f | ||
|
|
432f65f5e6 | ||
|
|
1d039f5486 | ||
|
|
b9a42b396f | ||
|
|
ff2fb0fc1b | ||
|
|
317f290186 | ||
|
|
ab690f3f01 | ||
|
|
47932f21c4 | ||
|
|
808328e041 | ||
|
|
6784822cfb | ||
|
|
684b38291f | ||
|
|
01896b1bde | ||
|
|
e659c01646 | ||
|
|
204d6c43b4 | ||
|
|
d3c2b7ce9d | ||
|
|
93dfff92f1 | ||
|
|
6e409d2d88 | ||
|
|
d5bc214300 | ||
|
|
92c6c1087e | ||
|
|
feed96f95e | ||
|
|
cba6165ae1 | ||
|
|
cdfcd69afa | ||
|
|
885653d52e | ||
|
|
27faacbf5a | ||
|
|
c51b0337c1 | ||
|
|
fa055f9f69 | ||
|
|
f60c623af0 | ||
|
|
746891eb5c | ||
|
|
f09b5da60b | ||
|
|
689e1c10ba | ||
|
|
a5c085e003 | ||
|
|
63146300b7 | ||
|
|
ca5e397fc5 | ||
|
|
3416302b0d | ||
|
|
7366efc4ca | ||
|
|
d8d817eaed | ||
|
|
c0757e8a20 | ||
|
|
e565694914 | ||
|
|
081928e55b | ||
|
|
dc90c93894 | ||
|
|
18a46c338a | ||
|
|
119d586cf4 | ||
|
|
c73acd7de0 | ||
|
|
0b59a242d4 | ||
|
|
ed490517da | ||
|
|
00ce77e7ef | ||
|
|
ae545e0165 | ||
|
|
b592c05b93 | ||
|
|
7fe0ad088b | ||
|
|
ddcf5c68b3 | ||
|
|
e633a12dbe | ||
|
|
d584354ee4 | ||
|
|
303cfa71aa | ||
|
|
88b3198894 | ||
|
|
8606093921 | ||
|
|
cba5a457d9 | ||
|
|
19cd83d408 |
@@ -11,7 +11,7 @@ from datasets import Dataset
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_processor, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
@@ -109,9 +109,9 @@ def load_preference_datasets(
|
||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||
) -> TrainDatasetMeta:
|
||||
"""
|
||||
Loads one or more training or evaluation datasets for DPO training, calling
|
||||
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
|
||||
information.
|
||||
Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
|
||||
Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
@@ -121,7 +121,7 @@ def load_preference_datasets(
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
"""
|
||||
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||
encode_pretraining,
|
||||
wrap_pretraining_dataset,
|
||||
)
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
|
||||
from axolotl.utils.data.sft import ( # noqa: F401
|
||||
get_dataset_wrapper,
|
||||
load_prepare_datasets,
|
||||
|
||||
@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def encode_pretraining(
|
||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_tokens: int,
|
||||
examples: Dict[str, List],
|
||||
text_column: str = "text",
|
||||
) -> Dict[str, List]:
|
||||
res = tokenizer(
|
||||
examples["text"],
|
||||
examples[text_column],
|
||||
truncation=True,
|
||||
max_length=max_tokens - 2,
|
||||
add_special_tokens=True,
|
||||
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
|
||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||
cfg.micro_batch_size = 1
|
||||
else:
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
encode = functools.partial(
|
||||
encode_pretraining,
|
||||
tokenizer,
|
||||
max_tokens,
|
||||
text_column=cfg.pretraining_dataset[0].text_column or "text",
|
||||
)
|
||||
|
||||
if cfg.shuffle_merged_datasets:
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||
|
||||
@@ -115,7 +115,7 @@ def drop_long_rl_seq(
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def load_prepare_dpo_datasets(cfg):
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
for i, ds_cfg in enumerate(dataset_cfgs):
|
||||
|
||||
@@ -1057,7 +1057,7 @@ class ModelLoader:
|
||||
)
|
||||
if (
|
||||
hasattr(self.model, "get_input_embeddings")
|
||||
and self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
||||
):
|
||||
resize_kwargs = {}
|
||||
if self.cfg.mean_resizing_embeddings is not None:
|
||||
|
||||
@@ -4,7 +4,8 @@ E2E tests for llama pretrain
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
@@ -12,19 +13,22 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
from .utils import check_model_output_exists
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPretrainLlama(unittest.TestCase):
|
||||
class TestPretrainLlama:
|
||||
"""
|
||||
Test case for Llama models w pretraining
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_pretrain_w_sample_packing(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_pretrain(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase):
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"sample_packing": sample_packing,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
|
||||
@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
|
||||
@@ -12,7 +12,7 @@ from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_processor, load_tokenizer
|
||||
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
||||
"""Verify that loading with deduplication removes duplicates."""
|
||||
|
||||
# Load the dataset using the deduplication setting
|
||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
||||
"""Verify that loading without deduplication retains duplicates."""
|
||||
self.cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
||||
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user