fix: use text_column even when not packing for pretraining (#2254)

* fix: use text_column even when not packing for pretraining

* feat: update test to check when not packing

* chore: lint

* Update src/axolotl/utils/data/pretraining.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
NanoCode012
2025-01-15 10:08:56 +07:00
committed by GitHub
parent 19cd83d408
commit cba5a457d9
2 changed files with 21 additions and 9 deletions

View File

@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
examples[text_column],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)