fix: use text_column even when not packing for pretraining (#2254)
* fix: use text_column even when not packing for pretraining * feat: update test to check when not packing * chore: lint * Update src/axolotl/utils/data/pretraining.py Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def encode_pretraining(
|
||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_tokens: int,
|
||||
examples: Dict[str, List],
|
||||
text_column: str = "text",
|
||||
) -> Dict[str, List]:
|
||||
res = tokenizer(
|
||||
examples["text"],
|
||||
examples[text_column],
|
||||
truncation=True,
|
||||
max_length=max_tokens - 2,
|
||||
add_special_tokens=True,
|
||||
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
|
||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||
cfg.micro_batch_size = 1
|
||||
else:
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
encode = functools.partial(
|
||||
encode_pretraining,
|
||||
tokenizer,
|
||||
max_tokens,
|
||||
text_column=cfg.pretraining_dataset[0].text_column or "text",
|
||||
)
|
||||
|
||||
if cfg.shuffle_merged_datasets:
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||
|
||||
Reference in New Issue
Block a user