From cba5a457d9541a1ffde6a99977bff575c4899966 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 15 Jan 2025 10:08:56 +0700 Subject: [PATCH] fix: use text_column even when not packing for pretraining (#2254) * fix: use text_column even when not packing for pretraining * feat: update test to check when not packing * chore: lint * Update src/axolotl/utils/data/pretraining.py Co-authored-by: Wing Lian --------- Co-authored-by: Wing Lian Co-authored-by: Wing Lian --- src/axolotl/utils/data/pretraining.py | 14 +++++++++++--- tests/e2e/test_llama_pretrain.py | 16 ++++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index f493db70e..369d2d6fe 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl") def encode_pretraining( - tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List] + tokenizer: PreTrainedTokenizerBase, + max_tokens: int, + examples: Dict[str, List], + text_column: str = "text", ) -> Dict[str, List]: res = tokenizer( - examples["text"], + examples[text_column], truncation=True, max_length=max_tokens - 2, add_special_tokens=True, @@ -196,7 +199,12 @@ def wrap_pretraining_dataset( # set this to 1 so downstream data_loader doesn't try to increase the batch again cfg.micro_batch_size = 1 else: - encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + encode = functools.partial( + encode_pretraining, + tokenizer, + max_tokens, + text_column=cfg.pretraining_dataset[0].text_column or "text", + ) if cfg.shuffle_merged_datasets: dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 68cd490be..117eba25d 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -4,7 +4,8 @@ E2E tests for llama pretrain import logging import os -import unittest + +import pytest from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets @@ -12,19 +13,22 @@ from axolotl.train import train from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists, with_temp_dir +from .utils import check_model_output_exists LOG = logging.getLogger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" -class TestPretrainLlama(unittest.TestCase): +class TestPretrainLlama: """ Test case for Llama models w pretraining """ - @with_temp_dir - def test_pretrain_w_sample_packing(self, temp_dir): + @pytest.mark.parametrize( + "sample_packing", + [True, False], + ) + def test_pretrain(self, temp_dir, sample_packing): # pylint: disable=duplicate-code cfg = DictDefault( { @@ -32,7 +36,7 @@ class TestPretrainLlama(unittest.TestCase): "tokenizer_type": "LlamaTokenizer", "flash_attention": True, "sequence_len": 1024, - "sample_packing": True, + "sample_packing": sample_packing, "special_tokens": { "unk_token": "", "bos_token": "",