drop empty token from beginning if tokenizer has no bos_token (in the case of qwen) (#1490)
This commit is contained in:
@@ -23,6 +23,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
|
PreTrainedModel,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@@ -802,6 +803,15 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
def tokenize_row(
|
||||||
|
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
|
||||||
|
) -> Dict:
|
||||||
|
res = super().tokenize_row(feature, model=model)
|
||||||
|
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||||
|
for key in res.keys():
|
||||||
|
res[key] = res[key][1:]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user