fix: use text_column even when not packing for pretraining (#2254)

* fix: use text_column even when not packing for pretraining

* feat: update test to check when not packing

* chore: lint

* Update src/axolotl/utils/data/pretraining.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
NanoCode012
2025-01-15 10:08:56 +07:00
committed by GitHub
parent 19cd83d408
commit cba5a457d9
2 changed files with 21 additions and 9 deletions

View File

@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
examples[text_column],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)

View File

@@ -4,7 +4,8 @@ E2E tests for llama pretrain
import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
@@ -12,19 +13,22 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama(unittest.TestCase):
class TestPretrainLlama:
"""
Test case for Llama models w pretraining
"""
@with_temp_dir
def test_pretrain_w_sample_packing(self, temp_dir):
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_pretrain(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
@@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase):
"tokenizer_type": "LlamaTokenizer",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": True,
"sample_packing": sample_packing,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",