remove the bos token from dpo outputs (#1733) [skip ci]
* remove the bos token from dpo outputs * don't forget to fix prompt_input_ids too * use processing_class instead of tokenizer * fix for processing class
This commit is contained in:
@@ -1038,24 +1038,37 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
self,
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = super().tokenize_row(
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
|
||||
Reference in New Issue
Block a user