filter bad rows
This commit is contained in:
@@ -52,6 +52,8 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
map_kwargs["batch_size"] = 100
|
map_kwargs["batch_size"] = 100
|
||||||
|
if self.prompt_tokenizer.filter_rows:
|
||||||
|
dataset = dataset.filter(self.prompt_tokenizer.filter_rows)
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
num_proc=num_proc,
|
num_proc=num_proc,
|
||||||
|
|||||||
@@ -478,6 +478,13 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
):
|
):
|
||||||
self.logprobs_field = logprobs_field
|
self.logprobs_field = logprobs_field
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
|
||||||
|
# remove rows where the logprob field is not available
|
||||||
|
self.filter_rows = (
|
||||||
|
lambda row: self.logprobs_field in row
|
||||||
|
and row[self.logprobs_field] is not None
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
prompter,
|
prompter,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||||
|
|
||||||
@@ -34,6 +34,8 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
Abstract class for tokenizing strategies
|
Abstract class for tokenizing strategies
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
filter_rows: Optional[Callable] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompter: Prompter,
|
prompter: Prompter,
|
||||||
|
|||||||
Reference in New Issue
Block a user