filter bad rows
This commit is contained in:
@@ -52,6 +52,8 @@ class TokenizedPromptDataset(Dataset):
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
map_kwargs["batch_size"] = 100
|
||||
if self.prompt_tokenizer.filter_rows:
|
||||
dataset = dataset.filter(self.prompt_tokenizer.filter_rows)
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
|
||||
@@ -478,6 +478,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
):
|
||||
self.logprobs_field = logprobs_field
|
||||
self.temperature = temperature
|
||||
|
||||
# remove rows where the logprob field is not available
|
||||
self.filter_rows = (
|
||||
lambda row: self.logprobs_field in row
|
||||
and row[self.logprobs_field] is not None
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
prompter,
|
||||
tokenizer,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
@@ -34,6 +34,8 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
Abstract class for tokenizing strategies
|
||||
"""
|
||||
|
||||
filter_rows: Optional[Callable] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: Prompter,
|
||||
|
||||
Reference in New Issue
Block a user