From d584354ee4d4bf696bc48cf8f3c878197babc876 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Dec 2024 15:47:18 -0500 Subject: [PATCH] filter bad rows --- src/axolotl/datasets.py | 2 ++ src/axolotl/prompt_strategies/chat_template.py | 7 +++++++ src/axolotl/prompt_tokenizers.py | 4 +++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index b5638a614..a8880a00f 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -52,6 +52,8 @@ class TokenizedPromptDataset(Dataset): if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True map_kwargs["batch_size"] = 100 + if self.prompt_tokenizer.filter_rows: + dataset = dataset.filter(self.prompt_tokenizer.filter_rows) return dataset.map( self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc, diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index c2a1060aa..4857d0ffc 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -478,6 +478,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): ): self.logprobs_field = logprobs_field self.temperature = temperature + + # remove rows where the logprob field is not available + self.filter_rows = ( + lambda row: self.logprobs_field in row + and row[self.logprobs_field] is not None + ) + super().__init__( prompter, tokenizer, diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bd6e3f9dc..c29fd05a4 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -2,7 +2,7 @@ import abc import logging -from typing import Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from transformers import BatchEncoding, PreTrainedTokenizer @@ -34,6 +34,8 @@ class PromptTokenizingStrategy(abc.ABC): Abstract class for tokenizing strategies """ + filter_rows: Optional[Callable] = None + def __init__( self, prompter: Prompter,