filter bad rows

This commit is contained in:
Wing Lian
2024-12-18 15:47:18 -05:00
parent 303cfa71aa
commit d584354ee4
3 changed files with 12 additions and 1 deletions

View File

@@ -52,6 +52,8 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
if self.prompt_tokenizer.filter_rows:
dataset = dataset.filter(self.prompt_tokenizer.filter_rows)
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,

View File

@@ -478,6 +478,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
):
self.logprobs_field = logprobs_field
self.temperature = temperature
# remove rows where the logprob field is not available
self.filter_rows = (
lambda row: self.logprobs_field in row
and row[self.logprobs_field] is not None
)
super().__init__(
prompter,
tokenizer,

View File

@@ -2,7 +2,7 @@
import abc
import logging
from typing import Dict, List, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizer
@@ -34,6 +34,8 @@ class PromptTokenizingStrategy(abc.ABC):
Abstract class for tokenizing strategies
"""
filter_rows: Optional[Callable] = None
def __init__(
self,
prompter: Prompter,