fix: use text_column even when not packing for pretraining (#2254)
* fix: use text_column even when not packing for pretraining * feat: update test to check when not packing * chore: lint * Update src/axolotl/utils/data/pretraining.py Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def encode_pretraining(
|
||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_tokens: int,
|
||||
examples: Dict[str, List],
|
||||
text_column: str = "text",
|
||||
) -> Dict[str, List]:
|
||||
res = tokenizer(
|
||||
examples["text"],
|
||||
examples[text_column],
|
||||
truncation=True,
|
||||
max_length=max_tokens - 2,
|
||||
add_special_tokens=True,
|
||||
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
|
||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||
cfg.micro_batch_size = 1
|
||||
else:
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
encode = functools.partial(
|
||||
encode_pretraining,
|
||||
tokenizer,
|
||||
max_tokens,
|
||||
text_column=cfg.pretraining_dataset[0].text_column or "text",
|
||||
)
|
||||
|
||||
if cfg.shuffle_merged_datasets:
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||
|
||||
@@ -4,7 +4,8 @@ E2E tests for llama pretrain
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
@@ -12,19 +13,22 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
from .utils import check_model_output_exists
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPretrainLlama(unittest.TestCase):
|
||||
class TestPretrainLlama:
|
||||
"""
|
||||
Test case for Llama models w pretraining
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_pretrain_w_sample_packing(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_pretrain(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase):
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"sample_packing": sample_packing,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
|
||||
Reference in New Issue
Block a user