Compare commits
5 Commits
cli-refact
...
chat-datas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7bf9741831 | ||
|
|
8606093921 | ||
|
|
cba5a457d9 | ||
|
|
19cd83d408 | ||
|
|
1ed4de73b6 |
@@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"):
|
|||||||
)
|
)
|
||||||
ds_cfg["field_messages"] = field_messages
|
ds_cfg["field_messages"] = field_messages
|
||||||
|
|
||||||
message_fields = features["conversations"][0].keys()
|
message_fields = features[field_messages][0].keys()
|
||||||
message_field_role = None
|
message_field_role = None
|
||||||
for key in ["from", "role"]:
|
for key in ["from", "role"]:
|
||||||
if key in message_fields:
|
if key in message_fields:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from datasets import Dataset
|
|||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_processor, load_tokenizer
|
from axolotl.utils.models import load_processor, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
@@ -103,9 +103,9 @@ def load_preference_datasets(
|
|||||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||||
) -> TrainDatasetMeta:
|
) -> TrainDatasetMeta:
|
||||||
"""
|
"""
|
||||||
Loads one or more training or evaluation datasets for DPO training, calling
|
Loads one or more training or evaluation datasets for RL training using paired
|
||||||
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
|
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
|
||||||
information.
|
Optionally, logs out debug information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
@@ -115,7 +115,7 @@ def load_preference_datasets(
|
|||||||
Dataclass with fields for training and evaluation datasets and the computed
|
Dataclass with fields for training and evaluation datasets and the computed
|
||||||
`total_num_steps`.
|
`total_num_steps`.
|
||||||
"""
|
"""
|
||||||
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
|
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
|
|||||||
encode_pretraining,
|
encode_pretraining,
|
||||||
wrap_pretraining_dataset,
|
wrap_pretraining_dataset,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
|
||||||
from axolotl.utils.data.sft import ( # noqa: F401
|
from axolotl.utils.data.sft import ( # noqa: F401
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
load_prepare_datasets,
|
load_prepare_datasets,
|
||||||
|
|||||||
@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
|
|||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
max_tokens: int,
|
||||||
|
examples: Dict[str, List],
|
||||||
|
text_column: str = "text",
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
res = tokenizer(
|
res = tokenizer(
|
||||||
examples["text"],
|
examples[text_column],
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_tokens - 2,
|
max_length=max_tokens - 2,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
|
|||||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||||
cfg.micro_batch_size = 1
|
cfg.micro_batch_size = 1
|
||||||
else:
|
else:
|
||||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
encode = functools.partial(
|
||||||
|
encode_pretraining,
|
||||||
|
tokenizer,
|
||||||
|
max_tokens,
|
||||||
|
text_column=cfg.pretraining_dataset[0].text_column or "text",
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.shuffle_merged_datasets:
|
if cfg.shuffle_merged_datasets:
|
||||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ def drop_long_rl_seq(
|
|||||||
raise ValueError("Unknown RL type")
|
raise ValueError("Unknown RL type")
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_dpo_datasets(cfg):
|
def load_prepare_preference_datasets(cfg):
|
||||||
def load_split(dataset_cfgs, _cfg):
|
def load_split(dataset_cfgs, _cfg):
|
||||||
split_datasets: List[Any] = []
|
split_datasets: List[Any] = []
|
||||||
for i, ds_cfg in enumerate(dataset_cfgs):
|
for i, ds_cfg in enumerate(dataset_cfgs):
|
||||||
|
|||||||
@@ -1057,7 +1057,7 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
hasattr(self.model, "get_input_embeddings")
|
hasattr(self.model, "get_input_embeddings")
|
||||||
and self.model.get_input_embeddings().num_embeddings < embeddings_len
|
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None:
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ E2E tests for llama pretrain
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
@@ -12,19 +13,22 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestPretrainLlama(unittest.TestCase):
|
class TestPretrainLlama:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models w pretraining
|
Test case for Llama models w pretraining
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_pretrain_w_sample_packing(self, temp_dir):
|
"sample_packing",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_pretrain(self, temp_dir, sample_packing):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase):
|
|||||||
"tokenizer_type": "LlamaTokenizer",
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": True,
|
"sample_packing": sample_packing,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from huggingface_hub import snapshot_download
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
@@ -280,7 +280,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
|
|
||||||
assert len(train_dataset) == 1800
|
assert len(train_dataset) == 1800
|
||||||
assert "conversation" in train_dataset.features
|
assert "conversation" in train_dataset.features
|
||||||
@@ -329,7 +329,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset, _ = load_prepare_dpo_datasets(cfg)
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
|
|
||||||
assert len(train_dataset) == 1800
|
assert len(train_dataset) == 1800
|
||||||
assert "conversation" in train_dataset.features
|
assert "conversation" in train_dataset.features
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from datasets import Dataset
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_processor, load_tokenizer
|
from axolotl.utils.models import load_processor, load_tokenizer
|
||||||
@@ -236,7 +236,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
|||||||
"""Verify that loading with deduplication removes duplicates."""
|
"""Verify that loading with deduplication removes duplicates."""
|
||||||
|
|
||||||
# Load the dataset using the deduplication setting
|
# Load the dataset using the deduplication setting
|
||||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||||
|
|
||||||
# Verify that the dataset has been deduplicated
|
# Verify that the dataset has been deduplicated
|
||||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||||
@@ -245,7 +245,7 @@ class TestDeduplicateRLDataset(unittest.TestCase):
|
|||||||
"""Verify that loading without deduplication retains duplicates."""
|
"""Verify that loading without deduplication retains duplicates."""
|
||||||
self.cfg.dataset_exact_deduplication = False
|
self.cfg.dataset_exact_deduplication = False
|
||||||
# Load the dataset without deduplication
|
# Load the dataset without deduplication
|
||||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||||
|
|
||||||
# Verify that the dataset retains duplicates
|
# Verify that the dataset retains duplicates
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
Reference in New Issue
Block a user