diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index fbb44ccfa..c02688968 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -39,6 +39,23 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): return strategy +def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None + ) + strategy = UltrachatShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation=conversation, + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + if ds_cfg and "strict" in ds_cfg: + strategy.strict = ds_cfg["strict"] + return strategy + + def load_role(tokenizer, cfg): return SimpleRoleShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2(), @@ -109,3 +126,17 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): {"from": role_map[t["role"]], "value": t["text"]} for t in conversations ] return turns + + +class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): + """ + sharegpt strategy that remaps ultrachat data to sharegpt format + """ + + def get_conversation_thread(self, prompt): + conversations = prompt["messages"] + role_map = {"user": "human", "assistant": "gpt"} + turns = [ + {"from": role_map[t["role"]], "value": t["content"]} for t in conversations + ] + return turns