filter bad rows

This commit is contained in:
Wing Lian
2024-12-18 15:47:18 -05:00
parent 303cfa71aa
commit d584354ee4
3 changed files with 12 additions and 1 deletions

View File

@@ -52,6 +52,8 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100 map_kwargs["batch_size"] = 100
if self.prompt_tokenizer.filter_rows:
dataset = dataset.filter(self.prompt_tokenizer.filter_rows)
return dataset.map( return dataset.map(
self.prompt_tokenizer.tokenize_prompt, self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc, num_proc=num_proc,

View File

@@ -478,6 +478,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
): ):
self.logprobs_field = logprobs_field self.logprobs_field = logprobs_field
self.temperature = temperature self.temperature = temperature
# remove rows where the logprob field is not available
self.filter_rows = (
lambda row: self.logprobs_field in row
and row[self.logprobs_field] is not None
)
super().__init__( super().__init__(
prompter, prompter,
tokenizer, tokenizer,

View File

@@ -2,7 +2,7 @@
import abc import abc
import logging import logging
from typing import Dict, List, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizer from transformers import BatchEncoding, PreTrainedTokenizer
@@ -34,6 +34,8 @@ class PromptTokenizingStrategy(abc.ABC):
Abstract class for tokenizing strategies Abstract class for tokenizing strategies
""" """
filter_rows: Optional[Callable] = None
def __init__( def __init__(
self, self,
prompter: Prompter, prompter: Prompter,