diff --git a/examples/qwen3/README.md b/examples/qwen3/README.md new file mode 100644 index 000000000..a3d35881d --- /dev/null +++ b/examples/qwen3/README.md @@ -0,0 +1,46 @@ +# Finetune Qwen3 with Axolotl + +[Qwen3](https://huggingface.co/collections/Qwen/qwen3) are a family of open source models trained by Alibaba. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Run the finetuning example: + + ```bash + axolotl train examples/qwen3/32b-qlora.yaml + ``` + +Let us know how it goes. Happy finetuning! 🚀 + +### Chat template masking a few tokens off + +If you notice that the `chat_template` masking for assistant prompts are off by a few tokens, please ensure that you are adding the below to the yaml. + +```yaml +chat_template: qwen3 +``` + +### TIPS + +- For inference, please check the official model card as it depends on your reasoning mode. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html). + +## Related Resources + +- [Qwen3 Blog](https://qwenlm.github.io/blog/qwen3/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 28155810f..0fec64d81 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -95,6 +95,7 @@ class ChatTemplatePrompter(Prompter): add_generation_prompt=False, images=None, tools=None, + real_last_index=None, ): """ Build a prompt from a conversation. @@ -114,6 +115,9 @@ class ChatTemplatePrompter(Prompter): if tools: chat_template_kwargs["tools"] = tools + if real_last_index: + chat_template_kwargs["real_last_index"] = real_last_index + if self.processor: if not callable(self.processor): raise TypeError("Processor must be callable") @@ -631,11 +635,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): turns_with_empty = turns[:turn_idx] + [empty_turn] turns_with_content = turns[: turn_idx + 1] + real_last_index = len(turns) - 1 + # Generate the conversation up to the turn, with final turn replaced with dummy content - dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore + dummy_ids = self.prompter.build_prompt( + turns_with_empty, tools=tools, real_last_index=real_last_index + ) # type: ignore # Generate the conversation up to the turn, with final turn included - full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore + full_ids = self.prompter.build_prompt( + turns_with_content, tools=tools, real_last_index=real_last_index + ) # type: ignore if not full_ids or not dummy_ids: LOG.warning(f"Empty template generated for turn {turn_idx}") diff --git a/src/axolotl/utils/chat_templates/templates/qwen3.jinja b/src/axolotl/utils/chat_templates/templates/qwen3.jinja index 09b82ed03..77ea906e7 100644 --- a/src/axolotl/utils/chat_templates/templates/qwen3.jinja +++ b/src/axolotl/utils/chat_templates/templates/qwen3.jinja @@ -15,6 +15,12 @@ {%- endif %} {%- endif %} {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{#- Determine the real last index: use provided value or default to messages length - 1 #} +{%- if real_last_index is defined and real_last_index is not none %} + {%- set ns.real_last_index = real_last_index %} +{%- else %} + {%- set ns.real_last_index = messages|length - 1 %} +{%- endif %} {%- for message in messages[::-1] %} {%- set index = (messages|length - 1) - loop.index0 %} {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} @@ -37,7 +43,7 @@ {%- endif %} {%- endif %} {%- if loop.index0 > ns.last_query_index %} - {%- if loop.last or (not loop.last and reasoning_content) %} + {%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %} {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} {%- else %} {{- '<|im_start|>' + message.role + '\n' + content }} diff --git a/src/axolotl/utils/mistral/mistral_tokenizer.py b/src/axolotl/utils/mistral/mistral_tokenizer.py index af174cdac..3ce6be780 100644 --- a/src/axolotl/utils/mistral/mistral_tokenizer.py +++ b/src/axolotl/utils/mistral/mistral_tokenizer.py @@ -80,6 +80,9 @@ class HFMistralTokenizer(MistralCommonTokenizer): ) -> str | list[int]: """Patched fn to handle setting serving mode, continue_final_message, remove chat_template and add_generation_prompt kwarg""" + # pop unnecessary kwarg for mistral + kwargs.pop("real_last_index", None) + try: if add_generation_prompt: self._set_mode(ValidationMode.serving)