automatically split out reasoning trace from dataset (#2579)
* automatically split out reasoning trace from dataset * chore: lint * fix import
This commit is contained in:
@@ -228,6 +228,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
train_on_eos: Optional[str] = None,
|
||||
train_on_eot: Optional[str] = None,
|
||||
eot_tokens: Optional[List[str]] = None,
|
||||
split_thinking: Optional[bool] = False,
|
||||
):
|
||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
@@ -247,6 +248,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self.eot_tokens = (
|
||||
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
|
||||
)
|
||||
self.split_thinking = split_thinking
|
||||
|
||||
self.images = "images"
|
||||
|
||||
@@ -655,6 +657,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
transformed_message["role"], transformed_message["role"]
|
||||
)
|
||||
|
||||
# TODO handle reasoning_content with split_thinking
|
||||
# if the role is assistant that we want to use reasoning_content
|
||||
if self.split_thinking and transformed_message["role"] == "assistant":
|
||||
content = transformed_message["content"]
|
||||
pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")]
|
||||
for pair in pairs:
|
||||
if pair[0] in content and pair[1] in content:
|
||||
start_idx = content.find(pair[0])
|
||||
end_idx = content.find(pair[1])
|
||||
thinking_content = content[start_idx + len(pair[0]) : end_idx]
|
||||
transformed_message["reasoning_content"] = thinking_content.strip()
|
||||
transformed_message["content"] = content[
|
||||
end_idx + len(pair[1]) :
|
||||
].lstrip()
|
||||
break
|
||||
|
||||
# Determine which keys in the original message were not mapped
|
||||
mapped_values = set(self.prompter.message_property_mappings.values())
|
||||
remaining_keys = set(message) - mapped_values
|
||||
@@ -689,6 +707,7 @@ class StrategyLoader:
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||
"train_on_eot": ds_cfg.get("train_on_eot", None),
|
||||
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
|
||||
"split_thinking": ds_cfg.get("split_thinking", False),
|
||||
}
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -50,6 +50,7 @@ class SFTDataset(BaseModel):
|
||||
message_property_mappings: dict[str, str] | None = None
|
||||
message_field_training: str | None = None
|
||||
message_field_training_detail: str | None = None
|
||||
split_thinking: bool | None = None
|
||||
logprobs_field: str | None = None
|
||||
temperature: float | None = None
|
||||
roles_to_train: list[str] | None = None
|
||||
|
||||
Reference in New Issue
Block a user