Files
axolotl/search.json
Quarto GHA Workflow Runner f807756bde Built site for gh-pages
2026-04-02 14:25:09 +00:00

6928 lines
1.1 MiB
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
[
{
"objectID": "FAQS.html",
"href": "FAQS.html",
"title": "FAQs",
"section": "",
"text": "FAQs\n\nCan you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this PR\nWill this work with Deepspeed? Thats still a WIP, but setting export ACCELERATE_USE_DEEPSPEED=true should work in some cases\nError invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c\n/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.\nThis could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source."
},
{
"objectID": "docs/dataset-formats/template_free.html",
"href": "docs/dataset-formats/template_free.html",
"title": "Template-Free",
"section": "",
"text": "One of the most popular features of\naxolotl is\nsetting the following configuration value:\ntrain_on_inputs: false\nIf you declare a dataset formats\nsuch as alpaca or chatml, axolotl knows what is an input\n(i.e. human) vs. an output (i.e. the assistant) and masks the input\nlabels so that your model can focus on predicting the outputs only.\n\n\n\nHowever, there are many situations where you dont want to use one of\nthese formats or templates. This is because they can:\n\nAdd unnecessary boilerplate to your prompts.\nCreate artifacts like special delimiters <|im_start|> that can\nquickly become footguns if you dont include them correctly at\ninference time.\nEnforce a chat interface when you do not want one. Sometimes you\njust want to fine-tune a model to a very specific task and do NOT\nwant multi-turn conversations, roles, etc.\nLimit you to only certain roles that the template allows.\n\n\n\n\nYou can construct your prompts without a template by using the\ninput_output format, by setting type: input_output in your\nconfiguration file like this:\nconfig.yml\ntrain_on_inputs: false # Mask segments of your data\ndatasets:\n - path: output.jsonl\n type: input_output # use template free prompt construction\nUnlike type: completion, which is also template-free,\ntype: input_output allows you to mask segments of your text. More\ndetails on how this works are described below.",
"crumbs": [
"Dataset Formats",
"Template-Free"
]
},
{
"objectID": "docs/dataset-formats/template_free.html#sec-background",
"href": "docs/dataset-formats/template_free.html#sec-background",
"title": "Template-Free",
"section": "",
"text": "One of the most popular features of\naxolotl is\nsetting the following configuration value:\ntrain_on_inputs: false\nIf you declare a dataset formats\nsuch as alpaca or chatml, axolotl knows what is an input\n(i.e. human) vs. an output (i.e. the assistant) and masks the input\nlabels so that your model can focus on predicting the outputs only.\n\n\n\nHowever, there are many situations where you dont want to use one of\nthese formats or templates. This is because they can:\n\nAdd unnecessary boilerplate to your prompts.\nCreate artifacts like special delimiters <|im_start|> that can\nquickly become footguns if you dont include them correctly at\ninference time.\nEnforce a chat interface when you do not want one. Sometimes you\njust want to fine-tune a model to a very specific task and do NOT\nwant multi-turn conversations, roles, etc.\nLimit you to only certain roles that the template allows.\n\n\n\n\nYou can construct your prompts without a template by using the\ninput_output format, by setting type: input_output in your\nconfiguration file like this:\nconfig.yml\ntrain_on_inputs: false # Mask segments of your data\ndatasets:\n - path: output.jsonl\n type: input_output # use template free prompt construction\nUnlike type: completion, which is also template-free,\ntype: input_output allows you to mask segments of your text. More\ndetails on how this works are described below.",
"crumbs": [
"Dataset Formats",
"Template-Free"
]
},
{
"objectID": "docs/dataset-formats/template_free.html#sec-usage",
"href": "docs/dataset-formats/template_free.html#sec-usage",
"title": "Template-Free",
"section": "Usage",
"text": "Usage\nThis is how you can use the input_output format:\n\n1. Prepare Data\nTo use the input_output format, collect your data in the following\nformat into a jsonl file (below is the first row from the file\noutput.jsonl` pretty printed):\n$ head -n1 output.jsonl | python -m json.tool\n\n{\n \"segments\": [\n {\n \"label\": true,\n \"text\": \"<s>Hello\\n\"\n },\n {\n \"label\": true,\n \"text\": \"hi there!. \"\n },\n {\n \"label\": false,\n \"text\": \"goodbye \"\n },\n {\n \"label\": true,\n \"text\": \"farewell</s>\"\n }\n ]\n}\n\nSet label:false when you want to mask a segment of text so that the\nmodel isnt trained on it. Some things to keep in mind:\n\n[!IMPORTANT]\n1. EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl\nconcatenates all the segments as-is. The tokenizer doesnt add\nanything additional. Notice how I added spaces, newlines, <s>\n(BOS), and </s> (EOS) myself.\n2. Make sure you check the materialized output to validate that the\nprompt is getting assembled how you like.\n\n\n\n2. Use type: input_output\nLets materialize data with our output.jsonl file by setting\ntype: input_output in our axolotl config:\n# training_config.yaml\nbase_model: mistralai/Mistral-7B-v0.1\ndata_seed: 49\nseed: 49\n\ndatasets:\n - path: output.jsonl\n type: input_output\nval_set_size: 0.1\n\nsequence_len: 896\nsample_packing: false\n\nmicro_batch_size: 2\ngradient_accumulation_steps: 3\neval_batch_size: 2\nnum_epochs: 1\nlearning_rate: 0.0002\n\ntrain_on_inputs: false\nspecial_tokens:\n bos_token: \"<s>\"\n eos_token: \"</s>\"\n unk_token: \"<unk>\"\nYou can use the following command to materialize your data. The\n--debug flag will print the tokens, along with the labels so you can\nverify that the correct items are being ignored:\naxolotl preprocess training_config.yaml --debug\n\n...\n[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)\n(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)\nThe format is decoded_token(label, token_id), for example,\n<s>(1, 1) means that the token is <s>, the label is 1 and the\ntoken_id is 1. When the label is -100 then that token is ignored for\ntraining.\n\n\n3. Check the prompts\nHere is another way to check the materialized output:\nfrom transformers import AutoTokenizer\nfrom datasets import load_from_disk\nimport yaml\n\ndirectory = !ls last_run_prepared/\nwith open('training_config.yaml', 'r') as f:\n cfg = yaml.safe_load(f)\nmodel_id = cfg['base_model']\ntok = AutoTokenizer.from_pretrained(model_id)\nds = load_from_disk(f'last_run_prepared/{directory[0]}/')\n>>> row = ds[0]\n>>> print(tok.decode(row['input_ids']))\n<s> Hello\n hi there!. goodbye farewell</s>\nWe can check that the right tokens are ignored by comparing the labels\nto each token:\nimport pandas as pd\npd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in\n zip(row['input_ids'], row['labels'])])\n\n\n\ntoken\nlabel\nid\n\n\n\n\n0\n<s>\n1\n\n\n1\nHello\n22557\n\n\n2\n\\n\n13\n\n\n3\nhi\n12014\n\n\n4\nthere\n736\n\n\n5\n!\n28808\n\n\n6\n.\n28723\n\n\n7\n\n28705\n\n\n8\ngood\n-100\n\n\n9\nbye\n-100\n\n\n10\n\n-100\n\n\n11\nfare\n19111\n\n\n12\nwell\n5458\n\n\n13\n</s>\n2\n\n\n\nIf we look at the input data, the above table seems correct! (The jsonl\nversion is repeated below for reference):\n$ head -n1 output.jsonl | python -m json.tool\n\n{\n \"segments\": [\n {\n \"label\": true,\n \"text\": \"<s>Hello\\n\"\n },\n {\n \"label\": true,\n \"text\": \"hi there!. \"\n },\n {\n \"label\": false,\n \"text\": \"goodbye \"\n },\n {\n \"label\": true,\n \"text\": \"farewell</s>\"\n }\n ]\n}",
"crumbs": [
"Dataset Formats",
"Template-Free"
]
},
{
"objectID": "docs/dataset-formats/conversation.html",
"href": "docs/dataset-formats/conversation.html",
"title": "Conversation",
"section": "",
"text": "Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizers template, a supported template, or custom jinja2.\n\n\ndata.jsonl\n\n{\"messages\": [{\"role\": \"...\", \"content\": \"...\"}, {\"role\": \"...\", \"content\": \"...\"}, ...]}\n\nSee configs for full configs and supported templates.\n\n\nMost configs can be adapted as follows:\n# old\nchat_template: chatml\ndatasets:\n - path: ...\n type: sharegpt\n conversation: chatml\n\n# new (if using tokenizer's chat_template)\ndatasets:\n - path: ...\n type: chat_template\n\n field_messages: conversations\n message_property_mappings:\n role: from\n content: value\n\n# new (if setting a new chat_template like chatml, gemma, etc)\nchat_template: chatml\ndatasets:\n - path: ...\n type: chat_template\n\n field_messages: conversations\n message_property_mappings:\n role: from\n content: value\nWe recommend checking the below examples for other usecases.\n\n\n\n\n\n(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.\ndatasets:\n - path: ...\n type: chat_template\n roles_to_train:\n train_on_eos:\n\n\n\n\n\n\nTip\n\n\n\nIf you receive an error like “chat_template choice is tokenizer_default but tokenizers chat_template is null.”, it means the tokenizer does not have a default chat_template. Follow the examples below instead to set a custom chat_template.\n\n\n\n\n\nUsing the gemma chat template to override the tokenizer_config.jsons chat template on OpenAI messages format, training on all assistant messages.\nchat_template: gemma # this overwrites the tokenizer's chat_template\ndatasets:\n - path: ...\n type: chat_template\n roles_to_train: [\"assistant\"] # default value\n\n\n\n\n\n\nNote\n\n\n\nIf you want to use built-in chat_template, use chat_template: tokenizer_default (this is set by default).\n\n\n\n\n\nUsing the tokenizer_config.jsons chat template or chatml as fallback if the formers chat template does not exist, on OpenAI messages format, training on all assistant messages.\nchat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template\ndatasets:\n - path: ...\n type: chat_template\n\n\n\nUsing a custom jinja template on OpenAI messages format, training on all assistant messages.\n# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty\nchat_template_jinja: \"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\\n' + message['content'] + '<|end|>' + '\\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\\n' + message['content'] + '<|end|>' + '\\n' + '<|assistant|>' + '\\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\\n'}}{% endif %}{% endfor %}\"\n\ndatasets:\n - path: ...\n type: chat_template\n\n\n\n\n\n\nImportant\n\n\n\nPlease make sure that your tokenizer.eos_token is same as EOS (End-of-Sequence) token in template. Otherwise, set eos_token under special_tokens:.\n\n\n\n\n\n\nIf you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the eot_tokens: config. The handling of EOT tokens follows train_on_eos: which defaults to turn.\n\neot_tokens:\n - \"[/INST]\"\n # - \"[/SYSTEM_PROMPT]\"\n\ndatasets:\n - path: ...\n type: chat_template\n\n # optional\n train_on_eot: turn # defaults read from train_on_eos (which defaults to turn)\n\n\n\n\n\n\nTip\n\n\n\nSee config documentation for detailed explanations of “turn”, “last”, and “all” options for training on tokens.\n\n\n\n\n\n\n\n\nNote\n\n\n\nUsing eot_tokens requires each token that exists in chat_template to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.\nYou can add those tokens as new tokens under tokens: or (recommended) override unused added_tokens via added_tokens_overrides:. See config for more details.\n\n\n\nContinuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set train_on_eos: last.\n\neot_tokens:\n - \"[/INST]\"\n # ...\n\ndatasets:\n - path: ...\n type: chat_template\n\n train_on_eos: last\n train_on_eot: turn\n\n\n\n\n\n\nTip\n\n\n\nIf EOS token only appears at the end of a prompt, train_on_eos: last is equivalent to train_on_eos: turn. Therefore, generally, you can leave them to their defaults and omit them.\n\n\n\n\n\nInstead of passing tools via the system prompt, an alternative method would be to have the tools in a separate column and loaded via chat_template to let the template dynamically build it.\n{\n \"tools\": [\n {\n \"type\": \"...\",\n \"function\": {\n \"name\": \"...\",\n \"description\": \"...\",\n \"parameters\": {\n \"type\": \"...\",\n \"properties\": {\n // ...\n },\n \"required\": [\"...\"],\n },\n },\n },\n ],\n \"messages\": [\n // ...\n {\n \"role\": \"assistant\", // call the function via assistant\n \"tool_calls\": [\n {\n \"id\": \"...\", // required only for mistral\n \"type\": \"function\",\n \"function\": {\n \"name\": \"...\",\n \"arguments\": {\n \"...\": \"...\",\n }\n }\n }\n ]\n },\n {\n \"role\": \"tool\",\n \"tool_call_id\": \"...\", // required only for mistral\n \"name\": \"...\",\n \"content\": \"...\"\n },\n ],\n}\n\n\n\n\n\n\nNote\n\n\n\nTools need to follow JSON schema.\n\n\n\n\n\n\n\n\nWarning\n\n\n\nIf you have tool arguments with same name but different dtypes (like \"time\": string and \"time\": number), please save arguments: as JSON string to prevent datasets from having casting issues.\n\"arguments\": \"{\\\"...\\\": \\\"...\\\"}\"\nThe same is applicable for tool parameters.\n\"parameters\": \"{\\\"...\\\": \\\"...\\\"}\"\n\n\nExample config for Llama4:\nchat_template: llama4\ndatasets:\n - path: Nanobit/text-tools-2k-test\n type: chat_template\n # field_tools: tools # default is `tools`\n\n\n\n\n\n\nTip\n\n\n\nLook into the chat_template you are using to see if it supports tools and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the tool or ipython role for llama4 template.\n\n\n\n\n\n(Advanced) Using fine-grained control over tokens and turns to train in a conversation\nFor a data sample that looks like:\n\n\ndata.jsonl\n\n{\n \"conversations\": [\n {\"from\": \"system\", \"value\": \"You are an AI assistant.\", \"train\": false},\n {\"from\": \"human\", \"value\": \"Hello\", \"train\": false},\n {\"from\": \"assistant\", \"value\": \"Hello\", \"train\": true},\n {\"from\": \"human\", \"value\": \"How are you?\", \"train\": true},\n {\n \"from\": \"assistant\",\n \"value\": \"I'm doing very well, thank you!\",\n \"train_detail\": [\n {\"begin_offset\": 0, \"end_offset\": 8, \"train\": false},\n {\"begin_offset\": 9, \"end_offset\": 18, \"train\": true},\n {\"begin_offset\": 19, \"end_offset\": 30, \"train\": false},\n ],\n },\n {\n \"from\": \"human\",\n \"value\": \"I'm doing very well, thank you!\",\n \"train\": true,\n },\n {\"from\": \"assistant\", \"value\": \"Hi there!\", \"train\": true}\n ]\n}\n\nThe configuration would look like:\ndatasets:\n - path: ...\n type: chat_template\n chat_template: tokenizer_default\n field_messages: conversations\n message_property_mappings:\n role: from\n content: value\n roles_to_train: []\n train_on_eos: turn\n message_field_training: train\n message_field_training_detail: train_detail\n\n\n\n\n\n\nTip\n\n\n\nIt is not necessary to set both message_field_training and message_field_training_detail at once.\n\n\n\n\n\n(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.\ndatasets:\n - path: ...\n type: chat_template\n chat_template: qwen3\n split_thinking: true\nFor example, a content can look like:\n{\n \"content\": \"<think>Some thinking outputs</think>Output after thinking.\"\n}\nAfter split, it will look like:\n{\n \"reasoning_content\": \"Some thinking outputs\",\n \"content\": \"Output after thinking...\"\n}",
"crumbs": [
"Dataset Formats",
"Conversation"
]
},
{
"objectID": "docs/dataset-formats/conversation.html#chat_template",
"href": "docs/dataset-formats/conversation.html#chat_template",
"title": "Conversation",
"section": "",
"text": "Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizers template, a supported template, or custom jinja2.\n\n\ndata.jsonl\n\n{\"messages\": [{\"role\": \"...\", \"content\": \"...\"}, {\"role\": \"...\", \"content\": \"...\"}, ...]}\n\nSee configs for full configs and supported templates.\n\n\nMost configs can be adapted as follows:\n# old\nchat_template: chatml\ndatasets:\n - path: ...\n type: sharegpt\n conversation: chatml\n\n# new (if using tokenizer's chat_template)\ndatasets:\n - path: ...\n type: chat_template\n\n field_messages: conversations\n message_property_mappings:\n role: from\n content: value\n\n# new (if setting a new chat_template like chatml, gemma, etc)\nchat_template: chatml\ndatasets:\n - path: ...\n type: chat_template\n\n field_messages: conversations\n message_property_mappings:\n role: from\n content: value\nWe recommend checking the below examples for other usecases.\n\n\n\n\n\n(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.\ndatasets:\n - path: ...\n type: chat_template\n roles_to_train:\n train_on_eos:\n\n\n\n\n\n\nTip\n\n\n\nIf you receive an error like “chat_template choice is tokenizer_default but tokenizers chat_template is null.”, it means the tokenizer does not have a default chat_template. Follow the examples below instead to set a custom chat_template.\n\n\n\n\n\nUsing the gemma chat template to override the tokenizer_config.jsons chat template on OpenAI messages format, training on all assistant messages.\nchat_template: gemma # this overwrites the tokenizer's chat_template\ndatasets:\n - path: ...\n type: chat_template\n roles_to_train: [\"assistant\"] # default value\n\n\n\n\n\n\nNote\n\n\n\nIf you want to use built-in chat_template, use chat_template: tokenizer_default (this is set by default).\n\n\n\n\n\nUsing the tokenizer_config.jsons chat template or chatml as fallback if the formers chat template does not exist, on OpenAI messages format, training on all assistant messages.\nchat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template\ndatasets:\n - path: ...\n type: chat_template\n\n\n\nUsing a custom jinja template on OpenAI messages format, training on all assistant messages.\n# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty\nchat_template_jinja: \"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\\n' + message['content'] + '<|end|>' + '\\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\\n' + message['content'] + '<|end|>' + '\\n' + '<|assistant|>' + '\\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\\n'}}{% endif %}{% endfor %}\"\n\ndatasets:\n - path: ...\n type: chat_template\n\n\n\n\n\n\nImportant\n\n\n\nPlease make sure that your tokenizer.eos_token is same as EOS (End-of-Sequence) token in template. Otherwise, set eos_token under special_tokens:.\n\n\n\n\n\n\nIf you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the eot_tokens: config. The handling of EOT tokens follows train_on_eos: which defaults to turn.\n\neot_tokens:\n - \"[/INST]\"\n # - \"[/SYSTEM_PROMPT]\"\n\ndatasets:\n - path: ...\n type: chat_template\n\n # optional\n train_on_eot: turn # defaults read from train_on_eos (which defaults to turn)\n\n\n\n\n\n\nTip\n\n\n\nSee config documentation for detailed explanations of “turn”, “last”, and “all” options for training on tokens.\n\n\n\n\n\n\n\n\nNote\n\n\n\nUsing eot_tokens requires each token that exists in chat_template to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.\nYou can add those tokens as new tokens under tokens: or (recommended) override unused added_tokens via added_tokens_overrides:. See config for more details.\n\n\n\nContinuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set train_on_eos: last.\n\neot_tokens:\n - \"[/INST]\"\n # ...\n\ndatasets:\n - path: ...\n type: chat_template\n\n train_on_eos: last\n train_on_eot: turn\n\n\n\n\n\n\nTip\n\n\n\nIf EOS token only appears at the end of a prompt, train_on_eos: last is equivalent to train_on_eos: turn. Therefore, generally, you can leave them to their defaults and omit them.\n\n\n\n\n\nInstead of passing tools via the system prompt, an alternative method would be to have the tools in a separate column and loaded via chat_template to let the template dynamically build it.\n{\n \"tools\": [\n {\n \"type\": \"...\",\n \"function\": {\n \"name\": \"...\",\n \"description\": \"...\",\n \"parameters\": {\n \"type\": \"...\",\n \"properties\": {\n // ...\n },\n \"required\": [\"...\"],\n },\n },\n },\n ],\n \"messages\": [\n // ...\n {\n \"role\": \"assistant\", // call the function via assistant\n \"tool_calls\": [\n {\n \"id\": \"...\", // required only for mistral\n \"type\": \"function\",\n \"function\": {\n \"name\": \"...\",\n \"arguments\": {\n \"...\": \"...\",\n }\n }\n }\n ]\n },\n {\n \"role\": \"tool\",\n \"tool_call_id\": \"...\", // required only for mistral\n \"name\": \"...\",\n \"content\": \"...\"\n },\n ],\n}\n\n\n\n\n\n\nNote\n\n\n\nTools need to follow JSON schema.\n\n\n\n\n\n\n\n\nWarning\n\n\n\nIf you have tool arguments with same name but different dtypes (like \"time\": string and \"time\": number), please save arguments: as JSON string to prevent datasets from having casting issues.\n\"arguments\": \"{\\\"...\\\": \\\"...\\\"}\"\nThe same is applicable for tool parameters.\n\"parameters\": \"{\\\"...\\\": \\\"...\\\"}\"\n\n\nExample config for Llama4:\nchat_template: llama4\ndatasets:\n - path: Nanobit/text-tools-2k-test\n type: chat_template\n # field_tools: tools # default is `tools`\n\n\n\n\n\n\nTip\n\n\n\nLook into the chat_template you are using to see if it supports tools and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the tool or ipython role for llama4 template.\n\n\n\n\n\n(Advanced) Using fine-grained control over tokens and turns to train in a conversation\nFor a data sample that looks like:\n\n\ndata.jsonl\n\n{\n \"conversations\": [\n {\"from\": \"system\", \"value\": \"You are an AI assistant.\", \"train\": false},\n {\"from\": \"human\", \"value\": \"Hello\", \"train\": false},\n {\"from\": \"assistant\", \"value\": \"Hello\", \"train\": true},\n {\"from\": \"human\", \"value\": \"How are you?\", \"train\": true},\n {\n \"from\": \"assistant\",\n \"value\": \"I'm doing very well, thank you!\",\n \"train_detail\": [\n {\"begin_offset\": 0, \"end_offset\": 8, \"train\": false},\n {\"begin_offset\": 9, \"end_offset\": 18, \"train\": true},\n {\"begin_offset\": 19, \"end_offset\": 30, \"train\": false},\n ],\n },\n {\n \"from\": \"human\",\n \"value\": \"I'm doing very well, thank you!\",\n \"train\": true,\n },\n {\"from\": \"assistant\", \"value\": \"Hi there!\", \"train\": true}\n ]\n}\n\nThe configuration would look like:\ndatasets:\n - path: ...\n type: chat_template\n chat_template: tokenizer_default\n field_messages: conversations\n message_property_mappings:\n role: from\n content: value\n roles_to_train: []\n train_on_eos: turn\n message_field_training: train\n message_field_training_detail: train_detail\n\n\n\n\n\n\nTip\n\n\n\nIt is not necessary to set both message_field_training and message_field_training_detail at once.\n\n\n\n\n\n(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.\ndatasets:\n - path: ...\n type: chat_template\n chat_template: qwen3\n split_thinking: true\nFor example, a content can look like:\n{\n \"content\": \"<think>Some thinking outputs</think>Output after thinking.\"\n}\nAfter split, it will look like:\n{\n \"reasoning_content\": \"Some thinking outputs\",\n \"content\": \"Output after thinking...\"\n}",
"crumbs": [
"Dataset Formats",
"Conversation"
]
},
{
"objectID": "docs/dataset-formats/conversation.html#sharegpt",
"href": "docs/dataset-formats/conversation.html#sharegpt",
"title": "Conversation",
"section": "sharegpt",
"text": "sharegpt\n\n\n\n\n\n\nImportant\n\n\n\nShareGPT is deprecated!. Please see chat_template section.",
"crumbs": [
"Dataset Formats",
"Conversation"
]
},
{
"objectID": "docs/dataset-formats/conversation.html#pygmalion",
"href": "docs/dataset-formats/conversation.html#pygmalion",
"title": "Conversation",
"section": "pygmalion",
"text": "pygmalion\n\n\ndata.jsonl\n\n{\"conversations\": [{\"role\": \"...\", \"value\": \"...\"}]}",
"crumbs": [
"Dataset Formats",
"Conversation"
]
},
{
"objectID": "docs/dataset-formats/pretraining.html",
"href": "docs/dataset-formats/pretraining.html",
"title": "Pre-training",
"section": "",
"text": "Note\n\n\n\nPre-training documentation has been consolidated:\n\nStreaming pretraining (large datasets): See Streaming Datasets\nNon-streaming pretraining (type: completion): See Dataset Formats",
"crumbs": [
"Dataset Formats",
"Pre-training"
]
},
{
"objectID": "docs/dataset-formats/index.html",
"href": "docs/dataset-formats/index.html",
"title": "Dataset Formats",
"section": "",
"text": "Axolotl is a training framework that aims to make the process convenient yet flexible to users by simply passing a config yaml file.\nAs there are a lot of available options in Axolotl, this guide aims to provide an simplify the user experience to choosing the proper choice.\nAxolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.",
"crumbs": [
"Dataset Formats"
]
},
{
"objectID": "docs/dataset-formats/index.html#pre-training",
"href": "docs/dataset-formats/index.html#pre-training",
"title": "Dataset Formats",
"section": "Pre-training",
"text": "Pre-training\nPre-training trains on raw text corpora with no input masking. The dataset format is simple:\n{\"text\": \"first row\"}\n{\"text\": \"second row\"}\nAxolotl supports two approaches:\n\nStreaming (large datasets)\nFor large corpora that dont fit in memory, use pretraining_dataset with streaming. Data is tokenized on-demand during training.\npretraining_dataset:\n - path: HuggingFaceFW/fineweb-edu\n type: pretrain\n text_column: text\n split: train\n\n\n\n\n\n\nImportant\n\n\n\nStreaming requires max_steps in your config — Axolotl cannot infer the dataset size. One step = sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus tokens.\n\n\nSee Streaming Datasets for full configuration details.\n\n\nNon-streaming (smaller datasets)\nFor datasets that fit in memory, use type: completion under datasets:. The entire dataset is pre-tokenized before training, which can be done on a CPU-only machine.\ndatasets:\n - path: my_corpus\n type: completion\n\n\n\n\n\n\nNote\n\n\n\nWith completion, texts exceeding sequence_len are split into multiple samples automatically.",
"crumbs": [
"Dataset Formats"
]
},
{
"objectID": "docs/dataset-formats/index.html#supervised-fine-tuning-sft",
"href": "docs/dataset-formats/index.html#supervised-fine-tuning-sft",
"title": "Dataset Formats",
"section": "Supervised fine-tuning (SFT)",
"text": "Supervised fine-tuning (SFT)\nSupervised fine-tuning is the process of training models to respond to an instruction or chat input.\nAs there are a wide variety of dataset formats, Axolotl tries to support a majority of the formats available in public datasets.\nAxolotl provides four approaches for loading datasets, however, its easier to work backwards from the dataset you have available to figure out which approach to use.\nA flow chart is as follows:\n\nDo you already have the dataset tokenized? If yes, check Pre-Tokenized Dataset.\nDo you want to format the dataset yourself and manually choose each section to mask? If yes, check Template Free Dataset\nIs your dataset in a “conversation” format, containing a list[messages]? If yes, check Conversation Dataset\nIs your dataset in an “instruct” format, containing { instruction, response }? If yes, check Instruction Dataset\n\nIf you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.\n\n\n\n\n\n\nTip\n\n\n\nYou can mix and match within each approach or across approaches to train a model on a variety of datasets.\n\n\n\nPre-Tokenized Dataset\nWe suggest this approach when you want to bring your own tokenized dataset.\nAxolotl expects the dataset to have three keys:\n\ninput_ids: from tokenizing formatted prompt\nattention_mask: for masking padding. If you dont add padding, it would be equal to len(input_ids) * [1]\nlabels: this is the same as input_ids, however, if you want to mask certain tokens, you would set those indices to -100.\n\n\n\n\n\n\n\nTip\n\n\n\nMake sure to add BOS/EOS tokens to your prompt and mask it appropriately.\n\n\nA config for this would look like:\ndatasets:\n - path: A.jsonl\n type:\n\n\n\n\n\n\nNote\n\n\n\ntype: is empty!\n\n\nReference: Pre-Tokenized Dataset Documentation.\n\n\nTemplate Free Dataset\nWe reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldnt suffice.\nIn the example below, you could see that there is no proper structure. At the same time, its very flexible as there are no constraints on how your prompt can look.\n{\n \"segments\": [\n {\n \"label\": true,\n \"text\": \"<s>Hello\\n\"\n },\n {\n \"label\": true,\n \"text\": \"hi there!. \"\n },\n {\n \"label\": false,\n \"text\": \"goodbye \"\n },\n {\n \"label\": true,\n \"text\": \"farewell</s>\"\n }\n ]\n}\nEach prompt must be have a key called segments which is a list of { text, label }.\ndatasets:\n - path: A.jsonl\n type: input_output\nReference: Template Free Documentation.\n\n\nConversation Dataset\nconversation messages are a list of messages which usually contain a role and content key.\n\n\n\n\n\n\nTip\n\n\n\nFun fact: Axolotl synonymously refers to “chat” messages as conversation messages due to how FastChat initially used this term to build a widely used fastchat conversation method for formatting chat messages prior to the creation of chat_templates.\n\n\n\nWhat are chat_templates?\nThe current most popular and convenient method for inference is to use chat_templates for formatting prompts. Axolotl supports using chat_templates for training to ensure that the model performs in the same environment as in inference.\nHeres a quick rundown on chat_template: A chat_template is a Jinja2 template which formats a list of messages into a prompt.\nAn example of a prompt formatted into a popular template called ChatML can be seen below:\nSingle prompt (pretty-printed):\n{\n \"messages\": [\n {\n \"role\": \"user\",\n \"content\": \"Hi\"\n },\n {\n \"role\": \"assistant\",\n \"content\": \"How can I help you?\"\n },\n {\n \"role\": \"user\",\n \"content\": \"Can you add 3+5?\"\n },\n {\n \"role\": \"assistant\",\n \"content\": \"The answer is 8.\"\n }\n ]\n}\nThe ChatML template is as follows:\n{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}\nThe above prompt formatted into this template will result in:\n<|im_start|>user\nHi<|im_end|>\n<|im_start|>assistant\nHow can I help you?<|im_end|>\n<|im_start|>user\nCan you add 3+5?<|im_end|>\n<|im_start|>assistant\nThe answer is 8.<|im_end|>\nBy using delimiters (<|im_start|> and <|im_end|>), a prompt separates different speakers which helps the model identify which portion belongs to whom.\n\n\nCommon Conversation Dataset formats\nOlder conversation datasets with the following format are colloquially called sharegpt datasets.\n{\"conversations\": [{\"from\": \"...\", \"value\": \"...\"}]}\nNewer conversation datasets usually follow the OpenAI format.\n{\"messages\": [{\"role\": \"...\", \"content\": \"...\"}]}\nAxolotl supports both as well as allowing customization of any kind of key.\n\n\nChat Template Usage\nTo properly use this method, it is important to identify three things:\n\nWhich chat_template would you use?\nWhat are the keys in your dataset, and what are the possible roles? For example, in OpenAI format, the keys would be messages, role, and content, respectively, whereas the possible roles are system, user, and assistant.\nWhat do you want to mask? For instance, only assistant messages, only last message, or nothing.\n\n\nChoosing a chat_template\nThere are a lot of chat_templates out there. Axolotl supports the common ones: supported chat templates. For example, to use ChatML, it would be chat_template: chatml.\nHowever, it is also possible to use the already configured template within the tokenizer by specifying chat_template: tokenizer_default. If you want a fallback (in case some tokenizer does not have it pre-configured), you can do chat_template: tokenizer_default_fallback_chatml to fallback to the ChatML template if a tokenizer template was not found.\nOne last but powerful approach is to bring your own template. This can be set via:\nchat_template_jinja: # your template\n\n\nSetting chat_template dataset keys\nWe currently default to OpenAI format for dataset keys, so if thats your current dataset format, theres nothing to do here.\nIf your dataset format is different, here are the keys you should check (with their defaults):\ndatasets:\n ...\n field_messages: messages # this should point to the key containing the list of conversations\n message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template\n role: role\n content: content\nIn some chat_templates (e.g. Gemma), the roles are hardcoded to user and assistant. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a KeyError, it would be necessary to add mapping for your roles. Here is an example of how it would look like:\ndatasets:\n ...\n roles:\n assistant:\n - gpt\n - model\n user:\n - human\nIn the example above, all gpt and model values are converted to assistant. All human values are converted to user.\n\n\nHandling masking\nThe common use case for chat_template is for chat messages, therefore, it is common to mask all non-assistant messages. Assistant messages refer to the bot messages that you want the model to learn on.\nTo train on all assistant messages, you would set the following configs.\ndatasets:\n ...\n roles_to_train: [\"assistant\"]\n train_on_eos: \"turn\"\nThe train_on_eos config means that it would mask all EOS tokens for turns that arent assistant-turns. The other options are: all and last to choose which EOS to train on.\nPerhaps, you want to train on assistant and narrator roles, you can simply add narrator to the list of roles_to_train. You would also need to add it to the mapping of roles above.\ndatasets:\n ...\n roles_to_train: [\"assistant\", \"narrator\"]\n roles:\n assistant:\n - gpt\n - model\n user:\n - human\n narrator: [\"narrator\"]\n\n\n\n\n\n\nTip\n\n\n\nAs chat_templates may use hardcoded EOS/EOT tokens that are different from the tokenizers EOS, it is highly recommended to set them. For example, ChatML uses <|im_end|> to end turns.\nspecial_tokens:\n eos_token: <|im_end|>\n\n\n\n\nApplying chat_template\nOnce all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset.\ndatasets:\n - path: A.jsonl\n type: chat_template\n\n # step 1\n chat_template: chatml\n\n # step 2\n field_messages: messages\n message_property_mappings:\n role: role\n content: content\n\n roles:\n assistant:\n - gpt\n - model\n - assistant\n user:\n - human\n - user\n\n # step 3\n roles_to_train: [\"assistant\"]\n train_on_eos: \"turn\"\n\nspecial_tokens:\n eos_token: <|im_end|>\nIf this config were to be applied to the sample dataset above, the output would look as such (which can be retrieved via axolotl preprocess config.yaml --debug):\n<|im_start|>(-100, 128256) user(-100, 882)\n(-100, 198) Hi(-100, 13347) <|im_end|>(-100, 128257)\n(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)\n(-100, 198) How(4438, 4438) can(649, 649) I(358, 358) help(1520, 1520) you(499, 499) ?(30, 30) <|im_end|>(128257, 128257)\n(-100, 198) <|im_start|>(-100, 128256) user(-100, 882)\n(-100, 198) Can(-100, 6854) you(-100, 499) add(-100, 923) (-100, 220) 3(-100, 18) +(-100, 10) 5(-100, 20) ?(-100, 30) <|im_end|>(-100, 128257)\n(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)\n(-100, 198) The(791, 791) answer(4320, 4320) is(374, 374) (220, 220) 8(23, 23) .(13, 13) <|im_end|>(128257, 128257)\n(-100, 198)\nThe first number refers to the label, the second refers to the token_id. For example, -100 labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the token_id.\n\n\n\n\n\n\nNote\n\n\n\nIf during preprocess, there are a lot of warnings of Could not find content __ boundary, please check the FAQ section for chat_templates.\n\n\n\n\n\nReference\nPlease see docs here.\n\n\n\nInstruction Dataset\nInstruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.\nAn example is of a common format called Alpaca:\n{\"instruction\": \"...\", \"input\": \"...\", \"output\": \"...\"}\nUsing those keys, a prompt can be built based on it.\nBelow is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n{output}\nThis can be configured as such:\ndatasets:\n - path: A.jsonl\n type: alpaca\nAxolotl supports many kinds of instruction dataset. All of them can be found in the Instruction Dataset Documentation with their respective type and sample row format.\n\nCustom Instruct Prompt Format\nDue to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.\nIn the example below, a sample row is used to output in mistral_v1 format.\n{\"input\": \"...\", \"output\": \"...\"}\ndatasets:\n - path: repo\n type:\n system_prompt: \"\"\n\n field_system:\n field_instruction: input\n field_input:\n field_output: output\n\n # multi-line example with input\n format: |-\n [INST] {instruction} {input} [/INST]\n\n # single-line example without input\n no_input_format: \"[INST] {instruction} [/INST]\"\nThe config sets that the field_instruction is actually named input, and the field_input is empty as we dont have an input in this sample. Generally, instruction can be thought as the question to the model, and input as the additional information with output being the response. It is not necessary to have an input nor system. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.\nReference: Custom Instruct Prompt Format Documentation.",
"crumbs": [
"Dataset Formats"
]
},
{
"objectID": "docs/dataset-formats/index.html#reinforcement-learning-from-human-feedback-rlhf",
"href": "docs/dataset-formats/index.html#reinforcement-learning-from-human-feedback-rlhf",
"title": "Dataset Formats",
"section": "Reinforcement Learning from Human Feedback (RLHF)",
"text": "Reinforcement Learning from Human Feedback (RLHF)\nAs there are multiple RLHF methods with their own dataset requirements. Please see RLHF documentation for more detail.",
"crumbs": [
"Dataset Formats"
]
},
{
"objectID": "docs/api/cli.args.html",
"href": "docs/api/cli.args.html",
"title": "cli.args",
"section": "",
"text": "cli.args\nModule for axolotl CLI command arguments.\n\n\n\n\n\nName\nDescription\n\n\n\n\nEvaluateCliArgs\nDataclass with CLI arguments for axolotl evaluate command.\n\n\nInferenceCliArgs\nDataclass with CLI arguments for axolotl inference command.\n\n\nPreprocessCliArgs\nDataclass with CLI arguments for axolotl preprocess command.\n\n\nQuantizeCliArgs\nDataclass with CLI arguments for axolotl quantize command.\n\n\nTrainerCliArgs\nDataclass with CLI arguments for axolotl train command.\n\n\nVllmServeCliArgs\nDataclass with CLI arguments for axolotl vllm-serve command.\n\n\n\n\n\ncli.args.EvaluateCliArgs(\n debug=False,\n debug_text_only=False,\n debug_num_examples=0,\n)\nDataclass with CLI arguments for axolotl evaluate command.\n\n\n\ncli.args.InferenceCliArgs(prompter=None)\nDataclass with CLI arguments for axolotl inference command.\n\n\n\ncli.args.PreprocessCliArgs(\n debug=False,\n debug_text_only=False,\n debug_num_examples=1,\n prompter=None,\n download=True,\n iterable=False,\n)\nDataclass with CLI arguments for axolotl preprocess command.\n\n\n\ncli.args.QuantizeCliArgs(\n base_model=None,\n weight_dtype=None,\n activation_dtype=None,\n quantize_embedding=None,\n group_size=None,\n output_dir=None,\n hub_model_id=None,\n)\nDataclass with CLI arguments for axolotl quantize command.\n\n\n\ncli.args.TrainerCliArgs(\n debug=False,\n debug_text_only=False,\n debug_num_examples=0,\n prompter=None,\n shard=False,\n)\nDataclass with CLI arguments for axolotl train command.\n\n\n\ncli.args.VllmServeCliArgs(\n tensor_parallel_size=None,\n data_parallel_size=None,\n host=None,\n port=None,\n gpu_memory_utilization=None,\n dtype=None,\n max_model_len=None,\n enable_prefix_caching=None,\n serve_module=None,\n enable_reasoning=None,\n reasoning_parser=None,\n)\nDataclass with CLI arguments for axolotl vllm-serve command."
},
{
"objectID": "docs/api/cli.args.html#classes",
"href": "docs/api/cli.args.html#classes",
"title": "cli.args",
"section": "",
"text": "Name\nDescription\n\n\n\n\nEvaluateCliArgs\nDataclass with CLI arguments for axolotl evaluate command.\n\n\nInferenceCliArgs\nDataclass with CLI arguments for axolotl inference command.\n\n\nPreprocessCliArgs\nDataclass with CLI arguments for axolotl preprocess command.\n\n\nQuantizeCliArgs\nDataclass with CLI arguments for axolotl quantize command.\n\n\nTrainerCliArgs\nDataclass with CLI arguments for axolotl train command.\n\n\nVllmServeCliArgs\nDataclass with CLI arguments for axolotl vllm-serve command.\n\n\n\n\n\ncli.args.EvaluateCliArgs(\n debug=False,\n debug_text_only=False,\n debug_num_examples=0,\n)\nDataclass with CLI arguments for axolotl evaluate command.\n\n\n\ncli.args.InferenceCliArgs(prompter=None)\nDataclass with CLI arguments for axolotl inference command.\n\n\n\ncli.args.PreprocessCliArgs(\n debug=False,\n debug_text_only=False,\n debug_num_examples=1,\n prompter=None,\n download=True,\n iterable=False,\n)\nDataclass with CLI arguments for axolotl preprocess command.\n\n\n\ncli.args.QuantizeCliArgs(\n base_model=None,\n weight_dtype=None,\n activation_dtype=None,\n quantize_embedding=None,\n group_size=None,\n output_dir=None,\n hub_model_id=None,\n)\nDataclass with CLI arguments for axolotl quantize command.\n\n\n\ncli.args.TrainerCliArgs(\n debug=False,\n debug_text_only=False,\n debug_num_examples=0,\n prompter=None,\n shard=False,\n)\nDataclass with CLI arguments for axolotl train command.\n\n\n\ncli.args.VllmServeCliArgs(\n tensor_parallel_size=None,\n data_parallel_size=None,\n host=None,\n port=None,\n gpu_memory_utilization=None,\n dtype=None,\n max_model_len=None,\n enable_prefix_caching=None,\n serve_module=None,\n enable_reasoning=None,\n reasoning_parser=None,\n)\nDataclass with CLI arguments for axolotl vllm-serve command."
},
{
"objectID": "docs/api/prompt_strategies.orcamini.html",
"href": "docs/api/prompt_strategies.orcamini.html",
"title": "prompt_strategies.orcamini",
"section": "",
"text": "prompt_strategies.orcamini\nPrompt Strategy for finetuning Orca Mini (v2) models\nsee also https://huggingface.co/psmathur/orca_mini_v2_7b for more information\nUse dataset type: orcamini in conig.yml to use this prompt style.\nCompared to the alpaca_w_system.open_orca dataset type,\nthis one specifies the system prompt with “### System:”.\nNot suited/tested for multiple-turn conversations without further adjustments.\n\n\n\n\n\nName\nDescription\n\n\n\n\nOrcaMiniPrompter\nAdjusted Prompter for Orca Mini (v2) datasets\n\n\n\n\n\nprompt_strategies.orcamini.OrcaMiniPrompter(\n prompt_style=PromptStyle.INSTRUCT.value,\n)\nAdjusted Prompter for Orca Mini (v2) datasets"
},
{
"objectID": "docs/api/prompt_strategies.orcamini.html#classes",
"href": "docs/api/prompt_strategies.orcamini.html#classes",
"title": "prompt_strategies.orcamini",
"section": "",
"text": "Name\nDescription\n\n\n\n\nOrcaMiniPrompter\nAdjusted Prompter for Orca Mini (v2) datasets\n\n\n\n\n\nprompt_strategies.orcamini.OrcaMiniPrompter(\n prompt_style=PromptStyle.INSTRUCT.value,\n)\nAdjusted Prompter for Orca Mini (v2) datasets"
},
{
"objectID": "docs/api/cli.preprocess.html",
"href": "docs/api/cli.preprocess.html",
"title": "cli.preprocess",
"section": "",
"text": "cli.preprocess\nCLI to run preprocessing of a dataset.\n\n\n\n\n\nName\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_preprocess.\n\n\ndo_preprocess\nPreprocesses dataset specified in axolotl config.\n\n\n\n\n\ncli.preprocess.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls do_preprocess.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.preprocess.do_preprocess(cfg, cli_args)\nPreprocesses dataset specified in axolotl config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nPreprocessCliArgs\nPreprocessing-specific CLI arguments.\nrequired"
},
{
"objectID": "docs/api/cli.preprocess.html#functions",
"href": "docs/api/cli.preprocess.html#functions",
"title": "cli.preprocess",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_preprocess.\n\n\ndo_preprocess\nPreprocesses dataset specified in axolotl config.\n\n\n\n\n\ncli.preprocess.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls do_preprocess.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.preprocess.do_preprocess(cfg, cli_args)\nPreprocesses dataset specified in axolotl config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nPreprocessCliArgs\nPreprocessing-specific CLI arguments.\nrequired"
},
{
"objectID": "docs/api/utils.collators.core.html",
"href": "docs/api/utils.collators.core.html",
"title": "utils.collators.core",
"section": "",
"text": "utils.collators.core\nutils.collators.core\nbasic shared collator constants"
},
{
"objectID": "docs/api/prompt_strategies.dpo.llama3.html",
"href": "docs/api/prompt_strategies.dpo.llama3.html",
"title": "prompt_strategies.dpo.llama3",
"section": "",
"text": "prompt_strategies.dpo.llama3\nDPO strategies for llama-3 chat template\n\n\n\n\n\nName\nDescription\n\n\n\n\nargilla_chat\nfor argilla/dpo-mix-7k conversations\n\n\nicr\nchatml transforms for datasets with system, input, chosen, rejected\n\n\nintel\nFor Intel Orca DPO Pairs\n\n\nultra\nfor ultrafeedback binarized conversations\n\n\n\n\n\nprompt_strategies.dpo.llama3.argilla_chat(cfg, **kwargs)\nfor argilla/dpo-mix-7k conversations\n\n\n\nprompt_strategies.dpo.llama3.icr(cfg, **kwargs)\nchatml transforms for datasets with system, input, chosen, rejected\nex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs\n\n\n\nprompt_strategies.dpo.llama3.intel(cfg, **kwargs)\nFor Intel Orca DPO Pairs\n\n\n\nprompt_strategies.dpo.llama3.ultra(cfg, **kwargs)\nfor ultrafeedback binarized conversations"
},
{
"objectID": "docs/api/prompt_strategies.dpo.llama3.html#functions",
"href": "docs/api/prompt_strategies.dpo.llama3.html#functions",
"title": "prompt_strategies.dpo.llama3",
"section": "",
"text": "Name\nDescription\n\n\n\n\nargilla_chat\nfor argilla/dpo-mix-7k conversations\n\n\nicr\nchatml transforms for datasets with system, input, chosen, rejected\n\n\nintel\nFor Intel Orca DPO Pairs\n\n\nultra\nfor ultrafeedback binarized conversations\n\n\n\n\n\nprompt_strategies.dpo.llama3.argilla_chat(cfg, **kwargs)\nfor argilla/dpo-mix-7k conversations\n\n\n\nprompt_strategies.dpo.llama3.icr(cfg, **kwargs)\nchatml transforms for datasets with system, input, chosen, rejected\nex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs\n\n\n\nprompt_strategies.dpo.llama3.intel(cfg, **kwargs)\nFor Intel Orca DPO Pairs\n\n\n\nprompt_strategies.dpo.llama3.ultra(cfg, **kwargs)\nfor ultrafeedback binarized conversations"
},
{
"objectID": "docs/api/utils.schemas.enums.html",
"href": "docs/api/utils.schemas.enums.html",
"title": "utils.schemas.enums",
"section": "",
"text": "utils.schemas.enums\nEnums for Axolotl input config\n\n\n\n\n\nName\nDescription\n\n\n\n\nChatTemplate\nChat templates configuration subset\n\n\nCustomSupportedOptimizers\nCustom supported optimizers\n\n\nRLType\nRL trainer type configuration subset\n\n\nRingAttnFunc\nEnum class for supported ring-flash-attn implementations\n\n\n\n\n\nutils.schemas.enums.ChatTemplate()\nChat templates configuration subset\n\n\n\nutils.schemas.enums.CustomSupportedOptimizers()\nCustom supported optimizers\n\n\n\nutils.schemas.enums.RLType()\nRL trainer type configuration subset\n\n\n\nutils.schemas.enums.RingAttnFunc()\nEnum class for supported ring-flash-attn implementations"
},
{
"objectID": "docs/api/utils.schemas.enums.html#classes",
"href": "docs/api/utils.schemas.enums.html#classes",
"title": "utils.schemas.enums",
"section": "",
"text": "Name\nDescription\n\n\n\n\nChatTemplate\nChat templates configuration subset\n\n\nCustomSupportedOptimizers\nCustom supported optimizers\n\n\nRLType\nRL trainer type configuration subset\n\n\nRingAttnFunc\nEnum class for supported ring-flash-attn implementations\n\n\n\n\n\nutils.schemas.enums.ChatTemplate()\nChat templates configuration subset\n\n\n\nutils.schemas.enums.CustomSupportedOptimizers()\nCustom supported optimizers\n\n\n\nutils.schemas.enums.RLType()\nRL trainer type configuration subset\n\n\n\nutils.schemas.enums.RingAttnFunc()\nEnum class for supported ring-flash-attn implementations"
},
{
"objectID": "docs/api/utils.lora.html",
"href": "docs/api/utils.lora.html",
"title": "utils.lora",
"section": "",
"text": "utils.lora\nmodule to get the state dict of a merged lora model\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_lora_merged_state_dict\nCreate and return a state_dict that has the LoRA deltas\n\n\n\n\n\nutils.lora.get_lora_merged_state_dict(model)\nCreate and return a state_dict that has the LoRA deltas\nmerged into the base models weights, without modifying model in place.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\ntorch.nn.Module\nA model that has LoRA/PEFT adapters attached.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\ndict\ndict\nA state_dict of the merged parameters."
},
{
"objectID": "docs/api/utils.lora.html#functions",
"href": "docs/api/utils.lora.html#functions",
"title": "utils.lora",
"section": "",
"text": "Name\nDescription\n\n\n\n\nget_lora_merged_state_dict\nCreate and return a state_dict that has the LoRA deltas\n\n\n\n\n\nutils.lora.get_lora_merged_state_dict(model)\nCreate and return a state_dict that has the LoRA deltas\nmerged into the base models weights, without modifying model in place.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\ntorch.nn.Module\nA model that has LoRA/PEFT adapters attached.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\ndict\ndict\nA state_dict of the merged parameters."
},
{
"objectID": "docs/api/common.datasets.html",
"href": "docs/api/common.datasets.html",
"title": "common.datasets",
"section": "",
"text": "common.datasets\nDataset loading utilities.\n\n\n\n\n\nName\nDescription\n\n\n\n\nTrainDatasetMeta\nDataclass with fields for training and validation datasets and metadata.\n\n\n\n\n\ncommon.datasets.TrainDatasetMeta(\n train_dataset,\n eval_dataset=None,\n total_num_steps=None,\n)\nDataclass with fields for training and validation datasets and metadata.\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nload_datasets\nLoads one or more training or evaluation datasets, calling\n\n\nload_preference_datasets\nLoads one or more training or evaluation datasets for RL training using paired\n\n\nsample_dataset\nRandomly sample num_samples samples with replacement from dataset.\n\n\n\n\n\ncommon.datasets.load_datasets(cfg, cli_args=None, debug=False)\nLoads one or more training or evaluation datasets, calling\naxolotl.utils.data.prepare_datasets. Optionally, logs out debug information.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nPreprocessCliArgs | TrainerCliArgs | None\nCommand-specific CLI arguments.\nNone\n\n\ndebug\nbool\nWhether to print out tokenization of sample. This is duplicated in cfg and cli_args, but is kept due to use in our Colab notebooks.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nTrainDatasetMeta\nDataclass with fields for training and evaluation datasets and the computed total_num_steps.\n\n\n\n\n\n\n\ncommon.datasets.load_preference_datasets(cfg, cli_args=None)\nLoads one or more training or evaluation datasets for RL training using paired\npreference data, calling axolotl.utils.data.rl.prepare_preference_datasets.\nOptionally, logs out debug information.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nPreprocessCliArgs | TrainerCliArgs | None\nCommand-specific CLI arguments.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nTrainDatasetMeta\nDataclass with fields for training and evaluation datasets and the computed\n\n\n\nTrainDatasetMeta\ntotal_num_steps.\n\n\n\n\n\n\n\ncommon.datasets.sample_dataset(dataset, num_samples)\nRandomly sample num_samples samples with replacement from dataset."
},
{
"objectID": "docs/api/common.datasets.html#classes",
"href": "docs/api/common.datasets.html#classes",
"title": "common.datasets",
"section": "",
"text": "Name\nDescription\n\n\n\n\nTrainDatasetMeta\nDataclass with fields for training and validation datasets and metadata.\n\n\n\n\n\ncommon.datasets.TrainDatasetMeta(\n train_dataset,\n eval_dataset=None,\n total_num_steps=None,\n)\nDataclass with fields for training and validation datasets and metadata."
},
{
"objectID": "docs/api/common.datasets.html#functions",
"href": "docs/api/common.datasets.html#functions",
"title": "common.datasets",
"section": "",
"text": "Name\nDescription\n\n\n\n\nload_datasets\nLoads one or more training or evaluation datasets, calling\n\n\nload_preference_datasets\nLoads one or more training or evaluation datasets for RL training using paired\n\n\nsample_dataset\nRandomly sample num_samples samples with replacement from dataset.\n\n\n\n\n\ncommon.datasets.load_datasets(cfg, cli_args=None, debug=False)\nLoads one or more training or evaluation datasets, calling\naxolotl.utils.data.prepare_datasets. Optionally, logs out debug information.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nPreprocessCliArgs | TrainerCliArgs | None\nCommand-specific CLI arguments.\nNone\n\n\ndebug\nbool\nWhether to print out tokenization of sample. This is duplicated in cfg and cli_args, but is kept due to use in our Colab notebooks.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nTrainDatasetMeta\nDataclass with fields for training and evaluation datasets and the computed total_num_steps.\n\n\n\n\n\n\n\ncommon.datasets.load_preference_datasets(cfg, cli_args=None)\nLoads one or more training or evaluation datasets for RL training using paired\npreference data, calling axolotl.utils.data.rl.prepare_preference_datasets.\nOptionally, logs out debug information.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nPreprocessCliArgs | TrainerCliArgs | None\nCommand-specific CLI arguments.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nTrainDatasetMeta\nDataclass with fields for training and evaluation datasets and the computed\n\n\n\nTrainDatasetMeta\ntotal_num_steps.\n\n\n\n\n\n\n\ncommon.datasets.sample_dataset(dataset, num_samples)\nRandomly sample num_samples samples with replacement from dataset."
},
{
"objectID": "docs/api/monkeypatch.relora.html",
"href": "docs/api/monkeypatch.relora.html",
"title": "monkeypatch.relora",
"section": "",
"text": "monkeypatch.relora\nImplements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.\n\n\n\n\n\nName\nDescription\n\n\n\n\nReLoRACallback\nCallback to merge LoRA weights into the base model and save full-weight checkpoints\n\n\n\n\n\nmonkeypatch.relora.ReLoRACallback(cfg)\nCallback to merge LoRA weights into the base model and save full-weight checkpoints"
},
{
"objectID": "docs/api/monkeypatch.relora.html#classes",
"href": "docs/api/monkeypatch.relora.html#classes",
"title": "monkeypatch.relora",
"section": "",
"text": "Name\nDescription\n\n\n\n\nReLoRACallback\nCallback to merge LoRA weights into the base model and save full-weight checkpoints\n\n\n\n\n\nmonkeypatch.relora.ReLoRACallback(cfg)\nCallback to merge LoRA weights into the base model and save full-weight checkpoints"
},
{
"objectID": "docs/api/core.builders.base.html",
"href": "docs/api/core.builders.base.html",
"title": "core.builders.base",
"section": "",
"text": "core.builders.base\nBase class for trainer builder\n\n\n\n\n\nName\nDescription\n\n\n\n\nTrainerBuilderBase\nBase class for trainer builder.\n\n\n\n\n\ncore.builders.base.TrainerBuilderBase(cfg, model, tokenizer, processor=None)\nBase class for trainer builder.\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_post_trainer_create_callbacks\nCallbacks added after the trainer is created, usually b/c these need access to the trainer\n\n\n\n\n\ncore.builders.base.TrainerBuilderBase.get_post_trainer_create_callbacks(trainer)\nCallbacks added after the trainer is created, usually b/c these need access to the trainer"
},
{
"objectID": "docs/api/core.builders.base.html#classes",
"href": "docs/api/core.builders.base.html#classes",
"title": "core.builders.base",
"section": "",
"text": "Name\nDescription\n\n\n\n\nTrainerBuilderBase\nBase class for trainer builder.\n\n\n\n\n\ncore.builders.base.TrainerBuilderBase(cfg, model, tokenizer, processor=None)\nBase class for trainer builder.\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_post_trainer_create_callbacks\nCallbacks added after the trainer is created, usually b/c these need access to the trainer\n\n\n\n\n\ncore.builders.base.TrainerBuilderBase.get_post_trainer_create_callbacks(trainer)\nCallbacks added after the trainer is created, usually b/c these need access to the trainer"
},
{
"objectID": "docs/api/prompt_strategies.input_output.html",
"href": "docs/api/prompt_strategies.input_output.html",
"title": "prompt_strategies.input_output",
"section": "",
"text": "prompt_strategies.input_output\nModule for plain input/output prompt pairs\n\n\n\n\n\nName\nDescription\n\n\n\n\nRawInputOutputPrompter\nprompter for raw i/o data\n\n\nRawInputOutputStrategy\nPrompt Strategy class for input/output pairs\n\n\n\n\n\nprompt_strategies.input_output.RawInputOutputPrompter()\nprompter for raw i/o data\n\n\n\nprompt_strategies.input_output.RawInputOutputStrategy(\n *args,\n eos_token=None,\n **kwargs,\n)\nPrompt Strategy class for input/output pairs"
},
{
"objectID": "docs/api/prompt_strategies.input_output.html#classes",
"href": "docs/api/prompt_strategies.input_output.html#classes",
"title": "prompt_strategies.input_output",
"section": "",
"text": "Name\nDescription\n\n\n\n\nRawInputOutputPrompter\nprompter for raw i/o data\n\n\nRawInputOutputStrategy\nPrompt Strategy class for input/output pairs\n\n\n\n\n\nprompt_strategies.input_output.RawInputOutputPrompter()\nprompter for raw i/o data\n\n\n\nprompt_strategies.input_output.RawInputOutputStrategy(\n *args,\n eos_token=None,\n **kwargs,\n)\nPrompt Strategy class for input/output pairs"
},
{
"objectID": "docs/api/integrations.lm_eval.args.html",
"href": "docs/api/integrations.lm_eval.args.html",
"title": "integrations.lm_eval.args",
"section": "",
"text": "integrations.lm_eval.args\nModule for handling lm eval harness input arguments.\n\n\n\n\n\nName\nDescription\n\n\n\n\nLMEvalArgs\nInput args for lm eval harness\n\n\n\n\n\nintegrations.lm_eval.args.LMEvalArgs()\nInput args for lm eval harness"
},
{
"objectID": "docs/api/integrations.lm_eval.args.html#classes",
"href": "docs/api/integrations.lm_eval.args.html#classes",
"title": "integrations.lm_eval.args",
"section": "",
"text": "Name\nDescription\n\n\n\n\nLMEvalArgs\nInput args for lm eval harness\n\n\n\n\n\nintegrations.lm_eval.args.LMEvalArgs()\nInput args for lm eval harness"
},
{
"objectID": "docs/api/cli.inference.html",
"href": "docs/api/cli.inference.html",
"title": "cli.inference",
"section": "",
"text": "cli.inference\nCLI to run inference on a trained model.\n\n\n\n\n\nName\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_inference or do_inference_gradio.\n\n\ndo_inference\nRuns inference on the command line in a loop. User input is accepted, a chat\n\n\ndo_inference_gradio\nRuns inference in a Gradio interface. User input is accepted, a chat template is\n\n\nget_multi_line_input\nGets multi-line input from terminal.\n\n\n\n\n\ncli.inference.do_cli(config=Path('examples/'), gradio=False, **kwargs)\nParses axolotl config, CLI args, and calls do_inference or do_inference_gradio.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.inference.do_inference(cfg, cli_args)\nRuns inference on the command line in a loop. User input is accepted, a chat\ntemplate is (optionally) applied, and the model specified in the axolotl config is\nused to generate completions according to a default generation config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nInferenceCliArgs\nInference-specific CLI arguments.\nrequired\n\n\n\n\n\n\n\ncli.inference.do_inference_gradio(cfg, cli_args)\nRuns inference in a Gradio interface. User input is accepted, a chat template is\n(optionally) applied, and the model specified in the axolotl config is used to\ngenerate completions according to a default generation config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nInferenceCliArgs\nInference-specific CLI arguments.\nrequired\n\n\n\n\n\n\n\ncli.inference.get_multi_line_input()\nGets multi-line input from terminal.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nstr\nPossibly multi-line, possibly empty stdin input as a string."
},
{
"objectID": "docs/api/cli.inference.html#functions",
"href": "docs/api/cli.inference.html#functions",
"title": "cli.inference",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_inference or do_inference_gradio.\n\n\ndo_inference\nRuns inference on the command line in a loop. User input is accepted, a chat\n\n\ndo_inference_gradio\nRuns inference in a Gradio interface. User input is accepted, a chat template is\n\n\nget_multi_line_input\nGets multi-line input from terminal.\n\n\n\n\n\ncli.inference.do_cli(config=Path('examples/'), gradio=False, **kwargs)\nParses axolotl config, CLI args, and calls do_inference or do_inference_gradio.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.inference.do_inference(cfg, cli_args)\nRuns inference on the command line in a loop. User input is accepted, a chat\ntemplate is (optionally) applied, and the model specified in the axolotl config is\nused to generate completions according to a default generation config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nInferenceCliArgs\nInference-specific CLI arguments.\nrequired\n\n\n\n\n\n\n\ncli.inference.do_inference_gradio(cfg, cli_args)\nRuns inference in a Gradio interface. User input is accepted, a chat template is\n(optionally) applied, and the model specified in the axolotl config is used to\ngenerate completions according to a default generation config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nInferenceCliArgs\nInference-specific CLI arguments.\nrequired\n\n\n\n\n\n\n\ncli.inference.get_multi_line_input()\nGets multi-line input from terminal.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nstr\nPossibly multi-line, possibly empty stdin input as a string."
},
{
"objectID": "docs/api/monkeypatch.gradient_checkpointing.offload_disk.html",
"href": "docs/api/monkeypatch.gradient_checkpointing.offload_disk.html",
"title": "monkeypatch.gradient_checkpointing.offload_disk",
"section": "",
"text": "monkeypatch.gradient_checkpointing.offload_disk\nDISCO - DIsk-based Storage and Checkpointing with Optimized prefetching\n\n\n\n\n\nName\nDescription\n\n\n\n\nDisco\nDisco: DIsk-based Storage and Checkpointing with Optimized prefetching\n\n\nDiskOffloadManager\nManages offloaded tensors and handles prefetching in a separate thread.\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.Disco()\nDisco: DIsk-based Storage and Checkpointing with Optimized prefetching\nAdvanced disk-based gradient checkpointer with prefetching.\n\n\n\n\n\nName\nDescription\n\n\n\n\nbackward\nBackward pass that loads activations from disk with prefetching\n\n\nforward\nForward pass that offloads activations to disk asynchronously\n\n\nget_instance\nGet or create the offload manager\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.Disco.backward(\n ctx,\n *grad_outputs,\n)\nBackward pass that loads activations from disk with prefetching\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.Disco.forward(\n ctx,\n forward_function,\n hidden_states,\n *args,\n prefetch_size=1,\n prefetch_to_gpu=True,\n save_workers=4,\n)\nForward pass that offloads activations to disk asynchronously\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.Disco.get_instance(\n prefetch_size=1,\n prefetch_to_gpu=True,\n save_workers=4,\n)\nGet or create the offload manager\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager(\n prefetch_size=3,\n prefetch_to_gpu=True,\n save_workers=4,\n)\nManages offloaded tensors and handles prefetching in a separate thread.\nIncludes synchronization to prevent race conditions.\n\n\n\n\n\nName\nDescription\n\n\n\n\ncleanup\nClean up all temp files and stop prefetch thread with proper synchronization\n\n\ncleanup_tensor\nClean up a specific tensor file after its been used\n\n\nload_tensor\nLoad tensor from disk or prefetch cache with proper synchronization\n\n\nsave_tensor\nSave tensor to disk asynchronously and return file path with thread-safe operations\n\n\ntrigger_prefetch\nTrigger prefetching of the next N tensors with proper synchronization\n\n\nwait_for_save\nWait for a tensor to be saved to disk\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.cleanup()\nClean up all temp files and stop prefetch thread with proper synchronization\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.cleanup_tensor(\n file_path,\n)\nClean up a specific tensor file after its been used\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.load_tensor(\n file_path,\n target_device='cuda',\n)\nLoad tensor from disk or prefetch cache with proper synchronization\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.save_tensor(\n tensor,\n)\nSave tensor to disk asynchronously and return file path with thread-safe operations\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.trigger_prefetch(\n n=None,\n)\nTrigger prefetching of the next N tensors with proper synchronization\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.wait_for_save(\n file_path,\n timeout=None,\n)\nWait for a tensor to be saved to disk"
},
{
"objectID": "docs/api/monkeypatch.gradient_checkpointing.offload_disk.html#classes",
"href": "docs/api/monkeypatch.gradient_checkpointing.offload_disk.html#classes",
"title": "monkeypatch.gradient_checkpointing.offload_disk",
"section": "",
"text": "Name\nDescription\n\n\n\n\nDisco\nDisco: DIsk-based Storage and Checkpointing with Optimized prefetching\n\n\nDiskOffloadManager\nManages offloaded tensors and handles prefetching in a separate thread.\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.Disco()\nDisco: DIsk-based Storage and Checkpointing with Optimized prefetching\nAdvanced disk-based gradient checkpointer with prefetching.\n\n\n\n\n\nName\nDescription\n\n\n\n\nbackward\nBackward pass that loads activations from disk with prefetching\n\n\nforward\nForward pass that offloads activations to disk asynchronously\n\n\nget_instance\nGet or create the offload manager\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.Disco.backward(\n ctx,\n *grad_outputs,\n)\nBackward pass that loads activations from disk with prefetching\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.Disco.forward(\n ctx,\n forward_function,\n hidden_states,\n *args,\n prefetch_size=1,\n prefetch_to_gpu=True,\n save_workers=4,\n)\nForward pass that offloads activations to disk asynchronously\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.Disco.get_instance(\n prefetch_size=1,\n prefetch_to_gpu=True,\n save_workers=4,\n)\nGet or create the offload manager\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager(\n prefetch_size=3,\n prefetch_to_gpu=True,\n save_workers=4,\n)\nManages offloaded tensors and handles prefetching in a separate thread.\nIncludes synchronization to prevent race conditions.\n\n\n\n\n\nName\nDescription\n\n\n\n\ncleanup\nClean up all temp files and stop prefetch thread with proper synchronization\n\n\ncleanup_tensor\nClean up a specific tensor file after its been used\n\n\nload_tensor\nLoad tensor from disk or prefetch cache with proper synchronization\n\n\nsave_tensor\nSave tensor to disk asynchronously and return file path with thread-safe operations\n\n\ntrigger_prefetch\nTrigger prefetching of the next N tensors with proper synchronization\n\n\nwait_for_save\nWait for a tensor to be saved to disk\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.cleanup()\nClean up all temp files and stop prefetch thread with proper synchronization\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.cleanup_tensor(\n file_path,\n)\nClean up a specific tensor file after its been used\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.load_tensor(\n file_path,\n target_device='cuda',\n)\nLoad tensor from disk or prefetch cache with proper synchronization\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.save_tensor(\n tensor,\n)\nSave tensor to disk asynchronously and return file path with thread-safe operations\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.trigger_prefetch(\n n=None,\n)\nTrigger prefetching of the next N tensors with proper synchronization\n\n\n\nmonkeypatch.gradient_checkpointing.offload_disk.DiskOffloadManager.wait_for_save(\n file_path,\n timeout=None,\n)\nWait for a tensor to be saved to disk"
},
{
"objectID": "docs/api/core.datasets.chat.html",
"href": "docs/api/core.datasets.chat.html",
"title": "core.datasets.chat",
"section": "",
"text": "core.datasets.chat\nchat dataset module\n\n\n\n\n\nName\nDescription\n\n\n\n\nTokenizedChatDataset\nTokenized chat dataset\n\n\n\n\n\ncore.datasets.chat.TokenizedChatDataset(\n data,\n model_transform,\n *args,\n message_transform=None,\n formatter=None,\n process_count=None,\n keep_in_memory=False,\n **kwargs,\n)\nTokenized chat dataset"
},
{
"objectID": "docs/api/core.datasets.chat.html#classes",
"href": "docs/api/core.datasets.chat.html#classes",
"title": "core.datasets.chat",
"section": "",
"text": "Name\nDescription\n\n\n\n\nTokenizedChatDataset\nTokenized chat dataset\n\n\n\n\n\ncore.datasets.chat.TokenizedChatDataset(\n data,\n model_transform,\n *args,\n message_transform=None,\n formatter=None,\n process_count=None,\n keep_in_memory=False,\n **kwargs,\n)\nTokenized chat dataset"
},
{
"objectID": "docs/api/core.chat.format.shared.html",
"href": "docs/api/core.chat.format.shared.html",
"title": "core.chat.format.shared",
"section": "",
"text": "core.chat.format.shared\ncore.chat.format.shared\nshared functions for format transforms"
},
{
"objectID": "docs/api/logging_config.html",
"href": "docs/api/logging_config.html",
"title": "logging_config",
"section": "",
"text": "logging_config\nCommon logging module for axolotl.\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlLogger\nLogger that applies filtering to non-axolotl loggers.\n\n\nAxolotlOrWarnErrorFilter\nAllows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at\n\n\nColorfulFormatter\nFormatter to add coloring to log messages by log type\n\n\n\n\n\nlogging_config.AxolotlLogger(name, level=logging.NOTSET)\nLogger that applies filtering to non-axolotl loggers.\n\n\n\nlogging_config.AxolotlOrWarnErrorFilter(**kwargs)\nAllows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at\nINFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records\n(i.e. non-axolotl.INFO, DEBUG, etc. by default).\n\n\n\nlogging_config.ColorfulFormatter()\nFormatter to add coloring to log messages by log type\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nconfigure_logging\nConfigure with default logging\n\n\n\n\n\nlogging_config.configure_logging()\nConfigure with default logging"
},
{
"objectID": "docs/api/logging_config.html#classes",
"href": "docs/api/logging_config.html#classes",
"title": "logging_config",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlLogger\nLogger that applies filtering to non-axolotl loggers.\n\n\nAxolotlOrWarnErrorFilter\nAllows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at\n\n\nColorfulFormatter\nFormatter to add coloring to log messages by log type\n\n\n\n\n\nlogging_config.AxolotlLogger(name, level=logging.NOTSET)\nLogger that applies filtering to non-axolotl loggers.\n\n\n\nlogging_config.AxolotlOrWarnErrorFilter(**kwargs)\nAllows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at\nINFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records\n(i.e. non-axolotl.INFO, DEBUG, etc. by default).\n\n\n\nlogging_config.ColorfulFormatter()\nFormatter to add coloring to log messages by log type"
},
{
"objectID": "docs/api/logging_config.html#functions",
"href": "docs/api/logging_config.html#functions",
"title": "logging_config",
"section": "",
"text": "Name\nDescription\n\n\n\n\nconfigure_logging\nConfigure with default logging\n\n\n\n\n\nlogging_config.configure_logging()\nConfigure with default logging"
},
{
"objectID": "docs/api/prompt_strategies.chat_template.html",
"href": "docs/api/prompt_strategies.chat_template.html",
"title": "prompt_strategies.chat_template",
"section": "",
"text": "prompt_strategies.chat_template\nHF Chat Templates prompt strategy\n\n\n\n\n\nName\nDescription\n\n\n\n\nChatTemplatePrompter\nPrompter for HF chat templates\n\n\nChatTemplateStrategy\nTokenizing strategy for instruction-based prompts.\n\n\nMistralPrompter\nMistral prompter for chat template.\n\n\nMistralStrategy\nMistral strategy for chat template.\n\n\nStrategyLoader\nLoad chat template strategy based on configuration.\n\n\n\n\n\nprompt_strategies.chat_template.ChatTemplatePrompter(\n tokenizer,\n chat_template,\n processor=None,\n max_length=2048,\n message_property_mappings=None,\n message_field_training=None,\n message_field_training_detail=None,\n field_messages='messages',\n field_system='system',\n field_tools='tools',\n field_thinking='reasoning_content',\n roles=None,\n template_thinking_key='reasoning_content',\n chat_template_kwargs=None,\n drop_system_message=False,\n)\nPrompter for HF chat templates\n\n\n\n\n\nName\nDescription\n\n\n\n\nbuild_prompt\nBuild a prompt from a conversation.\n\n\n\n\n\nprompt_strategies.chat_template.ChatTemplatePrompter.build_prompt(\n conversation,\n add_generation_prompt=False,\n images=None,\n tools=None,\n real_last_index=None,\n)\nBuild a prompt from a conversation.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconversation\nlist[dict]\nA list of messages.\nrequired\n\n\nadd_generation_prompt\n\nWhether to add a generation prompt.\nFalse\n\n\nimages\n\nA list of images. (optional)\nNone\n\n\ntools\n\nA list of tools. (optional)\nNone\n\n\n\n\n\n\n\n\n\nprompt_strategies.chat_template.ChatTemplateStrategy(\n prompter,\n tokenizer,\n train_on_inputs,\n sequence_len,\n roles_to_train=None,\n train_on_eos=None,\n train_on_eot=None,\n eot_tokens=None,\n split_thinking=False,\n)\nTokenizing strategy for instruction-based prompts.\n\n\n\n\n\nName\nDescription\n\n\n\n\nfind_first_eot_token\nFind the first EOT token in the input_ids starting from start_idx.\n\n\nfind_turn\nLocate the starting and ending indices of the specified turn in a conversation.\n\n\ntokenize_prompt\nPublic method that can handle either a single prompt or a batch of prompts.\n\n\n\n\n\nprompt_strategies.chat_template.ChatTemplateStrategy.find_first_eot_token(\n input_ids,\n start_idx,\n)\nFind the first EOT token in the input_ids starting from start_idx.\n\n\n\nprompt_strategies.chat_template.ChatTemplateStrategy.find_turn(\n turns,\n turn_idx,\n tools=None,\n)\nLocate the starting and ending indices of the specified turn in a conversation.\n\n\n\nprompt_strategies.chat_template.ChatTemplateStrategy.tokenize_prompt(prompt)\nPublic method that can handle either a single prompt or a batch of prompts.\n\n\n\n\n\nprompt_strategies.chat_template.MistralPrompter(*args, **kwargs)\nMistral prompter for chat template.\n\n\n\nprompt_strategies.chat_template.MistralStrategy(\n prompter,\n tokenizer,\n train_on_inputs,\n sequence_len,\n roles_to_train=None,\n train_on_eos=None,\n train_on_eot=None,\n eot_tokens=None,\n split_thinking=False,\n)\nMistral strategy for chat template.\n\n\n\n\n\nName\nDescription\n\n\n\n\nfind_first_eot_token\nFind the first EOT token in the input_ids starting from start_idx.\n\n\n\n\n\nprompt_strategies.chat_template.MistralStrategy.find_first_eot_token(\n input_ids,\n start_idx,\n)\nFind the first EOT token in the input_ids starting from start_idx.\n\n\n\n\n\nprompt_strategies.chat_template.StrategyLoader()\nLoad chat template strategy based on configuration."
},
{
"objectID": "docs/api/prompt_strategies.chat_template.html#classes",
"href": "docs/api/prompt_strategies.chat_template.html#classes",
"title": "prompt_strategies.chat_template",
"section": "",
"text": "Name\nDescription\n\n\n\n\nChatTemplatePrompter\nPrompter for HF chat templates\n\n\nChatTemplateStrategy\nTokenizing strategy for instruction-based prompts.\n\n\nMistralPrompter\nMistral prompter for chat template.\n\n\nMistralStrategy\nMistral strategy for chat template.\n\n\nStrategyLoader\nLoad chat template strategy based on configuration.\n\n\n\n\n\nprompt_strategies.chat_template.ChatTemplatePrompter(\n tokenizer,\n chat_template,\n processor=None,\n max_length=2048,\n message_property_mappings=None,\n message_field_training=None,\n message_field_training_detail=None,\n field_messages='messages',\n field_system='system',\n field_tools='tools',\n field_thinking='reasoning_content',\n roles=None,\n template_thinking_key='reasoning_content',\n chat_template_kwargs=None,\n drop_system_message=False,\n)\nPrompter for HF chat templates\n\n\n\n\n\nName\nDescription\n\n\n\n\nbuild_prompt\nBuild a prompt from a conversation.\n\n\n\n\n\nprompt_strategies.chat_template.ChatTemplatePrompter.build_prompt(\n conversation,\n add_generation_prompt=False,\n images=None,\n tools=None,\n real_last_index=None,\n)\nBuild a prompt from a conversation.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconversation\nlist[dict]\nA list of messages.\nrequired\n\n\nadd_generation_prompt\n\nWhether to add a generation prompt.\nFalse\n\n\nimages\n\nA list of images. (optional)\nNone\n\n\ntools\n\nA list of tools. (optional)\nNone\n\n\n\n\n\n\n\n\n\nprompt_strategies.chat_template.ChatTemplateStrategy(\n prompter,\n tokenizer,\n train_on_inputs,\n sequence_len,\n roles_to_train=None,\n train_on_eos=None,\n train_on_eot=None,\n eot_tokens=None,\n split_thinking=False,\n)\nTokenizing strategy for instruction-based prompts.\n\n\n\n\n\nName\nDescription\n\n\n\n\nfind_first_eot_token\nFind the first EOT token in the input_ids starting from start_idx.\n\n\nfind_turn\nLocate the starting and ending indices of the specified turn in a conversation.\n\n\ntokenize_prompt\nPublic method that can handle either a single prompt or a batch of prompts.\n\n\n\n\n\nprompt_strategies.chat_template.ChatTemplateStrategy.find_first_eot_token(\n input_ids,\n start_idx,\n)\nFind the first EOT token in the input_ids starting from start_idx.\n\n\n\nprompt_strategies.chat_template.ChatTemplateStrategy.find_turn(\n turns,\n turn_idx,\n tools=None,\n)\nLocate the starting and ending indices of the specified turn in a conversation.\n\n\n\nprompt_strategies.chat_template.ChatTemplateStrategy.tokenize_prompt(prompt)\nPublic method that can handle either a single prompt or a batch of prompts.\n\n\n\n\n\nprompt_strategies.chat_template.MistralPrompter(*args, **kwargs)\nMistral prompter for chat template.\n\n\n\nprompt_strategies.chat_template.MistralStrategy(\n prompter,\n tokenizer,\n train_on_inputs,\n sequence_len,\n roles_to_train=None,\n train_on_eos=None,\n train_on_eot=None,\n eot_tokens=None,\n split_thinking=False,\n)\nMistral strategy for chat template.\n\n\n\n\n\nName\nDescription\n\n\n\n\nfind_first_eot_token\nFind the first EOT token in the input_ids starting from start_idx.\n\n\n\n\n\nprompt_strategies.chat_template.MistralStrategy.find_first_eot_token(\n input_ids,\n start_idx,\n)\nFind the first EOT token in the input_ids starting from start_idx.\n\n\n\n\n\nprompt_strategies.chat_template.StrategyLoader()\nLoad chat template strategy based on configuration."
},
{
"objectID": "docs/api/utils.collators.mamba.html",
"href": "docs/api/utils.collators.mamba.html",
"title": "utils.collators.mamba",
"section": "",
"text": "utils.collators.mamba\ncollators for Mamba\n\n\n\n\n\nName\nDescription\n\n\n\n\nMambaDataCollator\nCollator for State Space Models (Mamba)\n\n\n\n\n\nutils.collators.mamba.MambaDataCollator(tokenizer)\nCollator for State Space Models (Mamba)"
},
{
"objectID": "docs/api/utils.collators.mamba.html#classes",
"href": "docs/api/utils.collators.mamba.html#classes",
"title": "utils.collators.mamba",
"section": "",
"text": "Name\nDescription\n\n\n\n\nMambaDataCollator\nCollator for State Space Models (Mamba)\n\n\n\n\n\nutils.collators.mamba.MambaDataCollator(tokenizer)\nCollator for State Space Models (Mamba)"
},
{
"objectID": "docs/api/cli.config.html",
"href": "docs/api/cli.config.html",
"title": "cli.config",
"section": "",
"text": "cli.config\nConfiguration loading and processing.\n\n\n\n\n\nName\nDescription\n\n\n\n\ncheck_remote_config\nFirst, determines if the passed config is a valid HTTPS URL. Then, attempts to query\n\n\nchoose_config\nHelper method for choosing a axolotl config YAML file (considering only files\n\n\nload_cfg\nLoads the axolotl configuration stored at config, validates it, and performs\n\n\nprepare_plugins\nRegisters the plugins for the given configuration.\n\n\n\n\n\ncli.config.check_remote_config(config)\nFirst, determines if the passed config is a valid HTTPS URL. Then, attempts to query\nfor it and parse its content, first as JSON, then as YAML (YAML is preferred).\nFinally, the parsed content is written to a local file and its path is returned.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[str, Path]\nHTTPS URL to a YAML or JSON file.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nUnion[str, Path]\nEither the original config if its not a valid HTTPS URL, or the path to the\n\n\n\nUnion[str, Path]\ndownloaded remote config.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf the remote configuration is neither valid JSON or YAML.\n\n\n\nRuntimeError\nIf some request-related exception occurs from the file download.\n\n\n\nException\nCatch-all for any other exception.\n\n\n\n\n\n\n\ncli.config.choose_config(path)\nHelper method for choosing a axolotl config YAML file (considering only files\nending with .yml or .yaml). If more than one config file exists in the passed\npath, the user is prompted to choose one.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\npath\nPath\nDirectory in which config file(s) are stored.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nstr\nPath to either (1) the sole YAML file, or (2) if more than one YAML files exist,\n\n\n\nstr\nthe user-selected YAML file.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf no YAML files are found in the given path.\n\n\n\n\n\n\n\ncli.config.load_cfg(config=Path('examples/'), **kwargs)\nLoads the axolotl configuration stored at config, validates it, and performs\nvarious setup.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nstr | Path | DictDefault\nPath (local or remote) to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nDictDefault\nDictDefault mapping configuration keys to values.\n\n\n\n\n\n\n\ncli.config.prepare_plugins(cfg)\nRegisters the plugins for the given configuration.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired"
},
{
"objectID": "docs/api/cli.config.html#functions",
"href": "docs/api/cli.config.html#functions",
"title": "cli.config",
"section": "",
"text": "Name\nDescription\n\n\n\n\ncheck_remote_config\nFirst, determines if the passed config is a valid HTTPS URL. Then, attempts to query\n\n\nchoose_config\nHelper method for choosing a axolotl config YAML file (considering only files\n\n\nload_cfg\nLoads the axolotl configuration stored at config, validates it, and performs\n\n\nprepare_plugins\nRegisters the plugins for the given configuration.\n\n\n\n\n\ncli.config.check_remote_config(config)\nFirst, determines if the passed config is a valid HTTPS URL. Then, attempts to query\nfor it and parse its content, first as JSON, then as YAML (YAML is preferred).\nFinally, the parsed content is written to a local file and its path is returned.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[str, Path]\nHTTPS URL to a YAML or JSON file.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nUnion[str, Path]\nEither the original config if its not a valid HTTPS URL, or the path to the\n\n\n\nUnion[str, Path]\ndownloaded remote config.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf the remote configuration is neither valid JSON or YAML.\n\n\n\nRuntimeError\nIf some request-related exception occurs from the file download.\n\n\n\nException\nCatch-all for any other exception.\n\n\n\n\n\n\n\ncli.config.choose_config(path)\nHelper method for choosing a axolotl config YAML file (considering only files\nending with .yml or .yaml). If more than one config file exists in the passed\npath, the user is prompted to choose one.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\npath\nPath\nDirectory in which config file(s) are stored.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nstr\nPath to either (1) the sole YAML file, or (2) if more than one YAML files exist,\n\n\n\nstr\nthe user-selected YAML file.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf no YAML files are found in the given path.\n\n\n\n\n\n\n\ncli.config.load_cfg(config=Path('examples/'), **kwargs)\nLoads the axolotl configuration stored at config, validates it, and performs\nvarious setup.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nstr | Path | DictDefault\nPath (local or remote) to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nDictDefault\nDictDefault mapping configuration keys to values.\n\n\n\n\n\n\n\ncli.config.prepare_plugins(cfg)\nRegisters the plugins for the given configuration.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired"
},
{
"objectID": "docs/api/loaders.model.html",
"href": "docs/api/loaders.model.html",
"title": "loaders.model",
"section": "",
"text": "loaders.model\nModel loader class implementation for loading, configuring, and patching various models.\n\n\n\n\n\nName\nDescription\n\n\n\n\nModelLoader\nManages model configuration, initialization and application of patches during\n\n\n\n\n\nloaders.model.ModelLoader(\n cfg,\n tokenizer,\n *,\n inference=False,\n reference_model=False,\n **kwargs,\n)\nManages model configuration, initialization and application of patches during\nmodel loading.\nThis class orchestrates the entire process of loading a model from configuration to\nfinal preparation. It handles device mapping, quantization, attention mechanisms,\nadapter integration, and various optimizations.\n\n\n\nLoading and validating model configuration\nApplying monkey patches for optimizations / fixes\nSetting up device mapping (including multi-GPU configurations)\nConfiguring quantization\nSetting attention mechanisms (Flash Attention, SDPA, etc.)\nLoading and initializing the model\nApplying adapters (LoRA, QLoRA, etc.)\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nmodel\nPreTrainedModel | PeftModel | PeftMixedModel\nThe loaded model instance (available after load() is called).\n\n\nmodel_kwargs\ndict[str, Any]\nDictionary of keyword arguments passed to model initialization.\n\n\nbase_model\n\nName or path of the base model to load.\n\n\nmodel_type\n\nType of model to load (e.g., AutoModelForCausalLM).\n\n\nmodel_config\n\nConfiguration object for the model.\n\n\nauto_model_loader\n\nclass used for loading the model (default: AutoModelForCausalLM).\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nload\nLoad and prepare the model with all configurations and patches.\n\n\n\n\n\nloaders.model.ModelLoader.load()\nLoad and prepare the model with all configurations and patches.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]\nA tuple with the loaded model and its LoRA configuration (if applicable)."
},
{
"objectID": "docs/api/loaders.model.html#classes",
"href": "docs/api/loaders.model.html#classes",
"title": "loaders.model",
"section": "",
"text": "Name\nDescription\n\n\n\n\nModelLoader\nManages model configuration, initialization and application of patches during\n\n\n\n\n\nloaders.model.ModelLoader(\n cfg,\n tokenizer,\n *,\n inference=False,\n reference_model=False,\n **kwargs,\n)\nManages model configuration, initialization and application of patches during\nmodel loading.\nThis class orchestrates the entire process of loading a model from configuration to\nfinal preparation. It handles device mapping, quantization, attention mechanisms,\nadapter integration, and various optimizations.\n\n\n\nLoading and validating model configuration\nApplying monkey patches for optimizations / fixes\nSetting up device mapping (including multi-GPU configurations)\nConfiguring quantization\nSetting attention mechanisms (Flash Attention, SDPA, etc.)\nLoading and initializing the model\nApplying adapters (LoRA, QLoRA, etc.)\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nmodel\nPreTrainedModel | PeftModel | PeftMixedModel\nThe loaded model instance (available after load() is called).\n\n\nmodel_kwargs\ndict[str, Any]\nDictionary of keyword arguments passed to model initialization.\n\n\nbase_model\n\nName or path of the base model to load.\n\n\nmodel_type\n\nType of model to load (e.g., AutoModelForCausalLM).\n\n\nmodel_config\n\nConfiguration object for the model.\n\n\nauto_model_loader\n\nclass used for loading the model (default: AutoModelForCausalLM).\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nload\nLoad and prepare the model with all configurations and patches.\n\n\n\n\n\nloaders.model.ModelLoader.load()\nLoad and prepare the model with all configurations and patches.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]\nA tuple with the loaded model and its LoRA configuration (if applicable)."
},
{
"objectID": "docs/api/prompt_strategies.kto.chatml.html",
"href": "docs/api/prompt_strategies.kto.chatml.html",
"title": "prompt_strategies.kto.chatml",
"section": "",
"text": "prompt_strategies.kto.chatml\nKTO strategies for chatml\n\n\n\n\n\nName\nDescription\n\n\n\n\nargilla_chat\nfor argilla/kto-mix-15k conversations\n\n\nintel\nFor Intel Orca KTO\n\n\nultra\nfor ultrafeedback binarized conversations\n\n\n\n\n\nprompt_strategies.kto.chatml.argilla_chat(cfg, **kwargs)\nfor argilla/kto-mix-15k conversations\n\n\n\nprompt_strategies.kto.chatml.intel(cfg, **kwargs)\nFor Intel Orca KTO\nex: argilla/distilabel-intel-orca-kto\n\n\n\nprompt_strategies.kto.chatml.ultra(cfg, **kwargs)\nfor ultrafeedback binarized conversations\nex: argilla/ultrafeedback-binarized-preferences-cleaned-kto"
},
{
"objectID": "docs/api/prompt_strategies.kto.chatml.html#functions",
"href": "docs/api/prompt_strategies.kto.chatml.html#functions",
"title": "prompt_strategies.kto.chatml",
"section": "",
"text": "Name\nDescription\n\n\n\n\nargilla_chat\nfor argilla/kto-mix-15k conversations\n\n\nintel\nFor Intel Orca KTO\n\n\nultra\nfor ultrafeedback binarized conversations\n\n\n\n\n\nprompt_strategies.kto.chatml.argilla_chat(cfg, **kwargs)\nfor argilla/kto-mix-15k conversations\n\n\n\nprompt_strategies.kto.chatml.intel(cfg, **kwargs)\nFor Intel Orca KTO\nex: argilla/distilabel-intel-orca-kto\n\n\n\nprompt_strategies.kto.chatml.ultra(cfg, **kwargs)\nfor ultrafeedback binarized conversations\nex: argilla/ultrafeedback-binarized-preferences-cleaned-kto"
},
{
"objectID": "docs/api/cli.quantize.html",
"href": "docs/api/cli.quantize.html",
"title": "cli.quantize",
"section": "",
"text": "cli.quantize\nCLI to post-training quantize a model using torchao\n\n\n\n\n\nName\nDescription\n\n\n\n\ndo_quantize\nQuantizes a models models weights\n\n\n\n\n\ncli.quantize.do_quantize(config, cli_args)\nQuantizes a models models weights\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nThe path to the config file\nrequired\n\n\ncli_args\ndict\nAdditional command-line arguments\nrequired"
},
{
"objectID": "docs/api/cli.quantize.html#functions",
"href": "docs/api/cli.quantize.html#functions",
"title": "cli.quantize",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndo_quantize\nQuantizes a models models weights\n\n\n\n\n\ncli.quantize.do_quantize(config, cli_args)\nQuantizes a models models weights\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nThe path to the config file\nrequired\n\n\ncli_args\ndict\nAdditional command-line arguments\nrequired"
},
{
"objectID": "docs/api/prompt_strategies.bradley_terry.llama3.html",
"href": "docs/api/prompt_strategies.bradley_terry.llama3.html",
"title": "prompt_strategies.bradley_terry.llama3",
"section": "",
"text": "prompt_strategies.bradley_terry.llama3\nchatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template\n\n\n\n\n\nName\nDescription\n\n\n\n\nicr\nchatml transforms for datasets with system, input, chosen, rejected\n\n\n\n\n\nprompt_strategies.bradley_terry.llama3.icr(cfg, **kwargs)\nchatml transforms for datasets with system, input, chosen, rejected\nex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs"
},
{
"objectID": "docs/api/prompt_strategies.bradley_terry.llama3.html#functions",
"href": "docs/api/prompt_strategies.bradley_terry.llama3.html#functions",
"title": "prompt_strategies.bradley_terry.llama3",
"section": "",
"text": "Name\nDescription\n\n\n\n\nicr\nchatml transforms for datasets with system, input, chosen, rejected\n\n\n\n\n\nprompt_strategies.bradley_terry.llama3.icr(cfg, **kwargs)\nchatml transforms for datasets with system, input, chosen, rejected\nex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs"
},
{
"objectID": "docs/api/integrations.spectrum.args.html",
"href": "docs/api/integrations.spectrum.args.html",
"title": "integrations.spectrum.args",
"section": "",
"text": "integrations.spectrum.args\nModule for handling Spectrum input arguments.\n\n\n\n\n\nName\nDescription\n\n\n\n\nSpectrumArgs\nInput args for Spectrum.\n\n\n\n\n\nintegrations.spectrum.args.SpectrumArgs()\nInput args for Spectrum."
},
{
"objectID": "docs/api/integrations.spectrum.args.html#classes",
"href": "docs/api/integrations.spectrum.args.html#classes",
"title": "integrations.spectrum.args",
"section": "",
"text": "Name\nDescription\n\n\n\n\nSpectrumArgs\nInput args for Spectrum.\n\n\n\n\n\nintegrations.spectrum.args.SpectrumArgs()\nInput args for Spectrum."
},
{
"objectID": "docs/api/prompt_strategies.messages.chat.html",
"href": "docs/api/prompt_strategies.messages.chat.html",
"title": "prompt_strategies.messages.chat",
"section": "",
"text": "prompt_strategies.messages.chat\nChat dataset wrapping strategy for new internal messages representations\n\n\n\n\n\nName\nDescription\n\n\n\n\nChatMessageDatasetWrappingStrategy\nChat dataset wrapping strategy for new internal messages representations\n\n\n\n\n\nprompt_strategies.messages.chat.ChatMessageDatasetWrappingStrategy(\n processor,\n message_transform=None,\n formatter=None,\n **kwargs,\n)\nChat dataset wrapping strategy for new internal messages representations"
},
{
"objectID": "docs/api/prompt_strategies.messages.chat.html#classes",
"href": "docs/api/prompt_strategies.messages.chat.html#classes",
"title": "prompt_strategies.messages.chat",
"section": "",
"text": "Name\nDescription\n\n\n\n\nChatMessageDatasetWrappingStrategy\nChat dataset wrapping strategy for new internal messages representations\n\n\n\n\n\nprompt_strategies.messages.chat.ChatMessageDatasetWrappingStrategy(\n processor,\n message_transform=None,\n formatter=None,\n **kwargs,\n)\nChat dataset wrapping strategy for new internal messages representations"
},
{
"objectID": "docs/api/utils.callbacks.perplexity.html",
"href": "docs/api/utils.callbacks.perplexity.html",
"title": "utils.callbacks.perplexity",
"section": "",
"text": "utils.callbacks.perplexity\ncallback to calculate perplexity as an evaluation metric.\n\n\n\n\n\nName\nDescription\n\n\n\n\nPerplexity\nCalculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.\n\n\n\n\n\nutils.callbacks.perplexity.Perplexity(tokenizer, max_seq_len, stride=512)\nCalculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.\nThis is a custom variant that doesnt re-tokenize the input or re-load the model.\n\n\n\n\n\nName\nDescription\n\n\n\n\ncompute\nCompute perplexity in a fixed length sliding window across the sequence.\n\n\n\n\n\nutils.callbacks.perplexity.Perplexity.compute(model, references=None)\nCompute perplexity in a fixed length sliding window across the sequence."
},
{
"objectID": "docs/api/utils.callbacks.perplexity.html#classes",
"href": "docs/api/utils.callbacks.perplexity.html#classes",
"title": "utils.callbacks.perplexity",
"section": "",
"text": "Name\nDescription\n\n\n\n\nPerplexity\nCalculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.\n\n\n\n\n\nutils.callbacks.perplexity.Perplexity(tokenizer, max_seq_len, stride=512)\nCalculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.\nThis is a custom variant that doesnt re-tokenize the input or re-load the model.\n\n\n\n\n\nName\nDescription\n\n\n\n\ncompute\nCompute perplexity in a fixed length sliding window across the sequence.\n\n\n\n\n\nutils.callbacks.perplexity.Perplexity.compute(model, references=None)\nCompute perplexity in a fixed length sliding window across the sequence."
},
{
"objectID": "docs/api/monkeypatch.lora_kernels.html",
"href": "docs/api/monkeypatch.lora_kernels.html",
"title": "monkeypatch.lora_kernels",
"section": "",
"text": "monkeypatch.lora_kernels\nModule for patching custom LoRA Triton kernels and torch.autograd functions.\n\n\n\n\n\nName\nDescription\n\n\n\n\nFakeMLP\nplaceholder MLP for triton patching\n\n\n\n\n\nmonkeypatch.lora_kernels.FakeMLP(gate_proj, up_proj, down_proj)\nplaceholder MLP for triton patching\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\napply_lora_kernel_patches\nApplies optimized Triton kernel patches to a PEFT model.\n\n\nget_attention_cls_from_config\nGet the appropriate attention class by inspecting the model config.\n\n\nget_layers\nGet the layers of the model. Handles text-only and multimodal models.\n\n\noriginal_apply_o\nOriginal implementation of output projection without optimizations.\n\n\noriginal_apply_qkv\nOriginal implementation of QKV projection without optimizations.\n\n\npatch_self_attn_lora\nGiven an axolotl config, this method patches the inferred attention class forward\n\n\n\n\n\nmonkeypatch.lora_kernels.apply_lora_kernel_patches(model, cfg)\nApplies optimized Triton kernel patches to a PEFT model.\nPatches a PEFT model with optimized implementations for MLP and attention\ncomputations. The optimizations include custom Triton kernels for activation\nfunctions and specialized autograd functions for LoRA computations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\nPeftModelForCausalLM\nA PEFT model to be patched with optimized kernels.\nrequired\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nPeftModelForCausalLM\nPeftModelForCausalLM\nThe patched model with optimized kernels.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nTypeError\nIf the provided model is not a PeftModelForCausalLM.\n\n\n\nNotImplementedError\nIf the model type is not supported.\n\n\n\nAssertionError\nIf multiple adapters are active (currently unsupported).\n\n\n\n\n\n\nThe optimizations require LoRA adapters with no dropout and no bias terms. The\nfunction will skip patching if these conditions arent met.\n\n\n\n\nmonkeypatch.lora_kernels.get_attention_cls_from_config(cfg)\nGet the appropriate attention class by inspecting the model config.\nUses dynamic import to support any model architecture that follows\nthe standard transformers naming convention.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nType[nn.Module]\nThe appropriate attention class for the model.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf base_model not specified or attention class cannot be imported\n\n\n\nImportError\nIf the model module or attention class doesnt exist\n\n\n\n\n\n\n\nmonkeypatch.lora_kernels.get_layers(model)\nGet the layers of the model. Handles text-only and multimodal models.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\nPeftModelForCausalLM\nA PEFT model.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[nn.Module]\nA list of layers.\n\n\n\n\n\n\n\nmonkeypatch.lora_kernels.original_apply_o(self, hidden_states)\nOriginal implementation of output projection without optimizations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nself\nnn.Module\nThe attention module instance.\nrequired\n\n\nhidden_states\ntorch.Tensor\nInput tensor of shape [batch_size, seq_len, hidden_dim]`.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nThe output projection result.\n\n\n\n\n\n\n\nmonkeypatch.lora_kernels.original_apply_qkv(self, hidden_states)\nOriginal implementation of QKV projection without optimizations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nself\nnn.Module\nThe attention module instance.\nrequired\n\n\nhidden_states\ntorch.Tensor\nInput tensor of shape [batch_size, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[torch.Tensor, torch.Tensor, torch.Tensor]\nA tuple (query_states, key_states, value_states) containing the projected states for query, key, and value.\n\n\n\n\n\n\n\nmonkeypatch.lora_kernels.patch_self_attn_lora(cfg)\nGiven an axolotl config, this method patches the inferred attention class forward\npass with optimized LoRA implementations.\nIt modifies the attention class to use optimized QKV and output projections. The\noriginal implementation is preserved and can be restored if needed.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nAssertionError\nIf the required code blocks are not found in the attention implementation."
},
{
"objectID": "docs/api/monkeypatch.lora_kernels.html#classes",
"href": "docs/api/monkeypatch.lora_kernels.html#classes",
"title": "monkeypatch.lora_kernels",
"section": "",
"text": "Name\nDescription\n\n\n\n\nFakeMLP\nplaceholder MLP for triton patching\n\n\n\n\n\nmonkeypatch.lora_kernels.FakeMLP(gate_proj, up_proj, down_proj)\nplaceholder MLP for triton patching"
},
{
"objectID": "docs/api/monkeypatch.lora_kernels.html#functions",
"href": "docs/api/monkeypatch.lora_kernels.html#functions",
"title": "monkeypatch.lora_kernels",
"section": "",
"text": "Name\nDescription\n\n\n\n\napply_lora_kernel_patches\nApplies optimized Triton kernel patches to a PEFT model.\n\n\nget_attention_cls_from_config\nGet the appropriate attention class by inspecting the model config.\n\n\nget_layers\nGet the layers of the model. Handles text-only and multimodal models.\n\n\noriginal_apply_o\nOriginal implementation of output projection without optimizations.\n\n\noriginal_apply_qkv\nOriginal implementation of QKV projection without optimizations.\n\n\npatch_self_attn_lora\nGiven an axolotl config, this method patches the inferred attention class forward\n\n\n\n\n\nmonkeypatch.lora_kernels.apply_lora_kernel_patches(model, cfg)\nApplies optimized Triton kernel patches to a PEFT model.\nPatches a PEFT model with optimized implementations for MLP and attention\ncomputations. The optimizations include custom Triton kernels for activation\nfunctions and specialized autograd functions for LoRA computations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\nPeftModelForCausalLM\nA PEFT model to be patched with optimized kernels.\nrequired\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nPeftModelForCausalLM\nPeftModelForCausalLM\nThe patched model with optimized kernels.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nTypeError\nIf the provided model is not a PeftModelForCausalLM.\n\n\n\nNotImplementedError\nIf the model type is not supported.\n\n\n\nAssertionError\nIf multiple adapters are active (currently unsupported).\n\n\n\n\n\n\nThe optimizations require LoRA adapters with no dropout and no bias terms. The\nfunction will skip patching if these conditions arent met.\n\n\n\n\nmonkeypatch.lora_kernels.get_attention_cls_from_config(cfg)\nGet the appropriate attention class by inspecting the model config.\nUses dynamic import to support any model architecture that follows\nthe standard transformers naming convention.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nType[nn.Module]\nThe appropriate attention class for the model.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf base_model not specified or attention class cannot be imported\n\n\n\nImportError\nIf the model module or attention class doesnt exist\n\n\n\n\n\n\n\nmonkeypatch.lora_kernels.get_layers(model)\nGet the layers of the model. Handles text-only and multimodal models.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\nPeftModelForCausalLM\nA PEFT model.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[nn.Module]\nA list of layers.\n\n\n\n\n\n\n\nmonkeypatch.lora_kernels.original_apply_o(self, hidden_states)\nOriginal implementation of output projection without optimizations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nself\nnn.Module\nThe attention module instance.\nrequired\n\n\nhidden_states\ntorch.Tensor\nInput tensor of shape [batch_size, seq_len, hidden_dim]`.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nThe output projection result.\n\n\n\n\n\n\n\nmonkeypatch.lora_kernels.original_apply_qkv(self, hidden_states)\nOriginal implementation of QKV projection without optimizations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nself\nnn.Module\nThe attention module instance.\nrequired\n\n\nhidden_states\ntorch.Tensor\nInput tensor of shape [batch_size, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[torch.Tensor, torch.Tensor, torch.Tensor]\nA tuple (query_states, key_states, value_states) containing the projected states for query, key, and value.\n\n\n\n\n\n\n\nmonkeypatch.lora_kernels.patch_self_attn_lora(cfg)\nGiven an axolotl config, this method patches the inferred attention class forward\npass with optimized LoRA implementations.\nIt modifies the attention class to use optimized QKV and output projections. The\noriginal implementation is preserved and can be restored if needed.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nAssertionError\nIf the required code blocks are not found in the attention implementation."
},
{
"objectID": "docs/api/monkeypatch.data.batch_dataset_fetcher.html",
"href": "docs/api/monkeypatch.data.batch_dataset_fetcher.html",
"title": "monkeypatch.data.batch_dataset_fetcher",
"section": "",
"text": "monkeypatch.data.batch_dataset_fetcher\nMonkey patches for the dataset fetcher to handle batches of packed indexes.\n\n\n\n\n\nName\nDescription\n\n\n\n\napply_multipack_dataloader_patch\nThis patch allows DataLoader to correctly process batches that contain multiple bins\n\n\npatch_fetchers\nApply patches to PyTorchs DataLoader components.\n\n\npatched_worker_loop\nWorker loop that ensures patches are applied in worker processes.\n\n\nremove_multipack_dataloader_patch\nRemove the monkeypatch and restore original PyTorch DataLoader behavior.\n\n\n\n\n\nmonkeypatch.data.batch_dataset_fetcher.apply_multipack_dataloader_patch()\nThis patch allows DataLoader to correctly process batches that contain multiple bins\nof packed sequences.\n\n\n\nmonkeypatch.data.batch_dataset_fetcher.patch_fetchers()\nApply patches to PyTorchs DataLoader components.\n\n\n\nmonkeypatch.data.batch_dataset_fetcher.patched_worker_loop(*args, **kwargs)\nWorker loop that ensures patches are applied in worker processes.\n\n\n\nmonkeypatch.data.batch_dataset_fetcher.remove_multipack_dataloader_patch()\nRemove the monkeypatch and restore original PyTorch DataLoader behavior."
},
{
"objectID": "docs/api/monkeypatch.data.batch_dataset_fetcher.html#functions",
"href": "docs/api/monkeypatch.data.batch_dataset_fetcher.html#functions",
"title": "monkeypatch.data.batch_dataset_fetcher",
"section": "",
"text": "Name\nDescription\n\n\n\n\napply_multipack_dataloader_patch\nThis patch allows DataLoader to correctly process batches that contain multiple bins\n\n\npatch_fetchers\nApply patches to PyTorchs DataLoader components.\n\n\npatched_worker_loop\nWorker loop that ensures patches are applied in worker processes.\n\n\nremove_multipack_dataloader_patch\nRemove the monkeypatch and restore original PyTorch DataLoader behavior.\n\n\n\n\n\nmonkeypatch.data.batch_dataset_fetcher.apply_multipack_dataloader_patch()\nThis patch allows DataLoader to correctly process batches that contain multiple bins\nof packed sequences.\n\n\n\nmonkeypatch.data.batch_dataset_fetcher.patch_fetchers()\nApply patches to PyTorchs DataLoader components.\n\n\n\nmonkeypatch.data.batch_dataset_fetcher.patched_worker_loop(*args, **kwargs)\nWorker loop that ensures patches are applied in worker processes.\n\n\n\nmonkeypatch.data.batch_dataset_fetcher.remove_multipack_dataloader_patch()\nRemove the monkeypatch and restore original PyTorch DataLoader behavior."
},
{
"objectID": "docs/api/loaders.patch_manager.html",
"href": "docs/api/loaders.patch_manager.html",
"title": "loaders.patch_manager",
"section": "",
"text": "loaders.patch_manager\nPatch manager class implementation to complement axolotl.loaders.ModelLoader.\nApplies pre- and post-model load patches for various fixes and optimizations.\n\n\n\n\n\nName\nDescription\n\n\n\n\nPatchManager\nManages the application of patches during the model loading process.\n\n\n\n\n\nloaders.patch_manager.PatchManager(cfg, model_config, inference=False)\nManages the application of patches during the model loading process.\n\n\n\n\n\nName\nDescription\n\n\n\n\nhas_flash_attn\nCheck if flash attention is installed.\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\napply_post_model_build_patches\nApply patches right after model build, before post-load setup.\n\n\napply_post_model_load_patches\nApply patches that require the model instance.\n\n\napply_post_plugin_pre_model_load_patches\nApply post plugin-pre_model_load load patches based on config.\n\n\napply_pre_config_load_patches\nApply patches that must be set up before config loading.\n\n\napply_pre_model_load_patches\nApply pre-model load patches based on config.\n\n\napply_pre_tokenizer_load_patches\nApply patches that must be set up before tokenizer loading.\n\n\n\n\n\nloaders.patch_manager.PatchManager.apply_post_model_build_patches(model)\nApply patches right after model build, before post-load setup.\n\n\n\nloaders.patch_manager.PatchManager.apply_post_model_load_patches(model)\nApply patches that require the model instance.\n\n\n\nloaders.patch_manager.PatchManager.apply_post_plugin_pre_model_load_patches()\nApply post plugin-pre_model_load load patches based on config.\n\n\n\nloaders.patch_manager.PatchManager.apply_pre_config_load_patches(cfg)\nApply patches that must be set up before config loading.\nThis is for patches that intercept remote code loading from HuggingFace,\nwhich needs to be in place before AutoConfig.from_pretrained() is called.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nConfiguration dictionary with model and training settings.\nrequired\n\n\n\n\n\n\n\nloaders.patch_manager.PatchManager.apply_pre_model_load_patches()\nApply pre-model load patches based on config.\n\n\n\nloaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches(cfg)\nApply patches that must be set up before tokenizer loading.\nThis is for patches that intercept remote code loading from HuggingFace,\nwhich needs to be in place before AutoTokenizer.from_pretrained() is called.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nConfiguration dictionary with model and training settings.\nrequired"
},
{
"objectID": "docs/api/loaders.patch_manager.html#classes",
"href": "docs/api/loaders.patch_manager.html#classes",
"title": "loaders.patch_manager",
"section": "",
"text": "Name\nDescription\n\n\n\n\nPatchManager\nManages the application of patches during the model loading process.\n\n\n\n\n\nloaders.patch_manager.PatchManager(cfg, model_config, inference=False)\nManages the application of patches during the model loading process.\n\n\n\n\n\nName\nDescription\n\n\n\n\nhas_flash_attn\nCheck if flash attention is installed.\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\napply_post_model_build_patches\nApply patches right after model build, before post-load setup.\n\n\napply_post_model_load_patches\nApply patches that require the model instance.\n\n\napply_post_plugin_pre_model_load_patches\nApply post plugin-pre_model_load load patches based on config.\n\n\napply_pre_config_load_patches\nApply patches that must be set up before config loading.\n\n\napply_pre_model_load_patches\nApply pre-model load patches based on config.\n\n\napply_pre_tokenizer_load_patches\nApply patches that must be set up before tokenizer loading.\n\n\n\n\n\nloaders.patch_manager.PatchManager.apply_post_model_build_patches(model)\nApply patches right after model build, before post-load setup.\n\n\n\nloaders.patch_manager.PatchManager.apply_post_model_load_patches(model)\nApply patches that require the model instance.\n\n\n\nloaders.patch_manager.PatchManager.apply_post_plugin_pre_model_load_patches()\nApply post plugin-pre_model_load load patches based on config.\n\n\n\nloaders.patch_manager.PatchManager.apply_pre_config_load_patches(cfg)\nApply patches that must be set up before config loading.\nThis is for patches that intercept remote code loading from HuggingFace,\nwhich needs to be in place before AutoConfig.from_pretrained() is called.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nConfiguration dictionary with model and training settings.\nrequired\n\n\n\n\n\n\n\nloaders.patch_manager.PatchManager.apply_pre_model_load_patches()\nApply pre-model load patches based on config.\n\n\n\nloaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches(cfg)\nApply patches that must be set up before tokenizer loading.\nThis is for patches that intercept remote code loading from HuggingFace,\nwhich needs to be in place before AutoTokenizer.from_pretrained() is called.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nConfiguration dictionary with model and training settings.\nrequired"
},
{
"objectID": "docs/api/utils.model_shard_quant.html",
"href": "docs/api/utils.model_shard_quant.html",
"title": "utils.model_shard_quant",
"section": "",
"text": "utils.model_shard_quant\nmodule to handle loading model on cpu/meta device for FSDP\n\n\n\n\n\nName\nDescription\n\n\n\n\nload_and_quantize\nLoads value tensor into submodule of module, optionally skipping skip_names and converting to dtype.\n\n\n\n\n\nutils.model_shard_quant.load_and_quantize(\n module,\n name,\n value,\n device=None,\n dtype=None,\n skip_names=None,\n to_cpu=False,\n to_meta=False,\n verbose=False,\n quant_method='bnb',\n)\nLoads value tensor into submodule of module, optionally skipping skip_names and converting to dtype.\nQuantizes Params4bit on device then places on “cpu” if to_cpu=True or “meta” if to_meta=True."
},
{
"objectID": "docs/api/utils.model_shard_quant.html#functions",
"href": "docs/api/utils.model_shard_quant.html#functions",
"title": "utils.model_shard_quant",
"section": "",
"text": "Name\nDescription\n\n\n\n\nload_and_quantize\nLoads value tensor into submodule of module, optionally skipping skip_names and converting to dtype.\n\n\n\n\n\nutils.model_shard_quant.load_and_quantize(\n module,\n name,\n value,\n device=None,\n dtype=None,\n skip_names=None,\n to_cpu=False,\n to_meta=False,\n verbose=False,\n quant_method='bnb',\n)\nLoads value tensor into submodule of module, optionally skipping skip_names and converting to dtype.\nQuantizes Params4bit on device then places on “cpu” if to_cpu=True or “meta” if to_meta=True."
},
{
"objectID": "docs/api/utils.schemas.multimodal.html",
"href": "docs/api/utils.schemas.multimodal.html",
"title": "utils.schemas.multimodal",
"section": "",
"text": "utils.schemas.multimodal\nPydantic models for multimodal-related configuration\n\n\n\n\n\nName\nDescription\n\n\n\n\nMultiModalConfig\nMulti-modal configuration subset\n\n\n\n\n\nutils.schemas.multimodal.MultiModalConfig()\nMulti-modal configuration subset\n\n\n\n\n\nName\nDescription\n\n\n\n\nconvert_image_resize_algorithm\nConvert the image resize algorithm to a PIL.Image.Resampling enum.\n\n\n\n\n\nutils.schemas.multimodal.MultiModalConfig.convert_image_resize_algorithm(\n image_resize_algorithm,\n)\nConvert the image resize algorithm to a PIL.Image.Resampling enum."
},
{
"objectID": "docs/api/utils.schemas.multimodal.html#classes",
"href": "docs/api/utils.schemas.multimodal.html#classes",
"title": "utils.schemas.multimodal",
"section": "",
"text": "Name\nDescription\n\n\n\n\nMultiModalConfig\nMulti-modal configuration subset\n\n\n\n\n\nutils.schemas.multimodal.MultiModalConfig()\nMulti-modal configuration subset\n\n\n\n\n\nName\nDescription\n\n\n\n\nconvert_image_resize_algorithm\nConvert the image resize algorithm to a PIL.Image.Resampling enum.\n\n\n\n\n\nutils.schemas.multimodal.MultiModalConfig.convert_image_resize_algorithm(\n image_resize_algorithm,\n)\nConvert the image resize algorithm to a PIL.Image.Resampling enum."
},
{
"objectID": "docs/api/utils.callbacks.profiler.html",
"href": "docs/api/utils.callbacks.profiler.html",
"title": "utils.callbacks.profiler",
"section": "",
"text": "utils.callbacks.profiler\nHF Trainer callback for creating pytorch profiling snapshots\n\n\n\n\n\nName\nDescription\n\n\n\n\nPytorchProfilerCallback\nPyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.\n\n\n\n\n\nutils.callbacks.profiler.PytorchProfilerCallback(\n steps_to_profile=5,\n profiler_steps_start=0,\n)\nPyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.\nAlso runs torch.profiler to produce a Chrome trace for timing analysis."
},
{
"objectID": "docs/api/utils.callbacks.profiler.html#classes",
"href": "docs/api/utils.callbacks.profiler.html#classes",
"title": "utils.callbacks.profiler",
"section": "",
"text": "Name\nDescription\n\n\n\n\nPytorchProfilerCallback\nPyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.\n\n\n\n\n\nutils.callbacks.profiler.PytorchProfilerCallback(\n steps_to_profile=5,\n profiler_steps_start=0,\n)\nPyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.\nAlso runs torch.profiler to produce a Chrome trace for timing analysis."
},
{
"objectID": "docs/api/convert.html",
"href": "docs/api/convert.html",
"title": "convert",
"section": "",
"text": "convert\nModule containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes\n\n\n\n\n\nName\nDescription\n\n\n\n\nFileReader\nReads a file and returns its contents as a string\n\n\nFileWriter\nWrites a string to a file\n\n\nJsonParser\nParses a string as JSON and returns the result\n\n\nJsonToJsonlConverter\nConverts a JSON file to JSONL\n\n\nJsonlSerializer\nSerializes a list of JSON objects into a JSONL string\n\n\nStdoutWriter\nWrites a string to stdout\n\n\n\n\n\nconvert.FileReader()\nReads a file and returns its contents as a string\n\n\n\nconvert.FileWriter(file_path)\nWrites a string to a file\n\n\n\nconvert.JsonParser()\nParses a string as JSON and returns the result\n\n\n\nconvert.JsonToJsonlConverter(\n file_reader,\n file_writer,\n json_parser,\n jsonl_serializer,\n)\nConverts a JSON file to JSONL\n\n\n\nconvert.JsonlSerializer()\nSerializes a list of JSON objects into a JSONL string\n\n\n\nconvert.StdoutWriter()\nWrites a string to stdout"
},
{
"objectID": "docs/api/convert.html#classes",
"href": "docs/api/convert.html#classes",
"title": "convert",
"section": "",
"text": "Name\nDescription\n\n\n\n\nFileReader\nReads a file and returns its contents as a string\n\n\nFileWriter\nWrites a string to a file\n\n\nJsonParser\nParses a string as JSON and returns the result\n\n\nJsonToJsonlConverter\nConverts a JSON file to JSONL\n\n\nJsonlSerializer\nSerializes a list of JSON objects into a JSONL string\n\n\nStdoutWriter\nWrites a string to stdout\n\n\n\n\n\nconvert.FileReader()\nReads a file and returns its contents as a string\n\n\n\nconvert.FileWriter(file_path)\nWrites a string to a file\n\n\n\nconvert.JsonParser()\nParses a string as JSON and returns the result\n\n\n\nconvert.JsonToJsonlConverter(\n file_reader,\n file_writer,\n json_parser,\n jsonl_serializer,\n)\nConverts a JSON file to JSONL\n\n\n\nconvert.JsonlSerializer()\nSerializes a list of JSON objects into a JSONL string\n\n\n\nconvert.StdoutWriter()\nWrites a string to stdout"
},
{
"objectID": "docs/api/cli.utils.html",
"href": "docs/api/cli.utils.html",
"title": "cli.utils",
"section": "",
"text": "cli.utils\ncli.utils\nInit for axolotl.cli.utils module."
},
{
"objectID": "docs/api/kernels.lora.html",
"href": "docs/api/kernels.lora.html",
"title": "kernels.lora",
"section": "",
"text": "kernels.lora\nModule for definition of Low-Rank Adaptation (LoRA) Triton kernels.\nSee “LoRA: Low-Rank Adaptation of Large Language Models”\n(https://arxiv.org/abs/2106.09685).\nAlso supports DoRA (Weight-Decomposed Low-Rank Adaptation):\nSee “DoRA: Weight-Decomposed Low-Rank Adaptation” (https://arxiv.org/abs/2402.09353).\nCredit to unsloth (https://unsloth.ai/) for inspiration for this implementation.\n\n\n\n\n\nName\nDescription\n\n\n\n\nLoRA_Embedding\nFused LoRA embedding: F.embedding(x, W) + s * F.embedding(x, A^T) @ B^T.\n\n\nLoRA_MLP\nOptimized LoRA MLP implementation.\n\n\nLoRA_O\nOptimized LoRA implementation for output projection.\n\n\nLoRA_QKV\nOptimized LoRA QKV implementation with quantization support.\n\n\n\n\n\nkernels.lora.LoRA_Embedding()\nFused LoRA embedding: F.embedding(x, W) + s * F.embedding(x, A^T) @ B^T.\nSupports dropout and DoRA.\n\n\n\nkernels.lora.LoRA_MLP()\nOptimized LoRA MLP implementation.\nSupports bias, dropout, and DoRA. Dropout is applied to the input for\ngate/up projections. The down projection uses hidden states (post-activation)\nas input, so dropout is not applied there.\n\n\n\nkernels.lora.LoRA_O()\nOptimized LoRA implementation for output projection.\nSupports bias, dropout, and DoRA.\n\n\n\nkernels.lora.LoRA_QKV()\nOptimized LoRA QKV implementation with quantization support.\nSupports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).\nDropout is applied outside this Function so autograd handles its backward.\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\napply_lora_embedding\nApplies LoRA to embedding layer.\n\n\napply_lora_mlp_geglu\nApplies LoRA to MLP layer with GEGLU activation.\n\n\napply_lora_mlp_swiglu\nApplies LoRA to MLP layer with SwiGLU activation.\n\n\napply_lora_o\nApplies LoRA to output projection layer.\n\n\napply_lora_qkv\nApplies LoRA to compute Query, Key, Value projections.\n\n\nget_embedding_lora_parameters\nExtract LoRA parameters from a PEFT Embedding module.\n\n\nget_lora_parameters\nGets LoRA parameters from a projection module.\n\n\nmatmul_lora\nEfficient fused matmul + LoRA computation.\n\n\n\n\n\nkernels.lora.apply_lora_embedding(self, x)\nApplies LoRA to embedding layer.\n\n\n\nkernels.lora.apply_lora_mlp_geglu(self, X, inplace=True)\nApplies LoRA to MLP layer with GEGLU activation.\nSupports bias, dropout, and DoRA.\n\n\n\nkernels.lora.apply_lora_mlp_swiglu(self, X, inplace=True)\nApplies LoRA to MLP layer with SwiGLU activation.\nSupports bias, dropout, and DoRA.\n\n\n\nkernels.lora.apply_lora_o(self, X)\nApplies LoRA to output projection layer.\nSupports bias, dropout, and DoRA.\n\n\n\nkernels.lora.apply_lora_qkv(self, X, inplace=True)\nApplies LoRA to compute Query, Key, Value projections.\nSupports bias, dropout, and DoRA. Dropout is applied outside the autograd\nFunction so PyTorch handles its backward automatically. A single shared\ndropout mask is used across Q, K, V projections for memory efficiency.\n\n\n\nkernels.lora.get_embedding_lora_parameters(embed)\nExtract LoRA parameters from a PEFT Embedding module.\n\n\n\nkernels.lora.get_lora_parameters(proj)\nGets LoRA parameters from a projection module.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nproj\nnn.Module\nThe projection module to extract parameters from.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nA tuple containing:\n\n\n\ntorch.Tensor | None\n- W: base weight tensor\n\n\n\nQuantState | torch.Tensor | None\n- b: base layer bias (or None)\n\n\n\ntorch.Tensor | None\n- quant_state: quantization state (or None)\n\n\n\ntorch.Tensor | None\n- A: LoRA A weight (or None)\n\n\n\nfloat | None\n- B: LoRA B weight (or None)\n\n\n\ntorch.Tensor | None\n- s: LoRA scaling factor (or None)\n\n\n\nnn.Module | None\n- lora_bias: LoRA B bias (or None)\n\n\n\ntorch.Tensor | None\n- dropout: dropout module (or None)\n\n\n\ntuple[torch.Tensor, torch.Tensor | None, QuantState | torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, float | None, torch.Tensor | None, nn.Module | None, torch.Tensor | None]\n- magnitude: DoRA magnitude vector (or None)\n\n\n\n\n\n\n\nkernels.lora.matmul_lora(\n X,\n W,\n b,\n W_quant,\n A,\n B,\n s,\n out=None,\n X_drop=None,\n lora_bias=None,\n)\nEfficient fused matmul + LoRA computation.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nX\ntorch.Tensor\nInput tensor [*, in_features]\nrequired\n\n\nW\ntorch.Tensor\nBase weight matrix [out_features, in_features]\nrequired\n\n\nW_quant\nQuantState | torch.Tensor | None\nQuantization state for W\nrequired\n\n\nA\ntorch.Tensor | None\nLoRA A matrix [rank, in_features]\nrequired\n\n\nB\ntorch.Tensor | None\nLoRA B matrix [out_features, rank]\nrequired\n\n\ns\nfloat | None\nLoRA scaling factor\nrequired\n\n\nout\ntorch.Tensor | None\nOptional output tensor for inplace operations\nNone\n\n\nX_drop\ntorch.Tensor | None\nOptional dropout-applied input for LoRA path (if None, uses X)\nNone\n\n\nlora_bias\ntorch.Tensor | None\nOptional LoRA B layer bias [out_features]\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nResult of X @ W + s * X_drop @ A @ B + b + s * lora_bias"
},
{
"objectID": "docs/api/kernels.lora.html#classes",
"href": "docs/api/kernels.lora.html#classes",
"title": "kernels.lora",
"section": "",
"text": "Name\nDescription\n\n\n\n\nLoRA_Embedding\nFused LoRA embedding: F.embedding(x, W) + s * F.embedding(x, A^T) @ B^T.\n\n\nLoRA_MLP\nOptimized LoRA MLP implementation.\n\n\nLoRA_O\nOptimized LoRA implementation for output projection.\n\n\nLoRA_QKV\nOptimized LoRA QKV implementation with quantization support.\n\n\n\n\n\nkernels.lora.LoRA_Embedding()\nFused LoRA embedding: F.embedding(x, W) + s * F.embedding(x, A^T) @ B^T.\nSupports dropout and DoRA.\n\n\n\nkernels.lora.LoRA_MLP()\nOptimized LoRA MLP implementation.\nSupports bias, dropout, and DoRA. Dropout is applied to the input for\ngate/up projections. The down projection uses hidden states (post-activation)\nas input, so dropout is not applied there.\n\n\n\nkernels.lora.LoRA_O()\nOptimized LoRA implementation for output projection.\nSupports bias, dropout, and DoRA.\n\n\n\nkernels.lora.LoRA_QKV()\nOptimized LoRA QKV implementation with quantization support.\nSupports bias, dropout, and DoRA (Weight-Decomposed Low-Rank Adaptation).\nDropout is applied outside this Function so autograd handles its backward."
},
{
"objectID": "docs/api/kernels.lora.html#functions",
"href": "docs/api/kernels.lora.html#functions",
"title": "kernels.lora",
"section": "",
"text": "Name\nDescription\n\n\n\n\napply_lora_embedding\nApplies LoRA to embedding layer.\n\n\napply_lora_mlp_geglu\nApplies LoRA to MLP layer with GEGLU activation.\n\n\napply_lora_mlp_swiglu\nApplies LoRA to MLP layer with SwiGLU activation.\n\n\napply_lora_o\nApplies LoRA to output projection layer.\n\n\napply_lora_qkv\nApplies LoRA to compute Query, Key, Value projections.\n\n\nget_embedding_lora_parameters\nExtract LoRA parameters from a PEFT Embedding module.\n\n\nget_lora_parameters\nGets LoRA parameters from a projection module.\n\n\nmatmul_lora\nEfficient fused matmul + LoRA computation.\n\n\n\n\n\nkernels.lora.apply_lora_embedding(self, x)\nApplies LoRA to embedding layer.\n\n\n\nkernels.lora.apply_lora_mlp_geglu(self, X, inplace=True)\nApplies LoRA to MLP layer with GEGLU activation.\nSupports bias, dropout, and DoRA.\n\n\n\nkernels.lora.apply_lora_mlp_swiglu(self, X, inplace=True)\nApplies LoRA to MLP layer with SwiGLU activation.\nSupports bias, dropout, and DoRA.\n\n\n\nkernels.lora.apply_lora_o(self, X)\nApplies LoRA to output projection layer.\nSupports bias, dropout, and DoRA.\n\n\n\nkernels.lora.apply_lora_qkv(self, X, inplace=True)\nApplies LoRA to compute Query, Key, Value projections.\nSupports bias, dropout, and DoRA. Dropout is applied outside the autograd\nFunction so PyTorch handles its backward automatically. A single shared\ndropout mask is used across Q, K, V projections for memory efficiency.\n\n\n\nkernels.lora.get_embedding_lora_parameters(embed)\nExtract LoRA parameters from a PEFT Embedding module.\n\n\n\nkernels.lora.get_lora_parameters(proj)\nGets LoRA parameters from a projection module.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nproj\nnn.Module\nThe projection module to extract parameters from.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nA tuple containing:\n\n\n\ntorch.Tensor | None\n- W: base weight tensor\n\n\n\nQuantState | torch.Tensor | None\n- b: base layer bias (or None)\n\n\n\ntorch.Tensor | None\n- quant_state: quantization state (or None)\n\n\n\ntorch.Tensor | None\n- A: LoRA A weight (or None)\n\n\n\nfloat | None\n- B: LoRA B weight (or None)\n\n\n\ntorch.Tensor | None\n- s: LoRA scaling factor (or None)\n\n\n\nnn.Module | None\n- lora_bias: LoRA B bias (or None)\n\n\n\ntorch.Tensor | None\n- dropout: dropout module (or None)\n\n\n\ntuple[torch.Tensor, torch.Tensor | None, QuantState | torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, float | None, torch.Tensor | None, nn.Module | None, torch.Tensor | None]\n- magnitude: DoRA magnitude vector (or None)\n\n\n\n\n\n\n\nkernels.lora.matmul_lora(\n X,\n W,\n b,\n W_quant,\n A,\n B,\n s,\n out=None,\n X_drop=None,\n lora_bias=None,\n)\nEfficient fused matmul + LoRA computation.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nX\ntorch.Tensor\nInput tensor [*, in_features]\nrequired\n\n\nW\ntorch.Tensor\nBase weight matrix [out_features, in_features]\nrequired\n\n\nW_quant\nQuantState | torch.Tensor | None\nQuantization state for W\nrequired\n\n\nA\ntorch.Tensor | None\nLoRA A matrix [rank, in_features]\nrequired\n\n\nB\ntorch.Tensor | None\nLoRA B matrix [out_features, rank]\nrequired\n\n\ns\nfloat | None\nLoRA scaling factor\nrequired\n\n\nout\ntorch.Tensor | None\nOptional output tensor for inplace operations\nNone\n\n\nX_drop\ntorch.Tensor | None\nOptional dropout-applied input for LoRA path (if None, uses X)\nNone\n\n\nlora_bias\ntorch.Tensor | None\nOptional LoRA B layer bias [out_features]\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nResult of X @ W + s * X_drop @ A @ B + b + s * lora_bias"
},
{
"objectID": "docs/api/monkeypatch.utils.html",
"href": "docs/api/monkeypatch.utils.html",
"title": "monkeypatch.utils",
"section": "",
"text": "monkeypatch.utils\nShared utils for the monkeypatches\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_cu_seqlens\ngenerate a cumulative sequence length mask for flash attention using attn mask\n\n\nget_cu_seqlens_from_pos_ids\ngenerate a cumulative sequence length mask for flash attention using pos ids\n\n\n\n\n\nmonkeypatch.utils.get_cu_seqlens(attn_mask)\ngenerate a cumulative sequence length mask for flash attention using attn mask\n\n\n\nmonkeypatch.utils.get_cu_seqlens_from_pos_ids(position_ids)\ngenerate a cumulative sequence length mask for flash attention using pos ids"
},
{
"objectID": "docs/api/monkeypatch.utils.html#functions",
"href": "docs/api/monkeypatch.utils.html#functions",
"title": "monkeypatch.utils",
"section": "",
"text": "Name\nDescription\n\n\n\n\nget_cu_seqlens\ngenerate a cumulative sequence length mask for flash attention using attn mask\n\n\nget_cu_seqlens_from_pos_ids\ngenerate a cumulative sequence length mask for flash attention using pos ids\n\n\n\n\n\nmonkeypatch.utils.get_cu_seqlens(attn_mask)\ngenerate a cumulative sequence length mask for flash attention using attn mask\n\n\n\nmonkeypatch.utils.get_cu_seqlens_from_pos_ids(position_ids)\ngenerate a cumulative sequence length mask for flash attention using pos ids"
},
{
"objectID": "docs/api/common.const.html",
"href": "docs/api/common.const.html",
"title": "common.const",
"section": "",
"text": "common.const\ncommon.const\nVarious shared constants"
},
{
"objectID": "docs/api/utils.freeze.html",
"href": "docs/api/utils.freeze.html",
"title": "utils.freeze",
"section": "",
"text": "utils.freeze\nmodule to freeze/unfreeze parameters by name\n\n\n\n\n\nName\nDescription\n\n\n\n\nLayerNamePattern\nRepresents a regex pattern for layer names, potentially including a parameter index range.\n\n\n\n\n\nutils.freeze.LayerNamePattern(pattern)\nRepresents a regex pattern for layer names, potentially including a parameter index range.\n\n\n\n\n\nName\nDescription\n\n\n\n\nmatch\nChecks if the given layer name matches the regex pattern.\n\n\n\n\n\nutils.freeze.LayerNamePattern.match(name)\nChecks if the given layer name matches the regex pattern.\nParameters:\n- name (str): The layer name to check.\nReturns:\n- bool: True if the layer name matches the pattern, False otherwise.\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nfreeze_layers_except\nFreezes all layers of the given model except for the layers that match given regex patterns.\n\n\n\n\n\nutils.freeze.freeze_layers_except(model, regex_patterns)\nFreezes all layers of the given model except for the layers that match given regex patterns.\nPeriods in the patterns are treated as literal periods, not as wildcard characters.\nParameters:\n- model (nn.Module): The PyTorch model to be modified.\n- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.\nNote that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.\nAlso, to match the entire layer name, the pattern should start with “^” and end with “\\(\", otherwise it will match any part of the layer name.\n The range pattern part is optional and it is not compiled as a regex pattern which means you must put \"\\)” before the range pattern if you want to match the entire layer name.\nE.g., [“^model.embed_tokens.weight\\([:32000]\", \"layers.2[0-9]+.block_sparse_moe.gate.[a-z]+\\)”]\nReturns:\nNone; the model is modified in place."
},
{
"objectID": "docs/api/utils.freeze.html#classes",
"href": "docs/api/utils.freeze.html#classes",
"title": "utils.freeze",
"section": "",
"text": "Name\nDescription\n\n\n\n\nLayerNamePattern\nRepresents a regex pattern for layer names, potentially including a parameter index range.\n\n\n\n\n\nutils.freeze.LayerNamePattern(pattern)\nRepresents a regex pattern for layer names, potentially including a parameter index range.\n\n\n\n\n\nName\nDescription\n\n\n\n\nmatch\nChecks if the given layer name matches the regex pattern.\n\n\n\n\n\nutils.freeze.LayerNamePattern.match(name)\nChecks if the given layer name matches the regex pattern.\nParameters:\n- name (str): The layer name to check.\nReturns:\n- bool: True if the layer name matches the pattern, False otherwise."
},
{
"objectID": "docs/api/utils.freeze.html#functions",
"href": "docs/api/utils.freeze.html#functions",
"title": "utils.freeze",
"section": "",
"text": "Name\nDescription\n\n\n\n\nfreeze_layers_except\nFreezes all layers of the given model except for the layers that match given regex patterns.\n\n\n\n\n\nutils.freeze.freeze_layers_except(model, regex_patterns)\nFreezes all layers of the given model except for the layers that match given regex patterns.\nPeriods in the patterns are treated as literal periods, not as wildcard characters.\nParameters:\n- model (nn.Module): The PyTorch model to be modified.\n- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.\nNote that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.\nAlso, to match the entire layer name, the pattern should start with “^” and end with “\\(\", otherwise it will match any part of the layer name.\n The range pattern part is optional and it is not compiled as a regex pattern which means you must put \"\\)” before the range pattern if you want to match the entire layer name.\nE.g., [“^model.embed_tokens.weight\\([:32000]\", \"layers.2[0-9]+.block_sparse_moe.gate.[a-z]+\\)”]\nReturns:\nNone; the model is modified in place."
},
{
"objectID": "docs/api/utils.schemas.utils.html",
"href": "docs/api/utils.schemas.utils.html",
"title": "utils.schemas.utils",
"section": "",
"text": "utils.schemas.utils\nUtilities for Axolotl Pydantic models\n\n\n\n\n\nName\nDescription\n\n\n\n\nhandle_legacy_message_fields_logic\nHandle backwards compatibility between legacy message field mapping and new property mapping system.\n\n\n\n\n\nutils.schemas.utils.handle_legacy_message_fields_logic(data)\nHandle backwards compatibility between legacy message field mapping and new property mapping system.\nPreviously, the config only supported mapping role and content fields via dedicated config options:\n- message_field_role: Mapped to the role field\n- message_field_content: Mapped to the content field\nThe new system uses message_property_mappings to support arbitrary field mappings:\nmessage_property_mappings:\nrole: source_role_field\ncontent: source_content_field\nadditional_field: source_field\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ndata\ndict\nDictionary containing configuration data\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ndict\nUpdated dictionary with message field mappings consolidated\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf there are conflicts between legacy and new mappings"
},
{
"objectID": "docs/api/utils.schemas.utils.html#functions",
"href": "docs/api/utils.schemas.utils.html#functions",
"title": "utils.schemas.utils",
"section": "",
"text": "Name\nDescription\n\n\n\n\nhandle_legacy_message_fields_logic\nHandle backwards compatibility between legacy message field mapping and new property mapping system.\n\n\n\n\n\nutils.schemas.utils.handle_legacy_message_fields_logic(data)\nHandle backwards compatibility between legacy message field mapping and new property mapping system.\nPreviously, the config only supported mapping role and content fields via dedicated config options:\n- message_field_role: Mapped to the role field\n- message_field_content: Mapped to the content field\nThe new system uses message_property_mappings to support arbitrary field mappings:\nmessage_property_mappings:\nrole: source_role_field\ncontent: source_content_field\nadditional_field: source_field\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ndata\ndict\nDictionary containing configuration data\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ndict\nUpdated dictionary with message field mappings consolidated\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf there are conflicts between legacy and new mappings"
},
{
"objectID": "docs/api/utils.callbacks.qat.html",
"href": "docs/api/utils.callbacks.qat.html",
"title": "utils.callbacks.qat",
"section": "",
"text": "utils.callbacks.qat\nQAT Callback for HF Causal Trainer\n\n\n\n\n\nName\nDescription\n\n\n\n\nQATCallback\nCallback to toggle fake quantization for the model.\n\n\n\n\n\nutils.callbacks.qat.QATCallback(cfg)\nCallback to toggle fake quantization for the model.\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\ntoggle_fake_quant\nToggle fake quantization for any fake quantized linear or embedding layers in the model.\n\n\n\n\n\nutils.callbacks.qat.toggle_fake_quant(mod, enable)\nToggle fake quantization for any fake quantized linear or embedding layers in the model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmod\nnn.Module\nThe module to toggle fake quantization for.\nrequired\n\n\nenable\nbool\nWhether to enable or disable fake quantization.\nrequired"
},
{
"objectID": "docs/api/utils.callbacks.qat.html#classes",
"href": "docs/api/utils.callbacks.qat.html#classes",
"title": "utils.callbacks.qat",
"section": "",
"text": "Name\nDescription\n\n\n\n\nQATCallback\nCallback to toggle fake quantization for the model.\n\n\n\n\n\nutils.callbacks.qat.QATCallback(cfg)\nCallback to toggle fake quantization for the model."
},
{
"objectID": "docs/api/utils.callbacks.qat.html#functions",
"href": "docs/api/utils.callbacks.qat.html#functions",
"title": "utils.callbacks.qat",
"section": "",
"text": "Name\nDescription\n\n\n\n\ntoggle_fake_quant\nToggle fake quantization for any fake quantized linear or embedding layers in the model.\n\n\n\n\n\nutils.callbacks.qat.toggle_fake_quant(mod, enable)\nToggle fake quantization for any fake quantized linear or embedding layers in the model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmod\nnn.Module\nThe module to toggle fake quantization for.\nrequired\n\n\nenable\nbool\nWhether to enable or disable fake quantization.\nrequired"
},
{
"objectID": "docs/api/utils.data.sft.html",
"href": "docs/api/utils.data.sft.html",
"title": "utils.data.sft",
"section": "",
"text": "utils.data.sft\nData handling specific to SFT.\n\n\n\n\n\nName\nDescription\n\n\n\n\nprepare_datasets\nPrepare training and evaluation datasets based on configuration.\n\n\n\n\n\nutils.data.sft.prepare_datasets(cfg, tokenizer, processor=None)\nPrepare training and evaluation datasets based on configuration.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntokenizer\nPreTrainedTokenizer\nTokenizer to use for processing text.\nrequired\n\n\nprocessor\nProcessorMixin | None\nOptional processor for multimodal datasets.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]\nTuple of (train_dataset, eval_dataset, total_steps, prompters)."
},
{
"objectID": "docs/api/utils.data.sft.html#functions",
"href": "docs/api/utils.data.sft.html#functions",
"title": "utils.data.sft",
"section": "",
"text": "Name\nDescription\n\n\n\n\nprepare_datasets\nPrepare training and evaluation datasets based on configuration.\n\n\n\n\n\nutils.data.sft.prepare_datasets(cfg, tokenizer, processor=None)\nPrepare training and evaluation datasets based on configuration.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntokenizer\nPreTrainedTokenizer\nTokenizer to use for processing text.\nrequired\n\n\nprocessor\nProcessorMixin | None\nOptional processor for multimodal datasets.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]\nTuple of (train_dataset, eval_dataset, total_steps, prompters)."
},
{
"objectID": "docs/api/monkeypatch.llama_attn_hijack_xformers.html",
"href": "docs/api/monkeypatch.llama_attn_hijack_xformers.html",
"title": "monkeypatch.llama_attn_hijack_xformers",
"section": "",
"text": "monkeypatch.llama_attn_hijack_xformers\nmonkeypatch.llama_attn_hijack_xformers\nDirectly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments"
},
{
"objectID": "docs/api/core.trainers.grpo.sampler.html",
"href": "docs/api/core.trainers.grpo.sampler.html",
"title": "core.trainers.grpo.sampler",
"section": "",
"text": "core.trainers.grpo.sampler\nRepeat random sampler (similar to the one implemented in\nhttps://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds\nsequence parallelism functionality; i.e., duplicating data across ranks in the same\nsequence parallel group.\n\n\n\n\n\nName\nDescription\n\n\n\n\nSequenceParallelRepeatRandomSampler\nSampler for GRPO training with sequence parallelism.\n\n\n\n\n\ncore.trainers.grpo.sampler.SequenceParallelRepeatRandomSampler(\n dataset,\n mini_repeat_count,\n world_size,\n rank,\n batch_size=1,\n repeat_count=1,\n context_parallel_size=1,\n shuffle=True,\n seed=0,\n drop_last=False,\n)\nSampler for GRPO training with sequence parallelism.\nThis sampler ensures:\n- Ranks in the same sequence parallel (SP) group receive identical data.\n- Each index is repeated multiple times for sampling different completions.\n- Entire batches are repeated for reuse in multiple updates.\n- Data is properly distributed across SP groups.\nIn the table below, the values represent dataset indices. Each SP group has\ncontext_parallel_size = 2 GPUs working together on the same data. There are 2\nSP groups (SP0 and SP1), with world_size = 4 total GPUs.\n Sequence Parallel Groups\n | SP0 | SP1 |\n | GPU 0 | GPU 1 | GPU 2 | GPU 3 |\n global_step step <---> mini_repeat_count=3\n <----------> batch_size=2 per SP group\ngrad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data\n▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU\n|\n| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations\nnum_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation\n 2 4 [4 4 4 5 5 5] [6 6 6 7 7 7] <- New batch of data indices\n 2 5 [4 4 4 5 5 5] [6 6 6 7 7 7]\n ...\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ndataset\nSized\nDataset to sample from.\nrequired\n\n\nmini_repeat_count\nint\nHow many times to repeat each sample immediately.\nrequired\n\n\nworld_size\nint\nTotal number of processes.\nrequired\n\n\nrank\nint\nRank of current process.\nrequired\n\n\nbatch_size\nint\nNumber of samples per batch.\n1\n\n\nrepeat_count\nint\nHow many times to repeat the full sampling process.\n1\n\n\ncontext_parallel_size\nint\nNumber of ranks in a sequence parallel group.\n1\n\n\nshuffle\nbool\nWhether to shuffle the dataset.\nTrue\n\n\nseed\nint\nRandom seed for shuffling.\n0\n\n\ndrop_last\nbool\nWhether to drop the last incomplete batch.\nFalse\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nset_epoch\nSets the epoch for this sampler.\n\n\n\n\n\ncore.trainers.grpo.sampler.SequenceParallelRepeatRandomSampler.set_epoch(epoch)\nSets the epoch for this sampler.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nepoch\nint\nEpoch number to use for shuffling.\nrequired"
},
{
"objectID": "docs/api/core.trainers.grpo.sampler.html#classes",
"href": "docs/api/core.trainers.grpo.sampler.html#classes",
"title": "core.trainers.grpo.sampler",
"section": "",
"text": "Name\nDescription\n\n\n\n\nSequenceParallelRepeatRandomSampler\nSampler for GRPO training with sequence parallelism.\n\n\n\n\n\ncore.trainers.grpo.sampler.SequenceParallelRepeatRandomSampler(\n dataset,\n mini_repeat_count,\n world_size,\n rank,\n batch_size=1,\n repeat_count=1,\n context_parallel_size=1,\n shuffle=True,\n seed=0,\n drop_last=False,\n)\nSampler for GRPO training with sequence parallelism.\nThis sampler ensures:\n- Ranks in the same sequence parallel (SP) group receive identical data.\n- Each index is repeated multiple times for sampling different completions.\n- Entire batches are repeated for reuse in multiple updates.\n- Data is properly distributed across SP groups.\nIn the table below, the values represent dataset indices. Each SP group has\ncontext_parallel_size = 2 GPUs working together on the same data. There are 2\nSP groups (SP0 and SP1), with world_size = 4 total GPUs.\n Sequence Parallel Groups\n | SP0 | SP1 |\n | GPU 0 | GPU 1 | GPU 2 | GPU 3 |\n global_step step <---> mini_repeat_count=3\n <----------> batch_size=2 per SP group\ngrad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data\n▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU\n|\n| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations\nnum_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation\n 2 4 [4 4 4 5 5 5] [6 6 6 7 7 7] <- New batch of data indices\n 2 5 [4 4 4 5 5 5] [6 6 6 7 7 7]\n ...\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ndataset\nSized\nDataset to sample from.\nrequired\n\n\nmini_repeat_count\nint\nHow many times to repeat each sample immediately.\nrequired\n\n\nworld_size\nint\nTotal number of processes.\nrequired\n\n\nrank\nint\nRank of current process.\nrequired\n\n\nbatch_size\nint\nNumber of samples per batch.\n1\n\n\nrepeat_count\nint\nHow many times to repeat the full sampling process.\n1\n\n\ncontext_parallel_size\nint\nNumber of ranks in a sequence parallel group.\n1\n\n\nshuffle\nbool\nWhether to shuffle the dataset.\nTrue\n\n\nseed\nint\nRandom seed for shuffling.\n0\n\n\ndrop_last\nbool\nWhether to drop the last incomplete batch.\nFalse\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nset_epoch\nSets the epoch for this sampler.\n\n\n\n\n\ncore.trainers.grpo.sampler.SequenceParallelRepeatRandomSampler.set_epoch(epoch)\nSets the epoch for this sampler.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nepoch\nint\nEpoch number to use for shuffling.\nrequired"
},
{
"objectID": "docs/api/core.chat.messages.html",
"href": "docs/api/core.chat.messages.html",
"title": "core.chat.messages",
"section": "",
"text": "core.chat.messages\ninternal message representations of chat messages\n\n\n\n\n\nName\nDescription\n\n\n\n\nChatFormattedChats\nChat formatted chats with formatter and optional train on inputs\n\n\nChats\ntop level data structure for chat conversations\n\n\nMessageContentTypes\nMessage content types for text, image, audio, tool calls, and tool responses\n\n\nMessageContents\nMessage contents with type, value, metadata, weight, newline, and end of contents\n\n\nMessageRoles\nMessage roles for the system, user, assistant, and tools\n\n\nMessages\nMessages with role, content, metadata, weight, and chat formatting\n\n\nPreferenceChats\nrepresentation for preference data for chat\n\n\nSpecialToken\nSpecial tokens for beginning of string and end of string\n\n\nTool\nTool with description, function, and parameters\n\n\nToolCallContents\nTool call contents with name, arguments, and optional id\n\n\nToolCallFunction\nTool call function with name and arguments\n\n\nToolResponseContents\nTool response contents with name, content, and optional id\n\n\n\n\n\ncore.chat.messages.ChatFormattedChats()\nChat formatted chats with formatter and optional train on inputs\n\n\n\ncore.chat.messages.Chats()\ntop level data structure for chat conversations\n\n\n\ncore.chat.messages.MessageContentTypes()\nMessage content types for text, image, audio, tool calls, and tool responses\n\n\n\ncore.chat.messages.MessageContents()\nMessage contents with type, value, metadata, weight, newline, and end of contents\n\n\n\ncore.chat.messages.MessageRoles()\nMessage roles for the system, user, assistant, and tools\n\n\n\ncore.chat.messages.Messages()\nMessages with role, content, metadata, weight, and chat formatting\n\n\n\ncore.chat.messages.PreferenceChats()\nrepresentation for preference data for chat\n\n\n\ncore.chat.messages.SpecialToken()\nSpecial tokens for beginning of string and end of string\n\n\n\ncore.chat.messages.Tool()\nTool with description, function, and parameters\n\n\n\ncore.chat.messages.ToolCallContents()\nTool call contents with name, arguments, and optional id\n\n\n\ncore.chat.messages.ToolCallFunction()\nTool call function with name and arguments\n\n\n\ncore.chat.messages.ToolResponseContents()\nTool response contents with name, content, and optional id"
},
{
"objectID": "docs/api/core.chat.messages.html#classes",
"href": "docs/api/core.chat.messages.html#classes",
"title": "core.chat.messages",
"section": "",
"text": "Name\nDescription\n\n\n\n\nChatFormattedChats\nChat formatted chats with formatter and optional train on inputs\n\n\nChats\ntop level data structure for chat conversations\n\n\nMessageContentTypes\nMessage content types for text, image, audio, tool calls, and tool responses\n\n\nMessageContents\nMessage contents with type, value, metadata, weight, newline, and end of contents\n\n\nMessageRoles\nMessage roles for the system, user, assistant, and tools\n\n\nMessages\nMessages with role, content, metadata, weight, and chat formatting\n\n\nPreferenceChats\nrepresentation for preference data for chat\n\n\nSpecialToken\nSpecial tokens for beginning of string and end of string\n\n\nTool\nTool with description, function, and parameters\n\n\nToolCallContents\nTool call contents with name, arguments, and optional id\n\n\nToolCallFunction\nTool call function with name and arguments\n\n\nToolResponseContents\nTool response contents with name, content, and optional id\n\n\n\n\n\ncore.chat.messages.ChatFormattedChats()\nChat formatted chats with formatter and optional train on inputs\n\n\n\ncore.chat.messages.Chats()\ntop level data structure for chat conversations\n\n\n\ncore.chat.messages.MessageContentTypes()\nMessage content types for text, image, audio, tool calls, and tool responses\n\n\n\ncore.chat.messages.MessageContents()\nMessage contents with type, value, metadata, weight, newline, and end of contents\n\n\n\ncore.chat.messages.MessageRoles()\nMessage roles for the system, user, assistant, and tools\n\n\n\ncore.chat.messages.Messages()\nMessages with role, content, metadata, weight, and chat formatting\n\n\n\ncore.chat.messages.PreferenceChats()\nrepresentation for preference data for chat\n\n\n\ncore.chat.messages.SpecialToken()\nSpecial tokens for beginning of string and end of string\n\n\n\ncore.chat.messages.Tool()\nTool with description, function, and parameters\n\n\n\ncore.chat.messages.ToolCallContents()\nTool call contents with name, arguments, and optional id\n\n\n\ncore.chat.messages.ToolCallFunction()\nTool call function with name and arguments\n\n\n\ncore.chat.messages.ToolResponseContents()\nTool response contents with name, content, and optional id"
},
{
"objectID": "docs/api/core.trainers.mamba.html",
"href": "docs/api/core.trainers.mamba.html",
"title": "core.trainers.mamba",
"section": "",
"text": "core.trainers.mamba\nModule for mamba trainer\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlMambaTrainer\nMamba specific trainer to handle loss calculation\n\n\n\n\n\ncore.trainers.mamba.AxolotlMambaTrainer(\n *_args,\n bench_data_collator=None,\n eval_data_collator=None,\n dataset_tags=None,\n **kwargs,\n)\nMamba specific trainer to handle loss calculation"
},
{
"objectID": "docs/api/core.trainers.mamba.html#classes",
"href": "docs/api/core.trainers.mamba.html#classes",
"title": "core.trainers.mamba",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlMambaTrainer\nMamba specific trainer to handle loss calculation\n\n\n\n\n\ncore.trainers.mamba.AxolotlMambaTrainer(\n *_args,\n bench_data_collator=None,\n eval_data_collator=None,\n dataset_tags=None,\n **kwargs,\n)\nMamba specific trainer to handle loss calculation"
},
{
"objectID": "docs/api/prompt_strategies.dpo.passthrough.html",
"href": "docs/api/prompt_strategies.dpo.passthrough.html",
"title": "prompt_strategies.dpo.passthrough",
"section": "",
"text": "prompt_strategies.dpo.passthrough\nprompt_strategies.dpo.passthrough\nDPO prompt strategies passthrough/zero-processing strategy"
},
{
"objectID": "docs/api/kernels.swiglu.html",
"href": "docs/api/kernels.swiglu.html",
"title": "kernels.swiglu",
"section": "",
"text": "kernels.swiglu\nModule for definition of SwiGLU Triton kernels.\nSee “GLU Variants Improve Transformer” (https://arxiv.org/abs/2002.05202).\nCredit to unsloth (https://unsloth.ai/) for inspiration for this implementation.\n\n\n\n\n\nName\nDescription\n\n\n\n\nswiglu_backward\nSwiGLU backward pass using in-place operations.\n\n\nswiglu_forward\nSwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up, where\n\n\n\n\n\nkernels.swiglu.swiglu_backward(grad_output, gate, up)\nSwiGLU backward pass using in-place operations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ngrad_output\ntorch.Tensor\nGradient of loss with respect to output, shape [batch, seq_len, hidden_dim].\nrequired\n\n\ngate\ntorch.Tensor\nGate tensor from forward pass, shape [batch, seq_len, hidden_dim].\nrequired\n\n\nup\ntorch.Tensor\nUp-projection tensor from forward pass, shape [batch, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[torch.Tensor, torch.Tensor, torch.Tensor]\nTuple containing: - Forward pass output (h) - Gradient with respect to gate (df) - Gradient with respect to up-projection (de)\n\n\n\n\n\n\n\nkernels.swiglu.swiglu_forward(gate, up)\nSwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up, where\nx is the gate tensor.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ngate\ntorch.Tensor\nInput gate tensor of shape [batch, seq_len, hidden_dim].\nrequired\n\n\nup\ntorch.Tensor\nUp-projection tensor of shape [batch, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nOutput tensor of shape [batch, seq_len, hidden_dim]."
},
{
"objectID": "docs/api/kernels.swiglu.html#functions",
"href": "docs/api/kernels.swiglu.html#functions",
"title": "kernels.swiglu",
"section": "",
"text": "Name\nDescription\n\n\n\n\nswiglu_backward\nSwiGLU backward pass using in-place operations.\n\n\nswiglu_forward\nSwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up, where\n\n\n\n\n\nkernels.swiglu.swiglu_backward(grad_output, gate, up)\nSwiGLU backward pass using in-place operations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ngrad_output\ntorch.Tensor\nGradient of loss with respect to output, shape [batch, seq_len, hidden_dim].\nrequired\n\n\ngate\ntorch.Tensor\nGate tensor from forward pass, shape [batch, seq_len, hidden_dim].\nrequired\n\n\nup\ntorch.Tensor\nUp-projection tensor from forward pass, shape [batch, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[torch.Tensor, torch.Tensor, torch.Tensor]\nTuple containing: - Forward pass output (h) - Gradient with respect to gate (df) - Gradient with respect to up-projection (de)\n\n\n\n\n\n\n\nkernels.swiglu.swiglu_forward(gate, up)\nSwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up, where\nx is the gate tensor.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ngate\ntorch.Tensor\nInput gate tensor of shape [batch, seq_len, hidden_dim].\nrequired\n\n\nup\ntorch.Tensor\nUp-projection tensor of shape [batch, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nOutput tensor of shape [batch, seq_len, hidden_dim]."
},
{
"objectID": "docs/api/prompt_strategies.pygmalion.html",
"href": "docs/api/prompt_strategies.pygmalion.html",
"title": "prompt_strategies.pygmalion",
"section": "",
"text": "prompt_strategies.pygmalion\nModule containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class\n\n\n\n\n\nName\nDescription\n\n\n\n\nPygmalionPromptTokenizingStrategy\nTokenizing strategy for Pygmalion.\n\n\nPygmalionPrompter\nPrompter for Pygmalion.\n\n\n\n\n\nprompt_strategies.pygmalion.PygmalionPromptTokenizingStrategy(\n prompter,\n tokenizer,\n *args,\n **kwargs,\n)\nTokenizing strategy for Pygmalion.\n\n\n\nprompt_strategies.pygmalion.PygmalionPrompter(*args, **kwargs)\nPrompter for Pygmalion."
},
{
"objectID": "docs/api/prompt_strategies.pygmalion.html#classes",
"href": "docs/api/prompt_strategies.pygmalion.html#classes",
"title": "prompt_strategies.pygmalion",
"section": "",
"text": "Name\nDescription\n\n\n\n\nPygmalionPromptTokenizingStrategy\nTokenizing strategy for Pygmalion.\n\n\nPygmalionPrompter\nPrompter for Pygmalion.\n\n\n\n\n\nprompt_strategies.pygmalion.PygmalionPromptTokenizingStrategy(\n prompter,\n tokenizer,\n *args,\n **kwargs,\n)\nTokenizing strategy for Pygmalion.\n\n\n\nprompt_strategies.pygmalion.PygmalionPrompter(*args, **kwargs)\nPrompter for Pygmalion."
},
{
"objectID": "docs/api/utils.schemas.peft.html",
"href": "docs/api/utils.schemas.peft.html",
"title": "utils.schemas.peft",
"section": "",
"text": "utils.schemas.peft\nPydantic models for PEFT-related configuration\n\n\n\n\n\nName\nDescription\n\n\n\n\nLoftQConfig\nLoftQ configuration subset\n\n\nLoraConfig\nPeft / LoRA configuration subset\n\n\nPeftConfig\npeftq configuration subset\n\n\nReLoRAConfig\nReLoRA configuration subset\n\n\n\n\n\nutils.schemas.peft.LoftQConfig()\nLoftQ configuration subset\n\n\n\nutils.schemas.peft.LoraConfig()\nPeft / LoRA configuration subset\n\n\n\nutils.schemas.peft.PeftConfig()\npeftq configuration subset\n\n\n\nutils.schemas.peft.ReLoRAConfig()\nReLoRA configuration subset"
},
{
"objectID": "docs/api/utils.schemas.peft.html#classes",
"href": "docs/api/utils.schemas.peft.html#classes",
"title": "utils.schemas.peft",
"section": "",
"text": "Name\nDescription\n\n\n\n\nLoftQConfig\nLoftQ configuration subset\n\n\nLoraConfig\nPeft / LoRA configuration subset\n\n\nPeftConfig\npeftq configuration subset\n\n\nReLoRAConfig\nReLoRA configuration subset\n\n\n\n\n\nutils.schemas.peft.LoftQConfig()\nLoftQ configuration subset\n\n\n\nutils.schemas.peft.LoraConfig()\nPeft / LoRA configuration subset\n\n\n\nutils.schemas.peft.PeftConfig()\npeftq configuration subset\n\n\n\nutils.schemas.peft.ReLoRAConfig()\nReLoRA configuration subset"
},
{
"objectID": "docs/api/utils.schemas.trl.html",
"href": "docs/api/utils.schemas.trl.html",
"title": "utils.schemas.trl",
"section": "",
"text": "utils.schemas.trl\nPydantic models for TRL trainer configuration\n\n\n\n\n\nName\nDescription\n\n\n\n\nTRLConfig\nInput args for TRL.\n\n\n\n\n\nutils.schemas.trl.TRLConfig()\nInput args for TRL."
},
{
"objectID": "docs/api/utils.schemas.trl.html#classes",
"href": "docs/api/utils.schemas.trl.html#classes",
"title": "utils.schemas.trl",
"section": "",
"text": "Name\nDescription\n\n\n\n\nTRLConfig\nInput args for TRL.\n\n\n\n\n\nutils.schemas.trl.TRLConfig()\nInput args for TRL."
},
{
"objectID": "docs/api/prompt_strategies.completion.html",
"href": "docs/api/prompt_strategies.completion.html",
"title": "prompt_strategies.completion",
"section": "",
"text": "prompt_strategies.completion\nBasic completion text\n\n\n\n\n\nName\nDescription\n\n\n\n\nCompletionPromptTokenizingStrategy\nTokenizing strategy for Completion prompts.\n\n\nCompletionPrompter\nPrompter for completion\n\n\n\n\n\nprompt_strategies.completion.CompletionPromptTokenizingStrategy(\n *args,\n max_length=None,\n **kwargs,\n)\nTokenizing strategy for Completion prompts.\n\n\n\nprompt_strategies.completion.CompletionPrompter()\nPrompter for completion"
},
{
"objectID": "docs/api/prompt_strategies.completion.html#classes",
"href": "docs/api/prompt_strategies.completion.html#classes",
"title": "prompt_strategies.completion",
"section": "",
"text": "Name\nDescription\n\n\n\n\nCompletionPromptTokenizingStrategy\nTokenizing strategy for Completion prompts.\n\n\nCompletionPrompter\nPrompter for completion\n\n\n\n\n\nprompt_strategies.completion.CompletionPromptTokenizingStrategy(\n *args,\n max_length=None,\n **kwargs,\n)\nTokenizing strategy for Completion prompts.\n\n\n\nprompt_strategies.completion.CompletionPrompter()\nPrompter for completion"
},
{
"objectID": "docs/api/cli.vllm_serve.html",
"href": "docs/api/cli.vllm_serve.html",
"title": "cli.vllm_serve",
"section": "",
"text": "cli.vllm_serve\nCLI to start the vllm server for online RL\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlScriptArguments\nAdditional arguments for the VLLM server\n\n\n\n\n\ncli.vllm_serve.AxolotlScriptArguments(\n reasoning_parser='',\n enable_reasoning=None,\n)\nAdditional arguments for the VLLM server\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\ndo_vllm_serve\nStarts the VLLM server for serving LLM models used for online RL\n\n\n\n\n\ncli.vllm_serve.do_vllm_serve(config, cli_args)\nStarts the VLLM server for serving LLM models used for online RL\nArgs\n:param cfg: Parsed doct of the YAML config\n:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nprocess_id\n\nthe process id of the started VLLM server"
},
{
"objectID": "docs/api/cli.vllm_serve.html#classes",
"href": "docs/api/cli.vllm_serve.html#classes",
"title": "cli.vllm_serve",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlScriptArguments\nAdditional arguments for the VLLM server\n\n\n\n\n\ncli.vllm_serve.AxolotlScriptArguments(\n reasoning_parser='',\n enable_reasoning=None,\n)\nAdditional arguments for the VLLM server"
},
{
"objectID": "docs/api/cli.vllm_serve.html#functions",
"href": "docs/api/cli.vllm_serve.html#functions",
"title": "cli.vllm_serve",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndo_vllm_serve\nStarts the VLLM server for serving LLM models used for online RL\n\n\n\n\n\ncli.vllm_serve.do_vllm_serve(config, cli_args)\nStarts the VLLM server for serving LLM models used for online RL\nArgs\n:param cfg: Parsed doct of the YAML config\n:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nprocess_id\n\nthe process id of the started VLLM server"
},
{
"objectID": "docs/api/utils.trainer.html",
"href": "docs/api/utils.trainer.html",
"title": "utils.trainer",
"section": "",
"text": "utils.trainer\nModule containing the Trainer class and related functions\n\n\n\n\n\nName\nDescription\n\n\n\n\nadd_pose_position_ids\nuse the PoSE technique to extend the context length by randomly skipping\n\n\nadd_position_ids\nHandle both single-example and batched data.\n\n\nfilter_sequences_by_length\nFilter sequences outside valid length range [min_sequence_len, sequence_len].\n\n\nsetup_trainer\nHelper method for instantiating and building a (causal or RLHF) trainer.\n\n\n\n\n\nutils.trainer.add_pose_position_ids(\n sample,\n max_context_len=32768,\n split_on_token_ids=None,\n chunks=2,\n)\nuse the PoSE technique to extend the context length by randomly skipping\npositions in the context. We only want to skip right before tokens in\nthe split_on_token_ids list. We should attempt to randomly distribute\nthe skips, but we dont need the final position_ids to be the full\ncontext_len. There may be multiple turns in the context, so we want to\nmake sure we take into account the maximum possible number of skips\nremaining in each sample.\n\n\n\nutils.trainer.add_position_ids(sample)\nHandle both single-example and batched data.\n- single example: sample[input_ids] is a list[int]\n- batched data: sample[input_ids] is a list[list[int]]\n\n\n\nutils.trainer.filter_sequences_by_length(\n sample,\n sequence_len=2048,\n min_sequence_len=2,\n raise_on_drop=False,\n)\nFilter sequences outside valid length range [min_sequence_len, sequence_len].\nDrops samples that are either too short (< min_sequence_len) or too long (> sequence_len).\nWorks for both single-example (list[int]) or batched (list[list[int]]).\nIf raise_on_drop is set, the code raises a ValueError if a sample is\nencountered that is too long and would have been dropped.\n\n\n\nutils.trainer.setup_trainer(\n cfg,\n train_dataset,\n eval_dataset,\n model,\n tokenizer,\n processor,\n total_num_steps,\n model_ref=None,\n peft_config=None,\n)\nHelper method for instantiating and building a (causal or RLHF) trainer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\n\nAxolotl config object containing training parameters.\nrequired\n\n\ntrain_dataset\n\nDataset to use for training.\nrequired\n\n\neval_dataset\n\nDataset to use for evaluation.\nrequired\n\n\nmodel\n\nThe model to train.\nrequired\n\n\ntokenizer\n\nTokenizer for processing text input.\nrequired\n\n\nprocessor\n\nProcessor for data preparation.\nrequired\n\n\ntotal_num_steps\n\nThe total number of training steps.\nrequired\n\n\nmodel_ref\n\nOptional reference model for RLHF training. Default is None.\nNone\n\n\npeft_config\n\nOptional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\n\nA trainer instance (either HFRLTrainer or HFCausalTrainer) configured based on the provided parameters."
},
{
"objectID": "docs/api/utils.trainer.html#functions",
"href": "docs/api/utils.trainer.html#functions",
"title": "utils.trainer",
"section": "",
"text": "Name\nDescription\n\n\n\n\nadd_pose_position_ids\nuse the PoSE technique to extend the context length by randomly skipping\n\n\nadd_position_ids\nHandle both single-example and batched data.\n\n\nfilter_sequences_by_length\nFilter sequences outside valid length range [min_sequence_len, sequence_len].\n\n\nsetup_trainer\nHelper method for instantiating and building a (causal or RLHF) trainer.\n\n\n\n\n\nutils.trainer.add_pose_position_ids(\n sample,\n max_context_len=32768,\n split_on_token_ids=None,\n chunks=2,\n)\nuse the PoSE technique to extend the context length by randomly skipping\npositions in the context. We only want to skip right before tokens in\nthe split_on_token_ids list. We should attempt to randomly distribute\nthe skips, but we dont need the final position_ids to be the full\ncontext_len. There may be multiple turns in the context, so we want to\nmake sure we take into account the maximum possible number of skips\nremaining in each sample.\n\n\n\nutils.trainer.add_position_ids(sample)\nHandle both single-example and batched data.\n- single example: sample[input_ids] is a list[int]\n- batched data: sample[input_ids] is a list[list[int]]\n\n\n\nutils.trainer.filter_sequences_by_length(\n sample,\n sequence_len=2048,\n min_sequence_len=2,\n raise_on_drop=False,\n)\nFilter sequences outside valid length range [min_sequence_len, sequence_len].\nDrops samples that are either too short (< min_sequence_len) or too long (> sequence_len).\nWorks for both single-example (list[int]) or batched (list[list[int]]).\nIf raise_on_drop is set, the code raises a ValueError if a sample is\nencountered that is too long and would have been dropped.\n\n\n\nutils.trainer.setup_trainer(\n cfg,\n train_dataset,\n eval_dataset,\n model,\n tokenizer,\n processor,\n total_num_steps,\n model_ref=None,\n peft_config=None,\n)\nHelper method for instantiating and building a (causal or RLHF) trainer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\n\nAxolotl config object containing training parameters.\nrequired\n\n\ntrain_dataset\n\nDataset to use for training.\nrequired\n\n\neval_dataset\n\nDataset to use for evaluation.\nrequired\n\n\nmodel\n\nThe model to train.\nrequired\n\n\ntokenizer\n\nTokenizer for processing text input.\nrequired\n\n\nprocessor\n\nProcessor for data preparation.\nrequired\n\n\ntotal_num_steps\n\nThe total number of training steps.\nrequired\n\n\nmodel_ref\n\nOptional reference model for RLHF training. Default is None.\nNone\n\n\npeft_config\n\nOptional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\n\nA trainer instance (either HFRLTrainer or HFCausalTrainer) configured based on the provided parameters."
},
{
"objectID": "docs/api/utils.ctx_managers.sequence_parallel.html",
"href": "docs/api/utils.ctx_managers.sequence_parallel.html",
"title": "utils.ctx_managers.sequence_parallel",
"section": "",
"text": "utils.ctx_managers.sequence_parallel\nModule for Axolotl trainer sequence parallelism manager and utilities\n\n\n\n\n\nName\nDescription\n\n\n\n\nAllGatherWithGrad\nCustom autograd function for all-gather to preserve gradients.\n\n\nSequenceParallelContextManager\nContext manager for sequence parallelism operations.\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.AllGatherWithGrad()\nCustom autograd function for all-gather to preserve gradients.\n\n\n\n\n\nName\nDescription\n\n\n\n\nbackward\nBackward pass for all-gather operation.\n\n\nforward\nForward pass of all-gather of data with sequence dimension.\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.AllGatherWithGrad.backward(\n ctx,\n grad_output,\n)\nBackward pass for all-gather operation.\nExtracts the gradient slice corresponding to this ranks original input\nfrom the full gradient tensor.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\ntorch.autograd.function.FunctionCtx\ntorch.autograd function context.\nrequired\n\n\ngrad_output\ntorch.Tensor\nGradient from subsequent layers with respect to the concatenated output tensor.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[torch.Tensor, None]\nTuple containing the gradient slice for this ranks input tensor and None for the process group parameter which doesnt require gradients.\n\n\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.AllGatherWithGrad.forward(\n ctx,\n input_tensor,\n group,\n)\nForward pass of all-gather of data with sequence dimension.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\ntorch.autograd.function.FunctionCtx\ntorch.autograd function context.\nrequired\n\n\ninput_tensor\ntorch.Tensor\nTensor from model output with sequence dimension.\nrequired\n\n\ngroup\ndist.ProcessGroup\ntorch.distributed process group.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nTensor from gathering the input_tensor from across the process group and concatenating along the sequence dimension.\n\n\n\n\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.SequenceParallelContextManager(\n models,\n context_parallel_size,\n gradient_accumulation_steps,\n ring_attn_func,\n heads_k_stride,\n gather_outputs,\n device_mesh=None,\n)\nContext manager for sequence parallelism operations.\nThis class provides a context that will automatically apply sequence parallelism\nduring model forward passes using a pre-forward hook, and gather outputs from\nacross the sequence parallelism group using a post-forward hook.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodels\nlist[nn.Module]\nList of models to apply sequence parallelism to pre- and post- forward hooks.\nrequired\n\n\ncontext_parallel_size\nint\nNumber of processes to split sequences over.\nrequired\n\n\ngradient_accumulation_steps\nint\nNumber of steps to accumulate gradients over.\nrequired\n\n\nring_attn_func\nRingAttnFunc\nWhich ring attention function to use. Currently unused.\nrequired\n\n\nheads_k_stride\nint | None\nSequence parallelism K head stride size. Passed through to varlen_llama3 ring_flash_attn implementation.\nrequired\n\n\ngather_outputs\nbool\nWhether to gather outputs after model forward pass across the sequence parallel group.\nrequired\n\n\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\napply_sequence_parallelism\nApply sequence parallelism slicing to a batch.\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.apply_sequence_parallelism(\n batch,\n local_rank,\n local_world_size,\n gradient_accumulation_steps,\n ring_attn_func,\n)\nApply sequence parallelism slicing to a batch.\nSpecial handling is implemented for integer logits_to_keep, which indicates\nto only keep the last N tokens in the sequence during generation.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nbatch\ndict[str, torch.Tensor]\nBatch dictionary (e.g., input_ids, attention_mask, etc.).\nrequired\n\n\nlocal_rank\nint\nLocal rank in the sequence parallel group.\nrequired\n\n\nlocal_world_size\nint\nWorld size of the sequence parallel group.\nrequired\n\n\ngradient_accumulation_steps\nint\nNumber of steps to accumulate gradients over.\nrequired\n\n\nring_attn_func\nRingAttnFunc\nWhich ring attention function to use. Currently unused, but related to above TODO.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[dict[str, torch.Tensor], int, int]\ntuple of: - Batch dictionary with sliced tensors. - The original sequence length before padding. - The number of padding tokens added."
},
{
"objectID": "docs/api/utils.ctx_managers.sequence_parallel.html#classes",
"href": "docs/api/utils.ctx_managers.sequence_parallel.html#classes",
"title": "utils.ctx_managers.sequence_parallel",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAllGatherWithGrad\nCustom autograd function for all-gather to preserve gradients.\n\n\nSequenceParallelContextManager\nContext manager for sequence parallelism operations.\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.AllGatherWithGrad()\nCustom autograd function for all-gather to preserve gradients.\n\n\n\n\n\nName\nDescription\n\n\n\n\nbackward\nBackward pass for all-gather operation.\n\n\nforward\nForward pass of all-gather of data with sequence dimension.\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.AllGatherWithGrad.backward(\n ctx,\n grad_output,\n)\nBackward pass for all-gather operation.\nExtracts the gradient slice corresponding to this ranks original input\nfrom the full gradient tensor.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\ntorch.autograd.function.FunctionCtx\ntorch.autograd function context.\nrequired\n\n\ngrad_output\ntorch.Tensor\nGradient from subsequent layers with respect to the concatenated output tensor.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[torch.Tensor, None]\nTuple containing the gradient slice for this ranks input tensor and None for the process group parameter which doesnt require gradients.\n\n\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.AllGatherWithGrad.forward(\n ctx,\n input_tensor,\n group,\n)\nForward pass of all-gather of data with sequence dimension.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\ntorch.autograd.function.FunctionCtx\ntorch.autograd function context.\nrequired\n\n\ninput_tensor\ntorch.Tensor\nTensor from model output with sequence dimension.\nrequired\n\n\ngroup\ndist.ProcessGroup\ntorch.distributed process group.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nTensor from gathering the input_tensor from across the process group and concatenating along the sequence dimension.\n\n\n\n\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.SequenceParallelContextManager(\n models,\n context_parallel_size,\n gradient_accumulation_steps,\n ring_attn_func,\n heads_k_stride,\n gather_outputs,\n device_mesh=None,\n)\nContext manager for sequence parallelism operations.\nThis class provides a context that will automatically apply sequence parallelism\nduring model forward passes using a pre-forward hook, and gather outputs from\nacross the sequence parallelism group using a post-forward hook.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodels\nlist[nn.Module]\nList of models to apply sequence parallelism to pre- and post- forward hooks.\nrequired\n\n\ncontext_parallel_size\nint\nNumber of processes to split sequences over.\nrequired\n\n\ngradient_accumulation_steps\nint\nNumber of steps to accumulate gradients over.\nrequired\n\n\nring_attn_func\nRingAttnFunc\nWhich ring attention function to use. Currently unused.\nrequired\n\n\nheads_k_stride\nint | None\nSequence parallelism K head stride size. Passed through to varlen_llama3 ring_flash_attn implementation.\nrequired\n\n\ngather_outputs\nbool\nWhether to gather outputs after model forward pass across the sequence parallel group.\nrequired"
},
{
"objectID": "docs/api/utils.ctx_managers.sequence_parallel.html#functions",
"href": "docs/api/utils.ctx_managers.sequence_parallel.html#functions",
"title": "utils.ctx_managers.sequence_parallel",
"section": "",
"text": "Name\nDescription\n\n\n\n\napply_sequence_parallelism\nApply sequence parallelism slicing to a batch.\n\n\n\n\n\nutils.ctx_managers.sequence_parallel.apply_sequence_parallelism(\n batch,\n local_rank,\n local_world_size,\n gradient_accumulation_steps,\n ring_attn_func,\n)\nApply sequence parallelism slicing to a batch.\nSpecial handling is implemented for integer logits_to_keep, which indicates\nto only keep the last N tokens in the sequence during generation.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nbatch\ndict[str, torch.Tensor]\nBatch dictionary (e.g., input_ids, attention_mask, etc.).\nrequired\n\n\nlocal_rank\nint\nLocal rank in the sequence parallel group.\nrequired\n\n\nlocal_world_size\nint\nWorld size of the sequence parallel group.\nrequired\n\n\ngradient_accumulation_steps\nint\nNumber of steps to accumulate gradients over.\nrequired\n\n\nring_attn_func\nRingAttnFunc\nWhich ring attention function to use. Currently unused, but related to above TODO.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[dict[str, torch.Tensor], int, int]\ntuple of: - Batch dictionary with sliced tensors. - The original sequence length before padding. - The number of padding tokens added."
},
{
"objectID": "docs/api/core.training_args.html",
"href": "docs/api/core.training_args.html",
"title": "core.training_args",
"section": "",
"text": "core.training_args\nextra axolotl specific training args\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlCPOConfig\nCPO config for CPO training\n\n\nAxolotlKTOConfig\nKTO config for KTO training\n\n\nAxolotlORPOConfig\nORPO config for ORPO training\n\n\nAxolotlPRMConfig\nPRM config for PRM training\n\n\nAxolotlRewardConfig\nReward config for Reward training\n\n\nAxolotlTrainingArguments\nTraining arguments for Causal trainer\n\n\n\n\n\ncore.training_args.AxolotlCPOConfig(simpo_gamma=None)\nCPO config for CPO training\n\n\n\ncore.training_args.AxolotlKTOConfig()\nKTO config for KTO training\n\n\n\ncore.training_args.AxolotlORPOConfig()\nORPO config for ORPO training\n\n\n\ncore.training_args.AxolotlPRMConfig()\nPRM config for PRM training\n\n\n\ncore.training_args.AxolotlRewardConfig()\nReward config for Reward training\n\n\n\ncore.training_args.AxolotlTrainingArguments()\nTraining arguments for Causal trainer\nThis code is duplicated due to HF TrainingArguments not setting output_dir with a\ndefault value so it cant be used as a mixin."
},
{
"objectID": "docs/api/core.training_args.html#classes",
"href": "docs/api/core.training_args.html#classes",
"title": "core.training_args",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlCPOConfig\nCPO config for CPO training\n\n\nAxolotlKTOConfig\nKTO config for KTO training\n\n\nAxolotlORPOConfig\nORPO config for ORPO training\n\n\nAxolotlPRMConfig\nPRM config for PRM training\n\n\nAxolotlRewardConfig\nReward config for Reward training\n\n\nAxolotlTrainingArguments\nTraining arguments for Causal trainer\n\n\n\n\n\ncore.training_args.AxolotlCPOConfig(simpo_gamma=None)\nCPO config for CPO training\n\n\n\ncore.training_args.AxolotlKTOConfig()\nKTO config for KTO training\n\n\n\ncore.training_args.AxolotlORPOConfig()\nORPO config for ORPO training\n\n\n\ncore.training_args.AxolotlPRMConfig()\nPRM config for PRM training\n\n\n\ncore.training_args.AxolotlRewardConfig()\nReward config for Reward training\n\n\n\ncore.training_args.AxolotlTrainingArguments()\nTraining arguments for Causal trainer\nThis code is duplicated due to HF TrainingArguments not setting output_dir with a\ndefault value so it cant be used as a mixin."
},
{
"objectID": "docs/api/evaluate.html",
"href": "docs/api/evaluate.html",
"title": "evaluate",
"section": "",
"text": "evaluate\nModule for evaluating models.\n\n\n\n\n\nName\nDescription\n\n\n\n\nevaluate\nEvaluate a model on training and validation datasets.\n\n\nevaluate_dataset\nHelper function to evaluate a single dataset.\n\n\n\n\n\nevaluate.evaluate(cfg, dataset_meta)\nEvaluate a model on training and validation datasets.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ndataset_meta\nTrainDatasetMeta\nDataset metadata containing training and evaluation datasets.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nDict[str, float]\nDictionary mapping metric names to their values.\n\n\n\n\n\n\n\nevaluate.evaluate_dataset(trainer, dataset, dataset_type, flash_optimum=False)\nHelper function to evaluate a single dataset.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntrainer\nTrainer\nThe trainer instance.\nrequired\n\n\ndataset\nDataset\nDataset to evaluate.\nrequired\n\n\ndataset_type\nstr\nType of dataset (train or eval).\nrequired\n\n\nflash_optimum\nbool\nWhether to use flash optimum.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nOptional[Dict[str, float]]\nDictionary of metrics or None if dataset is None."
},
{
"objectID": "docs/api/evaluate.html#functions",
"href": "docs/api/evaluate.html#functions",
"title": "evaluate",
"section": "",
"text": "Name\nDescription\n\n\n\n\nevaluate\nEvaluate a model on training and validation datasets.\n\n\nevaluate_dataset\nHelper function to evaluate a single dataset.\n\n\n\n\n\nevaluate.evaluate(cfg, dataset_meta)\nEvaluate a model on training and validation datasets.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ndataset_meta\nTrainDatasetMeta\nDataset metadata containing training and evaluation datasets.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nDict[str, float]\nDictionary mapping metric names to their values.\n\n\n\n\n\n\n\nevaluate.evaluate_dataset(trainer, dataset, dataset_type, flash_optimum=False)\nHelper function to evaluate a single dataset.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntrainer\nTrainer\nThe trainer instance.\nrequired\n\n\ndataset\nDataset\nDataset to evaluate.\nrequired\n\n\ndataset_type\nstr\nType of dataset (train or eval).\nrequired\n\n\nflash_optimum\nbool\nWhether to use flash optimum.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nOptional[Dict[str, float]]\nDictionary of metrics or None if dataset is None."
},
{
"objectID": "docs/api/utils.callbacks.comet_.html",
"href": "docs/api/utils.callbacks.comet_.html",
"title": "utils.callbacks.comet_",
"section": "",
"text": "utils.callbacks.comet_\nComet module for trainer callbacks\n\n\n\n\n\nName\nDescription\n\n\n\n\nSaveAxolotlConfigtoCometCallback\nCallback to save axolotl config to comet\n\n\n\n\n\nutils.callbacks.comet_.SaveAxolotlConfigtoCometCallback(axolotl_config_path)\nCallback to save axolotl config to comet"
},
{
"objectID": "docs/api/utils.callbacks.comet_.html#classes",
"href": "docs/api/utils.callbacks.comet_.html#classes",
"title": "utils.callbacks.comet_",
"section": "",
"text": "Name\nDescription\n\n\n\n\nSaveAxolotlConfigtoCometCallback\nCallback to save axolotl config to comet\n\n\n\n\n\nutils.callbacks.comet_.SaveAxolotlConfigtoCometCallback(axolotl_config_path)\nCallback to save axolotl config to comet"
},
{
"objectID": "docs/api/loaders.tokenizer.html",
"href": "docs/api/loaders.tokenizer.html",
"title": "loaders.tokenizer",
"section": "",
"text": "loaders.tokenizer\nTokenizer loading functionality and associated utils\n\n\n\n\n\nName\nDescription\n\n\n\n\nload_tokenizer\nLoad and configure the tokenizer based on the provided config.\n\n\nmodify_tokenizer_files\nModify tokenizer files to replace added_tokens strings, save to output directory,\n\n\n\n\n\nloaders.tokenizer.load_tokenizer(cfg)\nLoad and configure the tokenizer based on the provided config.\n\n\n\nloaders.tokenizer.modify_tokenizer_files(\n tokenizer_path,\n token_mappings,\n output_dir,\n revision='main',\n)\nModify tokenizer files to replace added_tokens strings, save to output directory,\nand return the path to the modified tokenizer.\nThis only works with reserved tokens that were added to the tokenizer, not tokens\nalready part of the vocab.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntokenizer_path\nstr\nPath or name of the original tokenizer\nrequired\n\n\ntoken_mappings\ndict[int, str]\nDict mapping {token_id (int): new_token_string}\nrequired\n\n\noutput_dir\nstr\nDirectory to save the modified tokenizer\nrequired\n\n\nrevision\nstr\nModel revision/branch/tag/commit to load from (HF Hub)\n'main'\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nstr\nPath to the modified tokenizer directory\n\n\n\nRef: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941"
},
{
"objectID": "docs/api/loaders.tokenizer.html#functions",
"href": "docs/api/loaders.tokenizer.html#functions",
"title": "loaders.tokenizer",
"section": "",
"text": "Name\nDescription\n\n\n\n\nload_tokenizer\nLoad and configure the tokenizer based on the provided config.\n\n\nmodify_tokenizer_files\nModify tokenizer files to replace added_tokens strings, save to output directory,\n\n\n\n\n\nloaders.tokenizer.load_tokenizer(cfg)\nLoad and configure the tokenizer based on the provided config.\n\n\n\nloaders.tokenizer.modify_tokenizer_files(\n tokenizer_path,\n token_mappings,\n output_dir,\n revision='main',\n)\nModify tokenizer files to replace added_tokens strings, save to output directory,\nand return the path to the modified tokenizer.\nThis only works with reserved tokens that were added to the tokenizer, not tokens\nalready part of the vocab.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntokenizer_path\nstr\nPath or name of the original tokenizer\nrequired\n\n\ntoken_mappings\ndict[int, str]\nDict mapping {token_id (int): new_token_string}\nrequired\n\n\noutput_dir\nstr\nDirectory to save the modified tokenizer\nrequired\n\n\nrevision\nstr\nModel revision/branch/tag/commit to load from (HF Hub)\n'main'\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nstr\nPath to the modified tokenizer directory\n\n\n\nRef: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941"
},
{
"objectID": "docs/api/monkeypatch.llama_attn_hijack_flash.html",
"href": "docs/api/monkeypatch.llama_attn_hijack_flash.html",
"title": "monkeypatch.llama_attn_hijack_flash",
"section": "",
"text": "monkeypatch.llama_attn_hijack_flash\nFlash attention monkey patch for llama model\n\n\n\n\n\nName\nDescription\n\n\n\n\nflashattn_forward_with_s2attn\nInput shape: Batch x Time x Channel\n\n\n\n\n\nmonkeypatch.llama_attn_hijack_flash.flashattn_forward_with_s2attn(\n self,\n hidden_states,\n attention_mask=None,\n position_ids=None,\n past_key_value=None,\n output_attentions=False,\n use_cache=False,\n padding_mask=None,\n cu_seqlens=None,\n max_seqlen=None,\n)\nInput shape: Batch x Time x Channel\nFrom: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py\nattention_mask: [bsz, q_len]\ncu_seqlens will be ignored if provided\nmax_seqlen will be ignored if provided"
},
{
"objectID": "docs/api/monkeypatch.llama_attn_hijack_flash.html#functions",
"href": "docs/api/monkeypatch.llama_attn_hijack_flash.html#functions",
"title": "monkeypatch.llama_attn_hijack_flash",
"section": "",
"text": "Name\nDescription\n\n\n\n\nflashattn_forward_with_s2attn\nInput shape: Batch x Time x Channel\n\n\n\n\n\nmonkeypatch.llama_attn_hijack_flash.flashattn_forward_with_s2attn(\n self,\n hidden_states,\n attention_mask=None,\n position_ids=None,\n past_key_value=None,\n output_attentions=False,\n use_cache=False,\n padding_mask=None,\n cu_seqlens=None,\n max_seqlen=None,\n)\nInput shape: Batch x Time x Channel\nFrom: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py\nattention_mask: [bsz, q_len]\ncu_seqlens will be ignored if provided\nmax_seqlen will be ignored if provided"
},
{
"objectID": "docs/api/cli.cloud.modal_.html",
"href": "docs/api/cli.cloud.modal_.html",
"title": "cli.cloud.modal_",
"section": "",
"text": "cli.cloud.modal_\nModal Cloud support from CLI\n\n\n\n\n\nName\nDescription\n\n\n\n\nModalCloud\nModal Cloud implementation.\n\n\n\n\n\ncli.cloud.modal_.ModalCloud(config, app=None)\nModal Cloud implementation.\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nrun_cmd\nRun a command inside a folder, with Modal Volume reloading before and commit on success.\n\n\n\n\n\ncli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)\nRun a command inside a folder, with Modal Volume reloading before and commit on success."
},
{
"objectID": "docs/api/cli.cloud.modal_.html#classes",
"href": "docs/api/cli.cloud.modal_.html#classes",
"title": "cli.cloud.modal_",
"section": "",
"text": "Name\nDescription\n\n\n\n\nModalCloud\nModal Cloud implementation.\n\n\n\n\n\ncli.cloud.modal_.ModalCloud(config, app=None)\nModal Cloud implementation."
},
{
"objectID": "docs/api/cli.cloud.modal_.html#functions",
"href": "docs/api/cli.cloud.modal_.html#functions",
"title": "cli.cloud.modal_",
"section": "",
"text": "Name\nDescription\n\n\n\n\nrun_cmd\nRun a command inside a folder, with Modal Volume reloading before and commit on success.\n\n\n\n\n\ncli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)\nRun a command inside a folder, with Modal Volume reloading before and commit on success."
},
{
"objectID": "docs/api/prompt_strategies.stepwise_supervised.html",
"href": "docs/api/prompt_strategies.stepwise_supervised.html",
"title": "prompt_strategies.stepwise_supervised",
"section": "",
"text": "prompt_strategies.stepwise_supervised\nModule for stepwise datasets, typically including a prompt and reasoning traces,\nand (optionally) per-step, or per-prompt-trace labels for reward modelling.\n\n\n\n\n\nName\nDescription\n\n\n\n\nStepwiseSupervisedPromptTokenizingStrategy\nTokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.\n\n\n\n\n\nprompt_strategies.stepwise_supervised.StepwiseSupervisedPromptTokenizingStrategy(\n tokenizer,\n sequence_len=2048,\n step_separator='\\n',\n max_completion_length=None,\n train_on_last_step_only=False,\n)\nTokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.\nThese datasets should include the following columns:\n- prompt: the prompt text\n- completions: a list of n completion steps\n- labels: a list of n labels indicating the “correctness” of each step"
},
{
"objectID": "docs/api/prompt_strategies.stepwise_supervised.html#classes",
"href": "docs/api/prompt_strategies.stepwise_supervised.html#classes",
"title": "prompt_strategies.stepwise_supervised",
"section": "",
"text": "Name\nDescription\n\n\n\n\nStepwiseSupervisedPromptTokenizingStrategy\nTokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.\n\n\n\n\n\nprompt_strategies.stepwise_supervised.StepwiseSupervisedPromptTokenizingStrategy(\n tokenizer,\n sequence_len=2048,\n step_separator='\\n',\n max_completion_length=None,\n train_on_last_step_only=False,\n)\nTokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.\nThese datasets should include the following columns:\n- prompt: the prompt text\n- completions: a list of n completion steps\n- labels: a list of n labels indicating the “correctness” of each step"
},
{
"objectID": "docs/api/monkeypatch.btlm_attn_hijack_flash.html",
"href": "docs/api/monkeypatch.btlm_attn_hijack_flash.html",
"title": "monkeypatch.btlm_attn_hijack_flash",
"section": "",
"text": "monkeypatch.btlm_attn_hijack_flash\nmonkeypatch.btlm_attn_hijack_flash\nFlash attention monkey patch for cerebras btlm model"
},
{
"objectID": "docs/api/core.chat.format.llama3x.html",
"href": "docs/api/core.chat.format.llama3x.html",
"title": "core.chat.format.llama3x",
"section": "",
"text": "core.chat.format.llama3x\ncore.chat.format.llama3x\nLlama 3.x chat formatting functions for MessageContents"
},
{
"objectID": "docs/api/utils.quantization.html",
"href": "docs/api/utils.quantization.html",
"title": "utils.quantization",
"section": "",
"text": "utils.quantization\nUtilities for quantization including QAT and PTQ using torchao.\n\n\n\n\n\nName\nDescription\n\n\n\n\nconvert_qat_model\nThis function converts a QAT model which has fake quantized layers back to the original model.\n\n\nget_quantization_config\nThis function is used to build a post-training quantization config.\n\n\nprepare_model_for_qat\nThis function is used to prepare a model for QAT by swapping the models linear\n\n\nquantize_model\nThis function is used to quantize a model.\n\n\n\n\n\nutils.quantization.convert_qat_model(model, quantize_embedding=False)\nThis function converts a QAT model which has fake quantized layers back to the original model.\n\n\n\nutils.quantization.get_quantization_config(\n weight_dtype,\n activation_dtype=None,\n group_size=None,\n)\nThis function is used to build a post-training quantization config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nweight_dtype\nTorchAOQuantDType\nThe dtype to use for weight quantization.\nrequired\n\n\nactivation_dtype\nTorchAOQuantDType | None\nThe dtype to use for activation quantization.\nNone\n\n\ngroup_size\nint | None\nThe group size to use for weight quantization.\nNone\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nAOBaseConfig\nThe post-training quantization config.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf the activation dtype is not specified and the weight dtype is not int8 or int4, or if the group size is not specified for int8 or int4 weight only quantization.\n\n\n\n\n\n\n\nutils.quantization.prepare_model_for_qat(\n model,\n weight_dtype,\n group_size=None,\n activation_dtype=None,\n quantize_embedding=False,\n)\nThis function is used to prepare a model for QAT by swapping the models linear\nlayers with fake quantized linear layers, and optionally the embedding weights with\nfake quantized embedding weights.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\n\nThe model to quantize.\nrequired\n\n\nweight_dtype\nTorchAOQuantDType\nThe dtype to use for weight quantization.\nrequired\n\n\ngroup_size\nint | None\nThe group size to use for weight quantization.\nNone\n\n\nactivation_dtype\nTorchAOQuantDType | None\nThe dtype to use for activation quantization.\nNone\n\n\nquantize_embedding\nbool\nWhether to quantize the models embedding weights.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf the activation/weight dtype combination is invalid.\n\n\n\n\n\n\n\nutils.quantization.quantize_model(\n model,\n weight_dtype,\n group_size=None,\n activation_dtype=None,\n quantize_embedding=None,\n)\nThis function is used to quantize a model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\n\nThe model to quantize.\nrequired\n\n\nweight_dtype\nTorchAOQuantDType\nThe dtype to use for weight quantization.\nrequired\n\n\ngroup_size\nint | None\nThe group size to use for weight quantization.\nNone\n\n\nactivation_dtype\nTorchAOQuantDType | None\nThe dtype to use for activation quantization.\nNone\n\n\nquantize_embedding\nbool | None\nWhether to quantize the models embedding weights.\nNone"
},
{
"objectID": "docs/api/utils.quantization.html#functions",
"href": "docs/api/utils.quantization.html#functions",
"title": "utils.quantization",
"section": "",
"text": "Name\nDescription\n\n\n\n\nconvert_qat_model\nThis function converts a QAT model which has fake quantized layers back to the original model.\n\n\nget_quantization_config\nThis function is used to build a post-training quantization config.\n\n\nprepare_model_for_qat\nThis function is used to prepare a model for QAT by swapping the models linear\n\n\nquantize_model\nThis function is used to quantize a model.\n\n\n\n\n\nutils.quantization.convert_qat_model(model, quantize_embedding=False)\nThis function converts a QAT model which has fake quantized layers back to the original model.\n\n\n\nutils.quantization.get_quantization_config(\n weight_dtype,\n activation_dtype=None,\n group_size=None,\n)\nThis function is used to build a post-training quantization config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nweight_dtype\nTorchAOQuantDType\nThe dtype to use for weight quantization.\nrequired\n\n\nactivation_dtype\nTorchAOQuantDType | None\nThe dtype to use for activation quantization.\nNone\n\n\ngroup_size\nint | None\nThe group size to use for weight quantization.\nNone\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nAOBaseConfig\nThe post-training quantization config.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf the activation dtype is not specified and the weight dtype is not int8 or int4, or if the group size is not specified for int8 or int4 weight only quantization.\n\n\n\n\n\n\n\nutils.quantization.prepare_model_for_qat(\n model,\n weight_dtype,\n group_size=None,\n activation_dtype=None,\n quantize_embedding=False,\n)\nThis function is used to prepare a model for QAT by swapping the models linear\nlayers with fake quantized linear layers, and optionally the embedding weights with\nfake quantized embedding weights.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\n\nThe model to quantize.\nrequired\n\n\nweight_dtype\nTorchAOQuantDType\nThe dtype to use for weight quantization.\nrequired\n\n\ngroup_size\nint | None\nThe group size to use for weight quantization.\nNone\n\n\nactivation_dtype\nTorchAOQuantDType | None\nThe dtype to use for activation quantization.\nNone\n\n\nquantize_embedding\nbool\nWhether to quantize the models embedding weights.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf the activation/weight dtype combination is invalid.\n\n\n\n\n\n\n\nutils.quantization.quantize_model(\n model,\n weight_dtype,\n group_size=None,\n activation_dtype=None,\n quantize_embedding=None,\n)\nThis function is used to quantize a model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\n\nThe model to quantize.\nrequired\n\n\nweight_dtype\nTorchAOQuantDType\nThe dtype to use for weight quantization.\nrequired\n\n\ngroup_size\nint | None\nThe group size to use for weight quantization.\nNone\n\n\nactivation_dtype\nTorchAOQuantDType | None\nThe dtype to use for activation quantization.\nNone\n\n\nquantize_embedding\nbool | None\nWhether to quantize the models embedding weights.\nNone"
},
{
"objectID": "docs/api/monkeypatch.unsloth_.html",
"href": "docs/api/monkeypatch.unsloth_.html",
"title": "monkeypatch.unsloth_",
"section": "",
"text": "monkeypatch.unsloth_\nmonkeypatch.unsloth_\nmodule for patching with unsloth optimizations"
},
{
"objectID": "docs/api/prompt_strategies.orpo.chat_template.html",
"href": "docs/api/prompt_strategies.orpo.chat_template.html",
"title": "prompt_strategies.orpo.chat_template",
"section": "",
"text": "prompt_strategies.orpo.chat_template\nchatml prompt tokenization strategy for ORPO\n\n\n\n\n\nName\nDescription\n\n\n\n\nMessage\nmessage/turn\n\n\nMessageList\nconversation\n\n\nORPODatasetParsingStrategy\nStrategy to parse chosen rejected dataset into messagelist\n\n\nORPOPrompter\nSingle Turn prompter for ORPO\n\n\nORPOTokenizingStrategy\nrejected_ids\n\n\n\n\n\nprompt_strategies.orpo.chat_template.Message()\nmessage/turn\n\n\n\nprompt_strategies.orpo.chat_template.MessageList()\nconversation\n\n\n\nprompt_strategies.orpo.chat_template.ORPODatasetParsingStrategy()\nStrategy to parse chosen rejected dataset into messagelist\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_chosen_conversation_thread\nDataset structure mappings\n\n\nget_prompt\nMap the data to extract everything up to the last turn\n\n\nget_rejected_conversation_thread\nDataset structure mappings\n\n\n\n\n\nprompt_strategies.orpo.chat_template.ORPODatasetParsingStrategy.get_chosen_conversation_thread(\n prompt,\n)\nDataset structure mappings\n\n\n\nprompt_strategies.orpo.chat_template.ORPODatasetParsingStrategy.get_prompt(\n prompt,\n)\nMap the data to extract everything up to the last turn\n\n\n\nprompt_strategies.orpo.chat_template.ORPODatasetParsingStrategy.get_rejected_conversation_thread(\n prompt,\n)\nDataset structure mappings\n\n\n\n\n\nprompt_strategies.orpo.chat_template.ORPOPrompter(chat_template, tokenizer)\nSingle Turn prompter for ORPO\n\n\n\nprompt_strategies.orpo.chat_template.ORPOTokenizingStrategy(\n *args,\n dataset_parser=None,\n **kwargs,\n)\nrejected_ids\ninput_ids\nrejected_attention_mask\nattention_mask\nrejected_labels\nlabels\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nload\nchatml transforms for datasets with system, input, chosen, rejected\n\n\n\n\n\nprompt_strategies.orpo.chat_template.load(tokenizer, cfg, ds_cfg=None, **kwargs)\nchatml transforms for datasets with system, input, chosen, rejected"
},
{
"objectID": "docs/api/prompt_strategies.orpo.chat_template.html#classes",
"href": "docs/api/prompt_strategies.orpo.chat_template.html#classes",
"title": "prompt_strategies.orpo.chat_template",
"section": "",
"text": "Name\nDescription\n\n\n\n\nMessage\nmessage/turn\n\n\nMessageList\nconversation\n\n\nORPODatasetParsingStrategy\nStrategy to parse chosen rejected dataset into messagelist\n\n\nORPOPrompter\nSingle Turn prompter for ORPO\n\n\nORPOTokenizingStrategy\nrejected_ids\n\n\n\n\n\nprompt_strategies.orpo.chat_template.Message()\nmessage/turn\n\n\n\nprompt_strategies.orpo.chat_template.MessageList()\nconversation\n\n\n\nprompt_strategies.orpo.chat_template.ORPODatasetParsingStrategy()\nStrategy to parse chosen rejected dataset into messagelist\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_chosen_conversation_thread\nDataset structure mappings\n\n\nget_prompt\nMap the data to extract everything up to the last turn\n\n\nget_rejected_conversation_thread\nDataset structure mappings\n\n\n\n\n\nprompt_strategies.orpo.chat_template.ORPODatasetParsingStrategy.get_chosen_conversation_thread(\n prompt,\n)\nDataset structure mappings\n\n\n\nprompt_strategies.orpo.chat_template.ORPODatasetParsingStrategy.get_prompt(\n prompt,\n)\nMap the data to extract everything up to the last turn\n\n\n\nprompt_strategies.orpo.chat_template.ORPODatasetParsingStrategy.get_rejected_conversation_thread(\n prompt,\n)\nDataset structure mappings\n\n\n\n\n\nprompt_strategies.orpo.chat_template.ORPOPrompter(chat_template, tokenizer)\nSingle Turn prompter for ORPO\n\n\n\nprompt_strategies.orpo.chat_template.ORPOTokenizingStrategy(\n *args,\n dataset_parser=None,\n **kwargs,\n)\nrejected_ids\ninput_ids\nrejected_attention_mask\nattention_mask\nrejected_labels\nlabels"
},
{
"objectID": "docs/api/prompt_strategies.orpo.chat_template.html#functions",
"href": "docs/api/prompt_strategies.orpo.chat_template.html#functions",
"title": "prompt_strategies.orpo.chat_template",
"section": "",
"text": "Name\nDescription\n\n\n\n\nload\nchatml transforms for datasets with system, input, chosen, rejected\n\n\n\n\n\nprompt_strategies.orpo.chat_template.load(tokenizer, cfg, ds_cfg=None, **kwargs)\nchatml transforms for datasets with system, input, chosen, rejected"
},
{
"objectID": "docs/api/cli.art.html",
"href": "docs/api/cli.art.html",
"title": "cli.art",
"section": "",
"text": "cli.art\nAxolotl ASCII logo utils.\n\n\n\n\n\nName\nDescription\n\n\n\n\nprint_axolotl_text_art\nPrints axolotl ASCII art.\n\n\n\n\n\ncli.art.print_axolotl_text_art()\nPrints axolotl ASCII art."
},
{
"objectID": "docs/api/cli.art.html#functions",
"href": "docs/api/cli.art.html#functions",
"title": "cli.art",
"section": "",
"text": "Name\nDescription\n\n\n\n\nprint_axolotl_text_art\nPrints axolotl ASCII art.\n\n\n\n\n\ncli.art.print_axolotl_text_art()\nPrints axolotl ASCII art."
},
{
"objectID": "docs/api/loaders.processor.html",
"href": "docs/api/loaders.processor.html",
"title": "loaders.processor",
"section": "",
"text": "loaders.processor\nloaders.processor\nProcessor loading functionality for multi-modal models"
},
{
"objectID": "docs/api/cli.merge_sharded_fsdp_weights.html",
"href": "docs/api/cli.merge_sharded_fsdp_weights.html",
"title": "cli.merge_sharded_fsdp_weights",
"section": "",
"text": "cli.merge_sharded_fsdp_weights\nCLI to merge sharded FSDP model checkpoints into a single combined checkpoint.\n\n\n\n\n\nName\nDescription\n\n\n\n\nBFloat16CastPlanner\nA custom planner to cast tensors to bfloat16 on the fly during loading.\n\n\n\n\n\ncli.merge_sharded_fsdp_weights.BFloat16CastPlanner()\nA custom planner to cast tensors to bfloat16 on the fly during loading.\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls merge_fsdp_weights.\n\n\nmerge_fsdp_weights\nMerge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if\n\n\n\n\n\ncli.merge_sharded_fsdp_weights.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls merge_fsdp_weights.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.merge_sharded_fsdp_weights.merge_fsdp_weights(\n checkpoint_dir,\n output_path,\n remove_checkpoint_dir=False,\n)\nMerge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if\nSHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors.\nNote: this is a CPU-bound process.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncheckpoint_dir\nstr\nThe directory containing the FSDP checkpoints (can be either the model or optimizer).\nrequired\n\n\noutput_path\nstr\nThe path to save the merged checkpoint.\nrequired\n\n\nremove_checkpoint_dir\nbool, optional, defaults to False\nWhether to remove the checkpoint directory after merging.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf torch version < 2.3.0, or if checkpoint_dir does not exist."
},
{
"objectID": "docs/api/cli.merge_sharded_fsdp_weights.html#classes",
"href": "docs/api/cli.merge_sharded_fsdp_weights.html#classes",
"title": "cli.merge_sharded_fsdp_weights",
"section": "",
"text": "Name\nDescription\n\n\n\n\nBFloat16CastPlanner\nA custom planner to cast tensors to bfloat16 on the fly during loading.\n\n\n\n\n\ncli.merge_sharded_fsdp_weights.BFloat16CastPlanner()\nA custom planner to cast tensors to bfloat16 on the fly during loading."
},
{
"objectID": "docs/api/cli.merge_sharded_fsdp_weights.html#functions",
"href": "docs/api/cli.merge_sharded_fsdp_weights.html#functions",
"title": "cli.merge_sharded_fsdp_weights",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls merge_fsdp_weights.\n\n\nmerge_fsdp_weights\nMerge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if\n\n\n\n\n\ncli.merge_sharded_fsdp_weights.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls merge_fsdp_weights.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.merge_sharded_fsdp_weights.merge_fsdp_weights(\n checkpoint_dir,\n output_path,\n remove_checkpoint_dir=False,\n)\nMerge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if\nSHARDED_STATE_DICT was used for the model. Weights will be saved to {output_path}/model.safetensors.\nNote: this is a CPU-bound process.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncheckpoint_dir\nstr\nThe directory containing the FSDP checkpoints (can be either the model or optimizer).\nrequired\n\n\noutput_path\nstr\nThe path to save the merged checkpoint.\nrequired\n\n\nremove_checkpoint_dir\nbool, optional, defaults to False\nWhether to remove the checkpoint directory after merging.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf torch version < 2.3.0, or if checkpoint_dir does not exist."
},
{
"objectID": "docs/api/kernels.quantize.html",
"href": "docs/api/kernels.quantize.html",
"title": "kernels.quantize",
"section": "",
"text": "kernels.quantize\nDequantization utilities for bitsandbytes and FP8 integration.\n\n\n\n\n\nName\nDescription\n\n\n\n\ndequantize\nFast NF4 dequantization using bitsandbytes CUDA kernels.\n\n\ndequantize_fp8\nDequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.\n\n\n\n\n\nkernels.quantize.dequantize(W, quant_state=None, out=None)\nFast NF4 dequantization using bitsandbytes CUDA kernels.\nPerforms efficient dequantization of weights from NF4 format using bitsandbytes\noptimized CUDA implementations. Supports both legacy list and new QuantState\nformats.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nW\ntorch.Tensor\nQuantized weight tensor to dequantize\nrequired\n\n\nquant_state\nQuantState | list | torch.Tensor | None\nQuantization state containing metadata needed for dequantization. Can be either a QuantState object or legacy list format. If None, returns W unchanged.\nNone\n\n\nout\ntorch.Tensor | None\nOptional output tensor for storing dequantized results. Must match expected shape and dtype if provided.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nDequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if\n\n\n\ntorch.Tensor\ninput W was transposed.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nAssertionError\nIf provided output tensor doesnt match expected shape / dtype.\n\n\n\n\n\n\nUses CUDA streams for better performance when available in newer bitsandbytes\nversions (>0.43.3).\n\n\n\n\nkernels.quantize.dequantize_fp8(W, scale_inv, dtype=torch.bfloat16)\nDequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nW\ntorch.Tensor\nFP8 weight tensor [out_features, in_features] in float8_e4m3fn.\nrequired\n\n\nscale_inv\ntorch.Tensor\nPer-block inverse scale [ceil(out/block), ceil(in/block)] or per-tensor scalar.\nrequired\n\n\ndtype\ntorch.dtype\nOutput dtype (default bf16).\ntorch.bfloat16\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nDequantized tensor in the specified dtype."
},
{
"objectID": "docs/api/kernels.quantize.html#functions",
"href": "docs/api/kernels.quantize.html#functions",
"title": "kernels.quantize",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndequantize\nFast NF4 dequantization using bitsandbytes CUDA kernels.\n\n\ndequantize_fp8\nDequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.\n\n\n\n\n\nkernels.quantize.dequantize(W, quant_state=None, out=None)\nFast NF4 dequantization using bitsandbytes CUDA kernels.\nPerforms efficient dequantization of weights from NF4 format using bitsandbytes\noptimized CUDA implementations. Supports both legacy list and new QuantState\nformats.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nW\ntorch.Tensor\nQuantized weight tensor to dequantize\nrequired\n\n\nquant_state\nQuantState | list | torch.Tensor | None\nQuantization state containing metadata needed for dequantization. Can be either a QuantState object or legacy list format. If None, returns W unchanged.\nNone\n\n\nout\ntorch.Tensor | None\nOptional output tensor for storing dequantized results. Must match expected shape and dtype if provided.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nDequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if\n\n\n\ntorch.Tensor\ninput W was transposed.\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nAssertionError\nIf provided output tensor doesnt match expected shape / dtype.\n\n\n\n\n\n\nUses CUDA streams for better performance when available in newer bitsandbytes\nversions (>0.43.3).\n\n\n\n\nkernels.quantize.dequantize_fp8(W, scale_inv, dtype=torch.bfloat16)\nDequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nW\ntorch.Tensor\nFP8 weight tensor [out_features, in_features] in float8_e4m3fn.\nrequired\n\n\nscale_inv\ntorch.Tensor\nPer-block inverse scale [ceil(out/block), ceil(in/block)] or per-tensor scalar.\nrequired\n\n\ndtype\ntorch.dtype\nOutput dtype (default bf16).\ntorch.bfloat16\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\nDequantized tensor in the specified dtype."
},
{
"objectID": "docs/api/core.trainers.utils.html",
"href": "docs/api/core.trainers.utils.html",
"title": "core.trainers.utils",
"section": "",
"text": "core.trainers.utils\ncore.trainers.utils\nUtils for Axolotl trainers"
},
{
"objectID": "docs/api/prompt_strategies.dpo.chat_template.html",
"href": "docs/api/prompt_strategies.dpo.chat_template.html",
"title": "prompt_strategies.dpo.chat_template",
"section": "",
"text": "prompt_strategies.dpo.chat_template\nDPO prompt strategies for using tokenizer chat templates.\n\n\n\n\n\nName\nDescription\n\n\n\n\nargilla_chat\nDPO chat template strategy for argilla-style datasets.\n\n\n\n\n\nprompt_strategies.dpo.chat_template.argilla_chat(cfg, dataset_idx=0, **kwargs)\nDPO chat template strategy for argilla-style datasets.\nFor argilla-style datasets where chosen/rejected contain full conversations\ninstead of single response messages. Extracts the conversation history from\nthe chosen field and formats both chosen/rejected responses using the\nconfigured chat template.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\n\nConfiguration object containing chat_template and dataset settings\nrequired\n\n\ndataset_idx\n\nIndex of the dataset in the config (default: 0)\n0\n\n\n**kwargs\n\nAdditional keyword arguments (unused)\n{}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\ntuple\n\n(transform_fn, dataset_kwargs) where: - transform_fn: Function to transform dataset samples - dataset_kwargs: Dict with remove_columns specifying columns to drop\n\n\n\n\n\n\n{\n“chosen”: [\n{“role”: “user”, “content”: “…”},\n{“role”: “assistant”, “content”: “…”}\n],\n“rejected”: [\n{“role”: “user”, “content”: “…”},\n{“role”: “assistant”, “content”: “…”}\n]\n}"
},
{
"objectID": "docs/api/prompt_strategies.dpo.chat_template.html#functions",
"href": "docs/api/prompt_strategies.dpo.chat_template.html#functions",
"title": "prompt_strategies.dpo.chat_template",
"section": "",
"text": "Name\nDescription\n\n\n\n\nargilla_chat\nDPO chat template strategy for argilla-style datasets.\n\n\n\n\n\nprompt_strategies.dpo.chat_template.argilla_chat(cfg, dataset_idx=0, **kwargs)\nDPO chat template strategy for argilla-style datasets.\nFor argilla-style datasets where chosen/rejected contain full conversations\ninstead of single response messages. Extracts the conversation history from\nthe chosen field and formats both chosen/rejected responses using the\nconfigured chat template.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\n\nConfiguration object containing chat_template and dataset settings\nrequired\n\n\ndataset_idx\n\nIndex of the dataset in the config (default: 0)\n0\n\n\n**kwargs\n\nAdditional keyword arguments (unused)\n{}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\ntuple\n\n(transform_fn, dataset_kwargs) where: - transform_fn: Function to transform dataset samples - dataset_kwargs: Dict with remove_columns specifying columns to drop\n\n\n\n\n\n\n{\n“chosen”: [\n{“role”: “user”, “content”: “…”},\n{“role”: “assistant”, “content”: “…”}\n],\n“rejected”: [\n{“role”: “user”, “content”: “…”},\n{“role”: “assistant”, “content”: “…”}\n]\n}"
},
{
"objectID": "docs/api/cli.delinearize_llama4.html",
"href": "docs/api/cli.delinearize_llama4.html",
"title": "cli.delinearize_llama4",
"section": "",
"text": "cli.delinearize_llama4\nCLI tool to delinearize quantized/Linearized Llama-4 models.\n\n\n\n\n\nName\nDescription\n\n\n\n\ndo_cli\nConvert a patched HF format Llama4 model (with separated projections)\n\n\n\n\n\ncli.delinearize_llama4.do_cli(model, output)\nConvert a patched HF format Llama4 model (with separated projections)\nback to the original HF format (with fused projections).\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\nUnion[Path, str]\nPath to the patched HF model\nrequired\n\n\noutput\nUnion[Path, str]\nPath to save the converted model\nrequired"
},
{
"objectID": "docs/api/cli.delinearize_llama4.html#functions",
"href": "docs/api/cli.delinearize_llama4.html#functions",
"title": "cli.delinearize_llama4",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndo_cli\nConvert a patched HF format Llama4 model (with separated projections)\n\n\n\n\n\ncli.delinearize_llama4.do_cli(model, output)\nConvert a patched HF format Llama4 model (with separated projections)\nback to the original HF format (with fused projections).\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmodel\nUnion[Path, str]\nPath to the patched HF model\nrequired\n\n\noutput\nUnion[Path, str]\nPath to save the converted model\nrequired"
},
{
"objectID": "docs/faq.html",
"href": "docs/faq.html",
"title": "FAQ",
"section": "",
"text": "General\nQ: The trainer stopped and hasnt progressed in several minutes.\n\nA: Usually an issue with the GPUs communicating with each other. See the NCCL doc\n\nQ: exitcode: -9\n\nA: This usually happens when you run out of system RAM.\n\nQ: exitcode: -7 while using deepspeed\n\nA: Try upgrading deepspeed w: pip install -U deepspeed\n\nQ: AttributeError: DummyOptim object has no attribute step\nQ: ModuleNotFoundError: No module named mpi4py using single GPU with deepspeed\n\nA: You may be using deepspeed with single gpu. Please remove the deepspeed: section in the yaml file or --deepspeed CLI flag.\n\nQ: The codes is stuck on saving preprocessed datasets.\n\nA: This is usually an issue with the GPU. This can be resolved through setting the os environment variable CUDA_VISIBLE_DEVICES=0. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.\n\nQ: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.\n\nA: This is likely due to vocab size mismatch. By default, Axolotl expands the models embeddings if the tokenizer has more tokens than the model. Please use the axolotl merge-lora command to merge the adapters instead of using your own scripts.\n\n\nOn the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the models embeddings unless shrink_embeddings: true is set in the config.\n\nQ: How to call Axolotl via custom python scripts?\n\nA: Since Axolotl is just Python, please see src/axolotl/cli/main.py on how each command is called.\n\nQ: How to know the value to use for fsdp_transformer_layer_cls_to_wrap?\n\nA: This is the class name of the transformer layer to wrap with FSDP. For example, for LlamaForCausalLM, the value is LlamaDecoderLayer. To find this for a specific model, check the models PreTrainedModel definition and look for _no_split_modules variable in the modeling_<model_name>.py file within transformers library.\n\nQ: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token\n\nA: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:\n\n\nspecial_tokens:\n # str. If you're not sure, set to same as `eos_token`.\n pad_token: \"...\"\n\nQ: IterableDataset error or KeyError: 'input_ids' when using preprocess CLI\n\nA: This is because you may be using preprocess CLI with pretraining_dataset: or skip_prepare_dataset: true respectively. Please use axolotl train CLI directly instead as these datasets are prepared on demand.\n\nQ: vLLM is not working with Axolotl\n\nA: We currently recommend torch 2.6.0 for use with vllm. Please ensure you use the right version. For Docker, please use the main-py3.11-cu124-2.6.0 tag.\n\nQ: FA2 2.8.0 undefined symbol runtime error on CUDA 12.4\n\nA: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.\n\nQ: Can we mix text and text+image datasets for VLM training?\n\nA: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!\n\nQ: Why is memory/max_* different from nvidia-smi?\n\nA: We use torch APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.\n\n\n\nChat templates\nQ: jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____\n\nA: This means that the property mapping for the stated attribute does not exist when building chat_template prompt. For example, if no attribute 'content', please check you have added the correct mapping for content under message_property_mappings.\n\nQ: Empty template generated for turn ___\n\nA: The content is empty for that turn.\n\nQ: Could not find content start/end boundary for turn __\n\nA: The specific turns start/end could not be detected. Please ensure you have set the eos_token following your chat_template. Otherwise, this could be a chat_template which doesnt use proper boundaries for each turn (like system). On the rare occurrence, make sure your content is not [[dummy_message]]. Please let us know about this.\n\nQ: Content end boundary is before start boundary for turn ___\n\nA: This is an edge case which should not occur. Please create an Issue if this happens.\n\nQ: Content end boundary is the same as start boundary for turn ___. This is likely an empty turn.\n\nA: This is likely an empty turn.\n\nQ: The EOS token is incorrectly being masked or not being masked / EOS token __ not found in chat template.\n\nA: There can be two reasons:\n\n\n\nThis is because of the mismatch between tokenizer.eos_token and EOS token in template. Please make sure to set eos_token: under special_tokens: to the same EOS token as in template.\n\n\n\n\nThe EOS token is not in the template. Please check if your template is correct. As an example, phi_35 template does not use its dedicated EOS token <|endoftext|> at the end.\n\n\nQ: “chat_template choice is tokenizer_default but tokenizers chat_template is null. Please add a chat_template in tokenizer config”\n\nA: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See chat_template for more details.\n\nQ: The EOT token(s) are incorrectly being masked or not being masked / EOT token __ not found in chat template.\n\nA: There can be two reasons:\n\n\n\nThe EOT token is different from the EOS token and was not specified under eot_tokens:. Please set eot_tokens: to the same EOT token(s) as in template.\n\n\n\n\nThere is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case.\n\n\nQ: EOT token encoding failed. Please check if the token is valid and can be encoded.\n\nA: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue.\n\nQ: EOT token __ is encoded as multiple tokens.\n\nA: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under tokens: or (recommended) override unused added_tokens via added_tokens_overrides:.\n\nQ: Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot\n\nA: This is because the EOS token is in the eot_tokens: while mismatch between train_on_eos: and train_on_eot:. This will cause one to override the other. Please ensure that train_on_eos: and train_on_eot: are the same or remove the EOS token from eot_tokens:.\n\nQ: If eot_tokens: is not provided, what happens?\n\nA: If eot_tokens: is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.\n\n\nInternally, eot_tokens: tokenizer.eos_token and train_on_eot: train_on_eos (which defaults to turn). This transition helps clarify the naming and behavior of EOT/EOS tokens.\n\nQ: Data processing error: CAS service error\n\nA: Try disabling XET with export HF_HUB_DISABLE_XET=1\n\nQ: torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice.\n\nA: Depending on the version of torch, you may need to include this in your YAML:\n\n\nflex_attn_compile_kwargs:\n dynamic: false\n mode: max-autotune-no-cudagraphs\n\n**Q: ValueError(\"Backward pass should have cleared tracker of all tensors\")\n\nA: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with offload_activations: legacy in your YAML.\n\n**Q: Error parsing tool_calls arguments as JSON.\n\nA: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.",
"crumbs": [
"Troubleshooting",
"FAQ"
]
},
{
"objectID": "docs/expert_quantization.html",
"href": "docs/expert_quantization.html",
"title": "MoE Expert Quantization",
"section": "",
"text": "Transformers v5 changed MoE expert layers from nn.Linear to fused nn.Parameter (3D+ tensors).\nThis means bitsandbytes can no longer quantize them during model loading, resulting in all expert\nweights being loaded in full bf16 precision and causing massive VRAM usage.\nquantize_moe_experts solves this by quantizing expert weights during model loading.\nIt intercepts the weight loading process, quantizes each expert tensor on the fly, and\nimmediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.\nFor example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.",
"crumbs": [
"Advanced Features",
"MoE Expert Quantization"
]
},
{
"objectID": "docs/expert_quantization.html#usage",
"href": "docs/expert_quantization.html#usage",
"title": "MoE Expert Quantization",
"section": "Usage",
"text": "Usage\nEnable expert quantization in your Axolotl config:\nquantize_moe_experts: true\nThis works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.\n\nExpert LoRA targeting\nYou can optionally apply LoRA adapters directly to expert weights using lora_target_parameters:\nlora_target_parameters:\n - mlp.experts.gate_up_proj\n - mlp.experts.down_proj\n # - mlp.gate.weight # router\n\n\n\n\n\n\nNote\n\n\n\nlora_dropout must be 0 when using lora_target_parameters.",
"crumbs": [
"Advanced Features",
"MoE Expert Quantization"
]
},
{
"objectID": "docs/expert_quantization.html#requirements",
"href": "docs/expert_quantization.html#requirements",
"title": "MoE Expert Quantization",
"section": "Requirements",
"text": "Requirements\n\nRequires (adapter: lora and load_in_8bit: true) or (adapter: qlora and load_in_4bit: true)\nCUDA GPUs only (not tested with ROCm or other backends)\nFSDP2 compatible for distributed training",
"crumbs": [
"Advanced Features",
"MoE Expert Quantization"
]
},
{
"objectID": "docs/expert_quantization.html#limitations",
"href": "docs/expert_quantization.html#limitations",
"title": "MoE Expert Quantization",
"section": "Limitations",
"text": "Limitations\n\nlora_target_linear is not compatible with quantize_moe_experts. See Expert LoRA targeting instead.\ncpu_ram_efficient_loading hangs / takes long time with FSDP2 + QLoRA.\nTotal model parameter count may display incorrectly (trainable param count is correct).\nFSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.\nFSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.\nModel loading takes longer due to on-demand quantization, even on consecutive runs.\nDeepSpeed has not been tested.",
"crumbs": [
"Advanced Features",
"MoE Expert Quantization"
]
},
{
"objectID": "docs/expert_quantization.html#implementation-details",
"href": "docs/expert_quantization.html#implementation-details",
"title": "MoE Expert Quantization",
"section": "Implementation details",
"text": "Implementation details\nThe quantization is applied by patching transformers to intercept weight loading.\nWhen a 3D+ CUDA tensor with “expert” in its name is detected:\n\n4-bit mode: Uses bitsandbytes NF4 parametrization (configurable via bnb_4bit_quant_type).\n8-bit mode: Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.\n\nThe original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to\ntransformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.\nFor full implementation details, see PR #3439.",
"crumbs": [
"Advanced Features",
"MoE Expert Quantization"
]
},
{
"objectID": "docs/checkpoint_saving.html",
"href": "docs/checkpoint_saving.html",
"title": "Checkpoint Saving",
"section": "",
"text": "Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use)."
},
{
"objectID": "docs/checkpoint_saving.html#overview",
"href": "docs/checkpoint_saving.html#overview",
"title": "Checkpoint Saving",
"section": "",
"text": "Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use)."
},
{
"objectID": "docs/checkpoint_saving.html#file-based-checkpoint-trigger",
"href": "docs/checkpoint_saving.html#file-based-checkpoint-trigger",
"title": "Checkpoint Saving",
"section": "2 File-Based Checkpoint Trigger",
"text": "2 File-Based Checkpoint Trigger\n\n2.1 Configuration\nEnable in your config:\ndynamic_checkpoint:\n enabled: true\n check_interval: 100 # Optional: check every N steps (default: 100)\n trigger_file_path: \"axolotl_checkpoint.save\" # Optional: custom filename\nOptions:\n- enabled: true to enable (required)\n- check_interval: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.\n- trigger_file_path: Custom trigger filename. Default: axolotl_checkpoint.save\n\n\n2.2 How It Works\n\nRank 0 checks for trigger file every check_interval steps in output_dir\nWhen detected, file is deleted and checkpoint is saved\nIn distributed training, rank 0 broadcasts to synchronize all ranks\n\n\n\n2.3 Usage\nCommand line:\ntouch /path/to/output_dir/axolotl_checkpoint.save\nProgrammatic:\nfrom pathlib import Path\nPath(\"/path/to/output_dir/axolotl_checkpoint.save\").touch()\nCheckpoint saves within the next check_interval steps. The trigger file is auto-deleted after detection, so you can create it multiple times.\nCustom filename:\ndynamic_checkpoint:\n enabled: true\n trigger_file_path: \"my_trigger.save\"\ntouch /path/to/output_dir/my_trigger.save"
},
{
"objectID": "docs/checkpoint_saving.html#controlc-sigint-checkpoint",
"href": "docs/checkpoint_saving.html#controlc-sigint-checkpoint",
"title": "Checkpoint Saving",
"section": "3 Control+C (SIGINT) Checkpoint",
"text": "3 Control+C (SIGINT) Checkpoint\nPressing Ctrl+C during training saves the model state and exits gracefully. Note: This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger."
},
{
"objectID": "docs/checkpoint_saving.html#best-practices",
"href": "docs/checkpoint_saving.html#best-practices",
"title": "Checkpoint Saving",
"section": "4 Best Practices",
"text": "4 Best Practices\n\nCheck interval: Lower values (10-50) for fast training, default 100 for slower training\nDistributed training: Create trigger file once; rank 0 handles synchronization\nResume: Dynamic checkpoints can be resumed like regular checkpoints via resume_from_checkpoint"
},
{
"objectID": "docs/checkpoint_saving.html#example",
"href": "docs/checkpoint_saving.html#example",
"title": "Checkpoint Saving",
"section": "5 Example",
"text": "5 Example\noutput_dir: ./outputs/lora-out\nsave_steps: 500 # Scheduled checkpoints\n\ndynamic_checkpoint:\n enabled: true\n check_interval: 50\nThis enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps)."
},
{
"objectID": "docs/agents/pretraining.html",
"href": "docs/agents/pretraining.html",
"title": "Pretraining / Continual Pretraining — Agent Reference",
"section": "",
"text": "Train on raw text with no input masking. Two approaches depending on dataset size.\n\n\n\nContinual pretraining on domain-specific corpora\nAdapting a base model to a new language or domain before fine-tuning\nPretraining-style data where the entire text is the training signal\n\n\n\n\n\n\n\n\n\n\n\n\n\nNon-streaming (type: completion)\nStreaming (pretraining_dataset)\n\n\n\n\nDataset size\nFits in memory\nToo large to fit in memory\n\n\nTokenization\nPre-tokenized before training\nOn-demand during training\n\n\nConfig key\ndatasets:\npretraining_dataset:\n\n\nLong text handling\nSplits texts exceeding sequence_len\nConcatenates into fixed-length sequences\n\n\nBenefit\nCan preprocess on CPU, transfer to GPU\nStart training immediately, no preprocessing\n\n\n\n\n\n\nFor smaller datasets that fit in memory. Pre-tokenizes the entire dataset.\ndatasets:\n - path: my_corpus\n type: completion\n # field: text # Column name (default: \"text\")\n\n\n\nFor large corpora. Streams data on-demand without loading everything into memory.\npretraining_dataset:\n - path: HuggingFaceFW/fineweb-edu\n type: pretrain\n text_column: text\n split: train\n\nmax_steps: 1000 # Required — axolotl can't infer dataset size\nstreaming_multipack_buffer_size: 10000 # Buffer for sample packing\npretrain_multipack_attn: true # Prevent cross-attention between packed samples\nmax_steps is required for streaming — one step = sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus tokens.\nFull streaming docs: streaming.qmd\n\n\n\n{\"text\": \"The complete document text goes here.\"}\n\n\n\n\nsample_packing: true + pad_to_sequence_len: true — pack documents into fixed-length sequences\nflash_attention: true — required for sample packing\nNo adapter — typically full fine-tune for pretraining\ntrain_on_inputs: true — default for completion (all tokens trained on)\n\n\n\n\nsrc/axolotl/\n prompt_strategies/completion.py # Non-streaming: completion prompt strategy (no masking)\n utils/data/sft.py # Non-streaming: dataset loading and processing\n utils/data/streaming.py # Streaming: encode_streaming(), wrap_streaming_dataset()\n utils/schemas/config.py # Config fields: pretraining_dataset, pretrain_multipack_attn, etc.\n\nexamples/streaming/pretrain.yaml # Full streaming pretraining example config"
},
{
"objectID": "docs/agents/pretraining.html#when-to-use",
"href": "docs/agents/pretraining.html#when-to-use",
"title": "Pretraining / Continual Pretraining — Agent Reference",
"section": "",
"text": "Continual pretraining on domain-specific corpora\nAdapting a base model to a new language or domain before fine-tuning\nPretraining-style data where the entire text is the training signal"
},
{
"objectID": "docs/agents/pretraining.html#choosing-an-approach",
"href": "docs/agents/pretraining.html#choosing-an-approach",
"title": "Pretraining / Continual Pretraining — Agent Reference",
"section": "",
"text": "Non-streaming (type: completion)\nStreaming (pretraining_dataset)\n\n\n\n\nDataset size\nFits in memory\nToo large to fit in memory\n\n\nTokenization\nPre-tokenized before training\nOn-demand during training\n\n\nConfig key\ndatasets:\npretraining_dataset:\n\n\nLong text handling\nSplits texts exceeding sequence_len\nConcatenates into fixed-length sequences\n\n\nBenefit\nCan preprocess on CPU, transfer to GPU\nStart training immediately, no preprocessing"
},
{
"objectID": "docs/agents/pretraining.html#non-streaming-type-completion",
"href": "docs/agents/pretraining.html#non-streaming-type-completion",
"title": "Pretraining / Continual Pretraining — Agent Reference",
"section": "",
"text": "For smaller datasets that fit in memory. Pre-tokenizes the entire dataset.\ndatasets:\n - path: my_corpus\n type: completion\n # field: text # Column name (default: \"text\")"
},
{
"objectID": "docs/agents/pretraining.html#streaming-pretraining_dataset",
"href": "docs/agents/pretraining.html#streaming-pretraining_dataset",
"title": "Pretraining / Continual Pretraining — Agent Reference",
"section": "",
"text": "For large corpora. Streams data on-demand without loading everything into memory.\npretraining_dataset:\n - path: HuggingFaceFW/fineweb-edu\n type: pretrain\n text_column: text\n split: train\n\nmax_steps: 1000 # Required — axolotl can't infer dataset size\nstreaming_multipack_buffer_size: 10000 # Buffer for sample packing\npretrain_multipack_attn: true # Prevent cross-attention between packed samples\nmax_steps is required for streaming — one step = sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus tokens.\nFull streaming docs: streaming.qmd"
},
{
"objectID": "docs/agents/pretraining.html#dataset-format",
"href": "docs/agents/pretraining.html#dataset-format",
"title": "Pretraining / Continual Pretraining — Agent Reference",
"section": "",
"text": "{\"text\": \"The complete document text goes here.\"}"
},
{
"objectID": "docs/agents/pretraining.html#key-settings",
"href": "docs/agents/pretraining.html#key-settings",
"title": "Pretraining / Continual Pretraining — Agent Reference",
"section": "",
"text": "sample_packing: true + pad_to_sequence_len: true — pack documents into fixed-length sequences\nflash_attention: true — required for sample packing\nNo adapter — typically full fine-tune for pretraining\ntrain_on_inputs: true — default for completion (all tokens trained on)"
},
{
"objectID": "docs/agents/pretraining.html#file-map",
"href": "docs/agents/pretraining.html#file-map",
"title": "Pretraining / Continual Pretraining — Agent Reference",
"section": "",
"text": "src/axolotl/\n prompt_strategies/completion.py # Non-streaming: completion prompt strategy (no masking)\n utils/data/sft.py # Non-streaming: dataset loading and processing\n utils/data/streaming.py # Streaming: encode_streaming(), wrap_streaming_dataset()\n utils/schemas/config.py # Config fields: pretraining_dataset, pretrain_multipack_attn, etc.\n\nexamples/streaming/pretrain.yaml # Full streaming pretraining example config"
},
{
"objectID": "docs/agents/grpo.html",
"href": "docs/agents/grpo.html",
"title": "GRPO — Agent Reference",
"section": "",
"text": "Online RL with verifiable reward functions. For full config reference, async features, and scaling, see grpo.qmd. For vLLM setup, see vllm_serving.qmd.\n\n\nTerminal 1 (GPU 0) Terminal 2 (GPU 1)\n┌──────────────────────┐ ┌──────────────────────────────────┐\n│ vLLM Server │ HTTP │ Trainer │\n│ Serves base model │◄────────────►│ 1. Send prompts to vLLM │\n│ + LoRA adapter │ /generate │ 2. Score completions (rewards) │\n│ │ /set_lora │ 3. Compute advantages │\n│ Punica kernels for │ │ 4. PPO-clip gradient update │\n│ LoRA inference │ │ 5. Sync LoRA weights to vLLM │\n└──────────────────────┘ └──────────────────────────────────┘\n\n\n\n\nA YAML config with rl: grpo\nA reward module (Python file with reward functions)\nA running vLLM server (axolotl vllm-serve config.yaml)\n\n\n\n\ndef my_reward(completions, **kwargs) -> list[float]:\n # completions[i][0][\"content\"] = text of i-th completion\n # **kwargs contains dataset columns not removed by transform\n return [score_for_each_completion]\nMultiple rewards: reward_funcs: [r1, r2] with reward_weights: [1.0, 0.5].\n\n\n\n\n\n\n\n\n\n\n\nFeature\nConfig\nPurpose\n\n\n\n\nAsync prefetch\nasync_prefetch: true\nOverlap generation with training\n\n\nLoRA sync\nvllm_lora_sync: true\nFast adapter sync via filesystem\n\n\nStreaming scoring\nstreaming_partial_batch: true\nScore one group at a time\n\n\nZero-adv skip\nskip_zero_advantage_batches: true\nSkip batches with no learning signal\n\n\nReplay buffer\nreplay_buffer_size: 100\nCache high-signal groups\n\n\nIS correction\nvllm_importance_sampling_correction: true\nFix off-policy distribution shift\n\n\n\n\n\n\n\nrewards/*/mean > 0.15 within 20 steps (else: test reward function standalone)\nreward_std > 0 on most steps (else: no learning signal)\nentropy 0.05-0.5 (< 0.01 = mode collapse)\ngrad_norm 0.001-1.0 (> 10 = unstable, 0.0 = zero-advantage skip)\n\nSee training_stability.qmd for detailed diagnostics.\n\n\n\nsrc/axolotl/\n cli/train.py # Entry point\n cli/vllm_serve.py # Entry point for vLLM server\n core/trainers/grpo/\n trainer.py # AxolotlGRPOTrainer\n sampler.py # Sampling utilities\n core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer\n scripts/vllm_serve_lora.py # vLLM serve script with LoRA sync support\n utils/schemas/trl.py # TRL config schema (all trl: options)\n\ndocs/grpo.qmd # Full user docs: async, rewards, scaling, config reference\ndocs/vllm_serving.qmd # vLLM server modes, LoRA sync, weight sync"
},
{
"objectID": "docs/agents/grpo.html#architecture",
"href": "docs/agents/grpo.html#architecture",
"title": "GRPO — Agent Reference",
"section": "",
"text": "Terminal 1 (GPU 0) Terminal 2 (GPU 1)\n┌──────────────────────┐ ┌──────────────────────────────────┐\n│ vLLM Server │ HTTP │ Trainer │\n│ Serves base model │◄────────────►│ 1. Send prompts to vLLM │\n│ + LoRA adapter │ /generate │ 2. Score completions (rewards) │\n│ │ /set_lora │ 3. Compute advantages │\n│ Punica kernels for │ │ 4. PPO-clip gradient update │\n│ LoRA inference │ │ 5. Sync LoRA weights to vLLM │\n└──────────────────────┘ └──────────────────────────────────┘"
},
{
"objectID": "docs/agents/grpo.html#components-required",
"href": "docs/agents/grpo.html#components-required",
"title": "GRPO — Agent Reference",
"section": "",
"text": "A YAML config with rl: grpo\nA reward module (Python file with reward functions)\nA running vLLM server (axolotl vllm-serve config.yaml)"
},
{
"objectID": "docs/agents/grpo.html#reward-function-signature",
"href": "docs/agents/grpo.html#reward-function-signature",
"title": "GRPO — Agent Reference",
"section": "",
"text": "def my_reward(completions, **kwargs) -> list[float]:\n # completions[i][0][\"content\"] = text of i-th completion\n # **kwargs contains dataset columns not removed by transform\n return [score_for_each_completion]\nMultiple rewards: reward_funcs: [r1, r2] with reward_weights: [1.0, 0.5]."
},
{
"objectID": "docs/agents/grpo.html#key-async-features",
"href": "docs/agents/grpo.html#key-async-features",
"title": "GRPO — Agent Reference",
"section": "",
"text": "Feature\nConfig\nPurpose\n\n\n\n\nAsync prefetch\nasync_prefetch: true\nOverlap generation with training\n\n\nLoRA sync\nvllm_lora_sync: true\nFast adapter sync via filesystem\n\n\nStreaming scoring\nstreaming_partial_batch: true\nScore one group at a time\n\n\nZero-adv skip\nskip_zero_advantage_batches: true\nSkip batches with no learning signal\n\n\nReplay buffer\nreplay_buffer_size: 100\nCache high-signal groups\n\n\nIS correction\nvllm_importance_sampling_correction: true\nFix off-policy distribution shift"
},
{
"objectID": "docs/agents/grpo.html#health-checks",
"href": "docs/agents/grpo.html#health-checks",
"title": "GRPO — Agent Reference",
"section": "",
"text": "rewards/*/mean > 0.15 within 20 steps (else: test reward function standalone)\nreward_std > 0 on most steps (else: no learning signal)\nentropy 0.05-0.5 (< 0.01 = mode collapse)\ngrad_norm 0.001-1.0 (> 10 = unstable, 0.0 = zero-advantage skip)\n\nSee training_stability.qmd for detailed diagnostics."
},
{
"objectID": "docs/agents/grpo.html#file-map",
"href": "docs/agents/grpo.html#file-map",
"title": "GRPO — Agent Reference",
"section": "",
"text": "src/axolotl/\n cli/train.py # Entry point\n cli/vllm_serve.py # Entry point for vLLM server\n core/trainers/grpo/\n trainer.py # AxolotlGRPOTrainer\n sampler.py # Sampling utilities\n core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer\n scripts/vllm_serve_lora.py # vLLM serve script with LoRA sync support\n utils/schemas/trl.py # TRL config schema (all trl: options)\n\ndocs/grpo.qmd # Full user docs: async, rewards, scaling, config reference\ndocs/vllm_serving.qmd # vLLM server modes, LoRA sync, weight sync"
},
{
"objectID": "docs/agents/sft.html",
"href": "docs/agents/sft.html",
"title": "SFT — Agent Reference",
"section": "",
"text": "Supervised fine-tuning pipeline reference. For config templates and dataset format examples, see getting-started.qmd and dataset-formats/.\n\n\nYAML Config → axolotl train config.yaml\n\n 1. Load base model (+ quantization if QLoRA/8-bit)\n 2. Apply adapter layers (LoRA/QLoRA) if configured\n 3. Load + tokenize dataset(s)\n - Apply prompt template (chat_template / alpaca / custom)\n - Mask inputs (train_on_inputs: false)\n - Pack samples into sequences (sample_packing: true)\n 4. Training loop (HuggingFace Trainer)\n - forward → loss → backward → optimizer step → lr scheduler step\n 5. Save model / adapter weights + tokenizer\n\nMulti-GPU: FSDP or DeepSpeed shards model across GPUs automatically.\n\n\n\n\nA YAML config — model, dataset(s), adapter settings, hyperparameters\nA dataset — HuggingFace Hub, local JSONL/JSON/Parquet, or S3/GCS path\n(Optional) A custom prompt strategy — for non-standard dataset formats\n\nNo external server processes needed (unlike GRPO which requires vLLM).\n\n\n\nIs your data in chat/message format?\n ├─ YES: OpenAI message format (role/content)?\n │ ├─ YES ──────────────────────> type: chat_template (recommended)\n │ └─ NO (custom field names) ──> type: chat_template + message_property_mappings\n └─ NO: Instruction/response pairs?\n ├─ YES ──> type: alpaca (instruction, input, output)\n └─ NO: Raw text?\n ├─ YES with segments ─────> type: input_output (template-free masking)\n └─ YES continuous ────────> type: completion (pretraining-style)\nFull format specs: dataset-formats/\n\n\n\n\n\n\n\n\n\n\n\n\n\nModel Size\nLoRA\nQLoRA (4-bit)\nFull Fine-Tune\nVRAM (approx)\n\n\n\n\n1-3B\nPreferred\nLow-budget option\nSingle GPU OK\n8-16 GB (LoRA)\n\n\n7-8B\nPreferred\nGood balance\nNeeds multi-GPU\n16-24 GB (LoRA)\n\n\n13-14B\nPreferred\nGood balance\nMulti-GPU required\n24-40 GB (LoRA)\n\n\n30-70B\nLoRA or QLoRA\nPreferred for single GPU\nMulti-node\n40-80 GB (QLoRA)\n\n\n\n\n\n\n\n\n\nParameter\nLoRA\nQLoRA\nFull FT\n\n\n\n\nlearning_rate\n1e-4 to 3e-4\n1e-4 to 3e-4\n1e-5 to 5e-5\n\n\nlora_r\n16-64\n16-64\nN/A\n\n\nlora_alpha\n1-2x lora_r\n1-2x lora_r\nN/A\n\n\nmicro_batch_size\n2-8\n2-4\n1-2\n\n\ngradient_accumulation_steps\n2-8\n4-16\n4-16\n\n\nnum_epochs\n1-3\n1-3\n1-3\n\n\noptimizer\nadamw_8bit\nadamw_bnb_8bit\nadamw_torch_fused\n\n\n\nEffective batch = micro_batch * grad_accum * num_gpus. Lower LR for larger models.\n\n\n\n\n\n\n\n\n\n\n\nMetric\nHealthy\nProblem\n\n\n\n\ntrain_loss\nDecreasing, starting ~2-4 for chat models\nFlat or increasing from step 1 — data or LR issue\n\n\neval_loss\nDecreasing, tracks train_loss\nIncreasing while train_loss decreases — overfitting\n\n\ngrad_norm\n0.1-10, relatively stable\nSpikes >100 — instability. 0.0 — frozen weights\n\n\nlearning_rate\nFollows scheduler curve\nFlat or NaN — config issue\n\n\n\nWatch for: loss never decreasing (check train_on_inputs, dataset, LR), loss goes to 0 quickly (overfitting), eval_loss diverging (reduce epochs, add regularization). See training_stability.qmd.\n\n\n\n\n\n\n\n\n\n\nIssue\nFix\n\n\n\n\nOOM during training\nReduce micro_batch_size, enable gradient_checkpointing, reduce sequence_len\n\n\nsample_packing + SDPA + bf16 = 0.0 loss\nUse flash_attention: true or disable sample_packing\n\n\nMissing chat template error\nSet chat_template: chatml explicitly\n\n\nLabel masking wrong\nRun axolotl preprocess config.yaml --debug and inspect labels\n\n\nLoss NaN\nUse bf16: auto, lower LR, check data for empty samples\n\n\nTokenizer pad token / infinite loss\nSet special_tokens: pad_token: \"<\\|end_of_text\\|>\"\n\n\nFSDP save hangs\nUse fsdp_state_dict_type: FULL_STATE_DICT\n\n\nDeepSpeed CheckpointError\nSet use_reentrant: true in gradient_checkpointing_kwargs\n\n\n\nFull troubleshooting: training_stability.qmd, debugging.qmd\n\n\n\nsrc/axolotl/\n cli/train.py # Entry point for `axolotl train`\n cli/preprocess.py # Entry point for `axolotl preprocess`\n core/builders/causal.py # HFCausalTrainerBuilder — wires config → SFT trainer\n core/trainers/base.py # AxolotlTrainer — base trainer class\n core/trainers/mixins/ # Packing, optimizer, scheduler, checkpoints\n prompt_strategies/ # Format handlers: chat_template, alpaca, completion, input_output\n utils/schemas/config.py # AxolotlInputConfig — main config schema\n utils/schemas/datasets.py # SFTDataset, DatasetConfig\n utils/schemas/peft.py # LoraConfig — LoRA parameters\n integrations/liger/ # Liger kernel plugin\n\nexamples/llama-3/ # LoRA, QLoRA, full FT example configs\ndocs/getting-started.qmd # Quickstart with config templates\ndocs/optimizations.qmd # Flash attention, gradient checkpointing, sample packing\ndocs/multi-gpu.qmd # FSDP and DeepSpeed setup"
},
{
"objectID": "docs/agents/sft.html#architecture",
"href": "docs/agents/sft.html#architecture",
"title": "SFT — Agent Reference",
"section": "",
"text": "YAML Config → axolotl train config.yaml\n\n 1. Load base model (+ quantization if QLoRA/8-bit)\n 2. Apply adapter layers (LoRA/QLoRA) if configured\n 3. Load + tokenize dataset(s)\n - Apply prompt template (chat_template / alpaca / custom)\n - Mask inputs (train_on_inputs: false)\n - Pack samples into sequences (sample_packing: true)\n 4. Training loop (HuggingFace Trainer)\n - forward → loss → backward → optimizer step → lr scheduler step\n 5. Save model / adapter weights + tokenizer\n\nMulti-GPU: FSDP or DeepSpeed shards model across GPUs automatically."
},
{
"objectID": "docs/agents/sft.html#components-required",
"href": "docs/agents/sft.html#components-required",
"title": "SFT — Agent Reference",
"section": "",
"text": "A YAML config — model, dataset(s), adapter settings, hyperparameters\nA dataset — HuggingFace Hub, local JSONL/JSON/Parquet, or S3/GCS path\n(Optional) A custom prompt strategy — for non-standard dataset formats\n\nNo external server processes needed (unlike GRPO which requires vLLM)."
},
{
"objectID": "docs/agents/sft.html#dataset-format-decision-tree",
"href": "docs/agents/sft.html#dataset-format-decision-tree",
"title": "SFT — Agent Reference",
"section": "",
"text": "Is your data in chat/message format?\n ├─ YES: OpenAI message format (role/content)?\n │ ├─ YES ──────────────────────> type: chat_template (recommended)\n │ └─ NO (custom field names) ──> type: chat_template + message_property_mappings\n └─ NO: Instruction/response pairs?\n ├─ YES ──> type: alpaca (instruction, input, output)\n └─ NO: Raw text?\n ├─ YES with segments ─────> type: input_output (template-free masking)\n └─ YES continuous ────────> type: completion (pretraining-style)\nFull format specs: dataset-formats/"
},
{
"objectID": "docs/agents/sft.html#model-size-to-adapter-choice",
"href": "docs/agents/sft.html#model-size-to-adapter-choice",
"title": "SFT — Agent Reference",
"section": "",
"text": "Model Size\nLoRA\nQLoRA (4-bit)\nFull Fine-Tune\nVRAM (approx)\n\n\n\n\n1-3B\nPreferred\nLow-budget option\nSingle GPU OK\n8-16 GB (LoRA)\n\n\n7-8B\nPreferred\nGood balance\nNeeds multi-GPU\n16-24 GB (LoRA)\n\n\n13-14B\nPreferred\nGood balance\nMulti-GPU required\n24-40 GB (LoRA)\n\n\n30-70B\nLoRA or QLoRA\nPreferred for single GPU\nMulti-node\n40-80 GB (QLoRA)"
},
{
"objectID": "docs/agents/sft.html#hyperparameter-ranges",
"href": "docs/agents/sft.html#hyperparameter-ranges",
"title": "SFT — Agent Reference",
"section": "",
"text": "Parameter\nLoRA\nQLoRA\nFull FT\n\n\n\n\nlearning_rate\n1e-4 to 3e-4\n1e-4 to 3e-4\n1e-5 to 5e-5\n\n\nlora_r\n16-64\n16-64\nN/A\n\n\nlora_alpha\n1-2x lora_r\n1-2x lora_r\nN/A\n\n\nmicro_batch_size\n2-8\n2-4\n1-2\n\n\ngradient_accumulation_steps\n2-8\n4-16\n4-16\n\n\nnum_epochs\n1-3\n1-3\n1-3\n\n\noptimizer\nadamw_8bit\nadamw_bnb_8bit\nadamw_torch_fused\n\n\n\nEffective batch = micro_batch * grad_accum * num_gpus. Lower LR for larger models."
},
{
"objectID": "docs/agents/sft.html#healthy-training-indicators",
"href": "docs/agents/sft.html#healthy-training-indicators",
"title": "SFT — Agent Reference",
"section": "",
"text": "Metric\nHealthy\nProblem\n\n\n\n\ntrain_loss\nDecreasing, starting ~2-4 for chat models\nFlat or increasing from step 1 — data or LR issue\n\n\neval_loss\nDecreasing, tracks train_loss\nIncreasing while train_loss decreases — overfitting\n\n\ngrad_norm\n0.1-10, relatively stable\nSpikes >100 — instability. 0.0 — frozen weights\n\n\nlearning_rate\nFollows scheduler curve\nFlat or NaN — config issue\n\n\n\nWatch for: loss never decreasing (check train_on_inputs, dataset, LR), loss goes to 0 quickly (overfitting), eval_loss diverging (reduce epochs, add regularization). See training_stability.qmd."
},
{
"objectID": "docs/agents/sft.html#known-issues",
"href": "docs/agents/sft.html#known-issues",
"title": "SFT — Agent Reference",
"section": "",
"text": "Issue\nFix\n\n\n\n\nOOM during training\nReduce micro_batch_size, enable gradient_checkpointing, reduce sequence_len\n\n\nsample_packing + SDPA + bf16 = 0.0 loss\nUse flash_attention: true or disable sample_packing\n\n\nMissing chat template error\nSet chat_template: chatml explicitly\n\n\nLabel masking wrong\nRun axolotl preprocess config.yaml --debug and inspect labels\n\n\nLoss NaN\nUse bf16: auto, lower LR, check data for empty samples\n\n\nTokenizer pad token / infinite loss\nSet special_tokens: pad_token: \"<\\|end_of_text\\|>\"\n\n\nFSDP save hangs\nUse fsdp_state_dict_type: FULL_STATE_DICT\n\n\nDeepSpeed CheckpointError\nSet use_reentrant: true in gradient_checkpointing_kwargs\n\n\n\nFull troubleshooting: training_stability.qmd, debugging.qmd"
},
{
"objectID": "docs/agents/sft.html#file-map",
"href": "docs/agents/sft.html#file-map",
"title": "SFT — Agent Reference",
"section": "",
"text": "src/axolotl/\n cli/train.py # Entry point for `axolotl train`\n cli/preprocess.py # Entry point for `axolotl preprocess`\n core/builders/causal.py # HFCausalTrainerBuilder — wires config → SFT trainer\n core/trainers/base.py # AxolotlTrainer — base trainer class\n core/trainers/mixins/ # Packing, optimizer, scheduler, checkpoints\n prompt_strategies/ # Format handlers: chat_template, alpaca, completion, input_output\n utils/schemas/config.py # AxolotlInputConfig — main config schema\n utils/schemas/datasets.py # SFTDataset, DatasetConfig\n utils/schemas/peft.py # LoraConfig — LoRA parameters\n integrations/liger/ # Liger kernel plugin\n\nexamples/llama-3/ # LoRA, QLoRA, full FT example configs\ndocs/getting-started.qmd # Quickstart with config templates\ndocs/optimizations.qmd # Flash attention, gradient checkpointing, sample packing\ndocs/multi-gpu.qmd # FSDP and DeepSpeed setup"
},
{
"objectID": "docs/multi-gpu.html",
"href": "docs/multi-gpu.html",
"title": "Multi-GPU",
"section": "",
"text": "This guide covers advanced training configurations for multi-GPU setups using Axolotl.",
"crumbs": [
"Deployments",
"Multi-GPU"
]
},
{
"objectID": "docs/multi-gpu.html#sec-overview",
"href": "docs/multi-gpu.html#sec-overview",
"title": "Multi-GPU",
"section": "Overview",
"text": "Overview\nWhen training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy.\nYou generally cannot combine these strategies; they are mutually exclusive.\n\nDeepSpeed: Powerful optimization library, supports ZeRO stages 1-3.\nFSDP (Fully Sharded Data Parallel): PyTorchs native sharding implementation (Recommended).\nDDP (Distributed Data Parallel): PyTorchs native parallelism implementation (Default if neither of the above are selected).\n\nThese features can often be combined with the strategies above:\n\nSequence Parallelism: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP).\nFSDP + QLoRA: Combines 4-bit quantization with FSDP (Specific to FSDP).",
"crumbs": [
"Deployments",
"Multi-GPU"
]
},
{
"objectID": "docs/multi-gpu.html#sec-deepspeed",
"href": "docs/multi-gpu.html#sec-deepspeed",
"title": "Multi-GPU",
"section": "DeepSpeed",
"text": "DeepSpeed\n\nConfiguration\nAdd to your YAML config:\ndeepspeed: deepspeed_configs/zero1.json\n\n\nUsage\n# Fetch deepspeed configs (if not already present)\naxolotl fetch deepspeed_configs\n\n# Passing arg via config\naxolotl train config.yml\n\n# Passing arg via cli\naxolotl train config.yml --deepspeed deepspeed_configs/zero1.json\n\n\nZeRO Stages\nWe provide default configurations for:\n\nZeRO Stage 1 (zero1.json)\nZeRO Stage 1 with torch compile (zero1_torch_compile.json)\nZeRO Stage 2 (zero2.json)\nZeRO Stage 3 (zero3.json)\nZeRO Stage 3 with bf16 (zero3_bf16.json)\nZeRO Stage 3 with bf16 and CPU offload params(zero3_bf16_cpuoffload_params.json)\nZeRO Stage 3 with bf16 and CPU offload params and optimizer (zero3_bf16_cpuoffload_all.json)\n\n\n\n\n\n\n\nTip\n\n\n\nChoose the configuration that offloads the least amount to memory while still being able to fit on VRAM for best performance.\nStart from Stage 1 -> Stage 2 -> Stage 3.",
"crumbs": [
"Deployments",
"Multi-GPU"
]
},
{
"objectID": "docs/multi-gpu.html#sec-fsdp",
"href": "docs/multi-gpu.html#sec-fsdp",
"title": "Multi-GPU",
"section": "Fully Sharded Data Parallel (FSDP)",
"text": "Fully Sharded Data Parallel (FSDP)\nFSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers.\n\n\n\n\n\n\nNote\n\n\n\nFSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.\n\n\n\nFSDP + QLoRA\nFor combining FSDP with QLoRA, see our dedicated guide.\n\n\nMigrating from FSDP1 to FSDP2\nTo migrate your config from FSDP1 to FSDP2, you must use the fsdp_version top-level config field to specify the FSDP version, and\nalso follow the config field mapping below to update field names.\n\nConfig mapping\n\n\n\nFSDP1\nFSDP2\n\n\n\n\nfsdp_sharding_strategy\nreshard_after_forward\n\n\nfsdp_backward_prefetch_policy\nREMOVED\n\n\nfsdp_backward_prefetch\nREMOVED\n\n\nfsdp_forward_prefetch\nREMOVED\n\n\nfsdp_sync_module_states\nREMOVED\n\n\nfsdp_cpu_ram_efficient_loading\ncpu_ram_efficient_loading\n\n\nfsdp_state_dict_type\nstate_dict_type\n\n\nfsdp_use_orig_params\nREMOVED\n\n\nfsdp_activation_checkpointing\nactivation_checkpointing\n\n\n\nFor more details, please see the migration guide in the torchtitan repo. In Axolotl,\nif you were using the following FSDP1 config:\nfsdp_version: 1\nfsdp_config:\n fsdp_offload_params: false\n fsdp_cpu_ram_efficient_loading: true\n fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer\n fsdp_state_dict_type: FULL_STATE_DICT\n fsdp_sharding_strategy: FULL_SHARD\nYou can migrate to the following FSDP2 config:\nfsdp_version: 2\nfsdp_config:\n offload_params: false\n cpu_ram_efficient_loading: true\n auto_wrap_policy: TRANSFORMER_BASED_WRAP\n transformer_layer_cls_to_wrap: Qwen3DecoderLayer\n state_dict_type: FULL_STATE_DICT\n reshard_after_forward: true\n\n\n\nFSDP1 (deprecated)\n\n\n\n\n\n\nNote\n\n\n\nUsing fsdp to configure FSDP is deprecated and will be removed in an upcoming release of Axolotl. Please use fsdp_config as above instead.\n\n\nfsdp:\n - full_shard\n - auto_wrap\nfsdp_config:\n fsdp_offload_params: true\n fsdp_state_dict_type: FULL_STATE_DICT\n fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer",
"crumbs": [
"Deployments",
"Multi-GPU"
]
},
{
"objectID": "docs/multi-gpu.html#sec-sequence-parallelism",
"href": "docs/multi-gpu.html#sec-sequence-parallelism",
"title": "Multi-GPU",
"section": "Sequence parallelism",
"text": "Sequence parallelism\nWe support sequence parallelism (SP) via the\nring-flash-attention project. This\nallows one to split up sequences across GPUs, which is useful in the event that a\nsingle sequence causes OOM errors during model training.\nSee our dedicated guide for more information.",
"crumbs": [
"Deployments",
"Multi-GPU"
]
},
{
"objectID": "docs/multi-gpu.html#sec-performance",
"href": "docs/multi-gpu.html#sec-performance",
"title": "Multi-GPU",
"section": "Performance Optimization",
"text": "Performance Optimization\n\nLiger Kernel Integration\nPlease see docs for more info.",
"crumbs": [
"Deployments",
"Multi-GPU"
]
},
{
"objectID": "docs/multi-gpu.html#sec-troubleshooting",
"href": "docs/multi-gpu.html#sec-troubleshooting",
"title": "Multi-GPU",
"section": "Troubleshooting",
"text": "Troubleshooting\n\nNCCL Issues\nFor NCCL-related problems, see our NCCL troubleshooting guide.\n\n\nCommon Problems\n\nMemory IssuesTraining Instability\n\n\n\nReduce micro_batch_size\nReduce eval_batch_size\nAdjust gradient_accumulation_steps\nConsider using a higher ZeRO stage\n\n\n\n\nStart with DeepSpeed ZeRO-2\nMonitor loss values\nCheck learning rates\n\n\n\n\nFor more detailed troubleshooting, see our debugging guide.",
"crumbs": [
"Deployments",
"Multi-GPU"
]
},
{
"objectID": "docs/nd_parallelism.html",
"href": "docs/nd_parallelism.html",
"title": "N-D Parallelism (Beta)",
"section": "",
"text": "Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:\nor combinations of the above!",
"crumbs": [
"Advanced Features",
"N-D Parallelism (Beta)"
]
},
{
"objectID": "docs/nd_parallelism.html#core-concepts",
"href": "docs/nd_parallelism.html#core-concepts",
"title": "N-D Parallelism (Beta)",
"section": "Core Concepts",
"text": "Core Concepts\nParallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorchs DeviceMesh is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.\n\nData Parallelism\nData Parallelism focuses on splitting the global data batch across GPUs.\n\nDistributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.\nFully Sharded Data Parallel (FSDP): A highly memory-efficient form of data parallelism (inspired by DeepSpeeds ZeRO). Instead of replicating the model, FSDP shards the models parameters, gradients, and optimizer states across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an all_gather operation just before they are used, and they can be discarded immediately after (reshard-after-forward).\n\nFSDP maps to ZeRO stages:\n\nZeRO-2 (reshard_after_forward=False): Shards gradients and optimizer states. Model weights are replicated on each GPU.\nZeRO-3 (reshard_after_forward=True): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).\n\n\n\n\n\n[Experimental] Tensor Parallelism (TP)\nAlso known as “horizontal model parallelism,” as described in the Megatron-LM paper. Instead of splitting the batch, TP splits the models layers themselves across GPUs.\n\nHow it works: For a linear layer Y = XA, the weight matrix A is split column-wise (A = [A_1, A_2]). The computation becomes Y_1 = XA_1 and Y_2 = XA_2, which can happen in parallel on different GPUs. The final output Y is simply the concatenation of Y_1 and Y_2. Check this comment for more detailed info.\nRequirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.\n\n\n\nContext Parallelism (CP)\nContext Parallelism, also called Sequence Parallelism, addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.\n\nHow it works: If you have a sequence of 8192 tokens and a context_parallel_size of 4, each GPU will only handle a chunk of 2048 tokens.\nThe Challenge: Attention is not local; every token needs to “attend to” every other token. Splitting the sequence breaks this.\nThe Solution (ring-flash-attention): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a “ring.” After N-1 steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized flash-attention kernel at each step.\n\n\n\nHybrid Sharding Data Parallel (HSDP)\nHSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.\n\nIntra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the all_gather operations for sharded parameters fast.\nInter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDPs parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).\nExample: With 2 nodes of 8 GPUs each (16 total), you could have dp_shard_size=8 (FSDP within each node) and dp_replicate_size=2 (DDP across the two nodes).",
"crumbs": [
"Advanced Features",
"N-D Parallelism (Beta)"
]
},
{
"objectID": "docs/nd_parallelism.html#usage",
"href": "docs/nd_parallelism.html#usage",
"title": "N-D Parallelism (Beta)",
"section": "Usage",
"text": "Usage\n# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp\nfsdp_version: 2\nfsdp_config:\n # ...\n\n# The number of GPUs to shard the model parameters across (FSDP dimension).\ndp_shard_size: 4\n\n# The number of times to replicate the sharded model (DDP dimension).\ndp_replicate_size: 2\n\n# Number of GPUs for Tensor Parallelism.\ntensor_parallel_size: 1 # (default is 1, no TP)\n\n# Number of GPUs for Context/Sequence Parallelism.\ncontext_parallel_size: 1 # (default is 1, no CP)\nNote: We recommend FSDP. DeepSpeed is only compatible with tensor_parallel_size.",
"crumbs": [
"Advanced Features",
"N-D Parallelism (Beta)"
]
},
{
"objectID": "docs/nd_parallelism.html#examples",
"href": "docs/nd_parallelism.html#examples",
"title": "N-D Parallelism (Beta)",
"section": "Examples",
"text": "Examples\n\n\n\n\n\n\nTip\n\n\n\nSee our example configs here.\n\n\n\nHSDP on 2 nodes with 4 GPUs each (8 GPUs total):\n\nYou want FSDP within each node and DDP across nodes.\nSet dp_shard_size: 4 and dp_replicate_size: 2.\n\nFSDP + TP on a single 8-GPU node:\n\nYou want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.\nSet dp_shard_size: 4 and tensor_parallel_size: 2.\n\nFSDP + CP on a single 8-GPU node for long context:\n\nYou want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.\nSet dp_shard_size: 8 and context_parallel_size: 8. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.",
"crumbs": [
"Advanced Features",
"N-D Parallelism (Beta)"
]
},
{
"objectID": "docs/nd_parallelism.html#support-matrix",
"href": "docs/nd_parallelism.html#support-matrix",
"title": "N-D Parallelism (Beta)",
"section": "Support Matrix",
"text": "Support Matrix\nThis matrix describes how different parallelism methods can be combined in Axolotl.\n\n\n\n\n\n\n\n\n\n\n\nCombination\ndp_replicate_size\ndp_shard_size\ntp_size\ncp_size\nStatus & Notes\n\n\n\n\nFSDP (ZeRO-3)\n1\n>1\n1\n1\n✅ Fully supported. Shards model across all GPUs.\n\n\nHSDP\n>1\n>1\n1\n1\n✅ Fully supported. FSDP intra-node, DDP inter-node.\n\n\nFSDP + TP\n1\n>1\n>1\n1\n✅ 2D Parallelism. Shards the model across a dp_shard group, and TP-splits layers within the tp group.\n\n\nHSDP + TP\n>1\n>1\n>1\n1\n✅ 3D Parallelism. A powerful but complex combination.\n\n\nFSDP + CP\n1\n>1\n1\n>1\n✅ 2D Parallelism. Combines FSDP with context parallelism.\n\n\nFSDP + TP + CP\n1\n>1\n>1\n>1\n✅ 3D Parallelism. Another advanced combination.\n\n\nDDP + TP/CP\n>1\n1\n>1\n>1\n❌ Not Supported. The ParallelismConfig explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (dp_shard_size > 1).\n\n\nJust TP / CP\n1\n1\n>1\n>1\n✅ Supported. Useful for inference or when the model fits on one GPU but context is too long.\n\n\n\n\ntp_size refers to tensor_parallel_size\ncp_size refers to context_parallel_size",
"crumbs": [
"Advanced Features",
"N-D Parallelism (Beta)"
]
},
{
"objectID": "docs/mac.html",
"href": "docs/mac.html",
"title": "Mac M-series",
"section": "",
"text": "Currently Axolotl on Mac is partially usable, many of the dependencies of Axolotl including Pytorch do not support MPS or have incomplete support.\nCurrent support:\n\nSupport for all models\nFull training of models\nLoRA training\nSample packing\nFP16 and BF16 (awaiting AMP support for MPS in Pytorch)\nTri-daos flash-attn (until it is supported use spd_attention as an alternative)\nxformers\nbitsandbytes (meaning no 4/8 bits loading and bnb optimizers)\nqlora\nDeepSpeed\n\nUntested:\n\nFSDP",
"crumbs": [
"Deployments",
"Mac M-series"
]
},
{
"objectID": "docs/reward_modelling.html",
"href": "docs/reward_modelling.html",
"title": "Reward Modelling",
"section": "",
"text": "Overview\nReward modelling is a technique used to train models to predict the reward or value of a given input. This is particularly useful in reinforcement learning scenarios where the model needs to evaluate the quality of its actions or predictions.\nWe support the reward modelling techniques supported by trl.\n\n\n(Outcome) Reward Models\nOutcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).\nFor improved training stability, you can use the center_rewards_coefficient parameter to encourage mean-zero reward outputs (see TRL docs).\nbase_model: google/gemma-2-2b\nmodel_type: AutoModelForSequenceClassification\nnum_labels: 1\ntokenizer_type: AutoTokenizer\n\nreward_model: true\nchat_template: gemma\ndatasets:\n - path: argilla/distilabel-intel-orca-dpo-pairs\n type: bradley_terry.chat_template\n\nval_set_size: 0.1\neval_steps: 100\nBradley-Terry chat templates expect single-turn conversations in the following format:\n{\n \"system\": \"...\", // optional\n \"input\": \"...\",\n \"chosen\": \"...\",\n \"rejected\": \"...\"\n}\n\n\nProcess Reward Models (PRM)\n\n\n\n\n\n\nTip\n\n\n\nCheck out our PRM blog.\n\n\nProcess reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.\nbase_model: Qwen/Qwen2.5-3B\nmodel_type: AutoModelForTokenClassification\nnum_labels: 2\n\nprocess_reward_model: true\ndatasets:\n - path: trl-lib/math_shepherd\n type: stepwise_supervised\n split: train\n\nval_set_size: 0.1\neval_steps: 100\nPlease see stepwise_supervised for more details on the dataset format.",
"crumbs": [
"How To Guides",
"Reward Modelling"
]
},
{
"objectID": "docs/models/ministral3.html",
"href": "docs/models/ministral3.html",
"title": "Ministral3",
"section": "",
"text": "Ministral3 is a family of open-weight models from MistralAI found on HuggingFace. This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\nPlease see Thinking and Vision for their respective fine-tuning.\nThanks to the team at MistralAI for giving us early access to prepare for these releases.\nNote: This is still experimental given it is based on transformers v5 RC.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral3"
]
},
{
"objectID": "docs/models/ministral3.html#getting-started",
"href": "docs/models/ministral3.html#getting-started",
"title": "Ministral3",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl from source following the installation guide.\nInstall Cut Cross Entropy to reduce training VRAM usage.\nSwap to the Axolotl transformers v5 branch\ncp examples/ministral3/ministral3-3b-qlora.yaml ministral3-3b-qlora.yaml\n\ngit fetch\ngit checkout transformers-v5\n\n# Install packages for transformers v5\npip install -e .\nRun the fine-tuning:\naxolotl train ministral3-3b-qlora.yaml\n\nLet us know how it goes. Happy finetuning! 🚀\n\nTips\n\nWe recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repos files titled SYSTEM_PROMPT.txt.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe text dataset format follows the OpenAI Messages format as seen here.\n\n\n\nThinking\nMinistral3 2512 model supports thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.\n📚 See the Thinking fine-tuning guide →\n\n\nVision\nMinistral3 2512 model also supports vision capabilities.\n📚 See the Vision fine-tuning guide →",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral3"
]
},
{
"objectID": "docs/models/ministral3.html#optimization-guides",
"href": "docs/models/ministral3.html#optimization-guides",
"title": "Ministral3",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral3"
]
},
{
"objectID": "docs/models/ministral3.html#limitations",
"href": "docs/models/ministral3.html#limitations",
"title": "Ministral3",
"section": "Limitations",
"text": "Limitations\nWe only support the mistral-common tokenizer for Supervised Fine-tuning at the moment and for type: chat_template only.\nIn addition, we do not support overriding tokens yet.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral3"
]
},
{
"objectID": "docs/models/ministral3.html#related-resources",
"href": "docs/models/ministral3.html#related-resources",
"title": "Ministral3",
"section": "Related Resources",
"text": "Related Resources\n\nMistralAI Mistral3 Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral3"
]
},
{
"objectID": "docs/models/ministral3.html#future-work",
"href": "docs/models/ministral3.html#future-work",
"title": "Ministral3",
"section": "Future Work",
"text": "Future Work\n\nAdd parity to Preference Tuning, RL, etc.\nAdd parity to other tokenizer configs like overriding tokens.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral3"
]
},
{
"objectID": "docs/models/hunyuan.html",
"href": "docs/models/hunyuan.html",
"title": "Hunyuan",
"section": "",
"text": "Tencent released a family of opensource models called HunYuan with varying parameter scales of 0.5B, 1.8B, 4B, and 7B scale for both Pre-trained and Instruct variants. The models can be found at HuggingFace. This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"Hunyuan"
]
},
{
"objectID": "docs/models/hunyuan.html#getting-started",
"href": "docs/models/hunyuan.html#getting-started",
"title": "Hunyuan",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide. You need to install from main as HunYuan is only on nightly or use our latest Docker images.\nHere is an example of how to install from main for pip:\n\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\n\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn]'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n\nRun the finetuning example:\n\naxolotl train examples/hunyuan/hunyuan-v1-dense-qlora.yaml\nThis config uses about 4.7 GB VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nDataset\nHunYuan Instruct models can choose to enter a slow think or fast think pattern. For best performance on fine-tuning their Instruct models, your dataset should be adjusted to match their pattern.\n# fast think pattern\nmessages = [\n {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n {\"role\": \"user\", \"content\": \"/no_think What color is the sun?\" },\n {\"role\": \"assistant\", \"content\": \"<think>\\n\\n</think>\\n<answer>\\nThe sun is yellow.\\n</answer>\"}\n]\n\n# slow think pattern\nmessages = [\n {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n {\"role\": \"user\", \"content\": \"/no_think What color is the sun?\" },\n {\"role\": \"assistant\", \"content\": \"<think>\\nThe user is asking about the color of the sun. I need to ...\\n</think>\\n<answer>\\nThe sun is yellow.\\n</answer>\"}\n]\n\n\nTIPS\n\nFor inference, the official Tencent team recommends\n\n\n{\n \"do_sample\": true,\n \"top_k\": 20,\n \"top_p\": 0.8,\n \"repetition_penalty\": 1.05,\n \"temperature\": 0.7\n}\n\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Hunyuan"
]
},
{
"objectID": "docs/models/hunyuan.html#optimization-guides",
"href": "docs/models/hunyuan.html#optimization-guides",
"title": "Hunyuan",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training\nLoRA Optimizations",
"crumbs": [
"Getting Started",
"Model Guides",
"Hunyuan"
]
},
{
"objectID": "docs/models/hunyuan.html#related-resources",
"href": "docs/models/hunyuan.html#related-resources",
"title": "Hunyuan",
"section": "Related Resources",
"text": "Related Resources\n\nTencent HunYuan Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Hunyuan"
]
},
{
"objectID": "docs/models/smolvlm2.html",
"href": "docs/models/smolvlm2.html",
"title": "SmolVLM 2",
"section": "",
"text": "SmolVLM2 are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.\nThese models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.\nThis guide shows how to fine-tune SmolVLM2 models with Axolotl.",
"crumbs": [
"Getting Started",
"Model Guides",
"SmolVLM 2"
]
},
{
"objectID": "docs/models/smolvlm2.html#getting-started",
"href": "docs/models/smolvlm2.html#getting-started",
"title": "SmolVLM 2",
"section": "Getting Started",
"text": "Getting Started\n\nInstall Axolotl following the installation guide.\nHere is an example of how to install from pip:\n# Ensure you have a compatible version of Pytorch installed\npip3 install packaging setuptools wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\nInstall an extra dependency:\npip3 install num2words==0.5.14\nRun the finetuning example:\n# LoRA SFT (1x48GB @ 6.8GiB)\naxolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml",
"crumbs": [
"Getting Started",
"Model Guides",
"SmolVLM 2"
]
},
{
"objectID": "docs/models/smolvlm2.html#tips",
"href": "docs/models/smolvlm2.html#tips",
"title": "SmolVLM 2",
"section": "TIPS",
"text": "TIPS\n\nDataset Format: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on Multimodal Formats.\nDataset Loading: Read more on how to prepare and load your own datasets in our documentation.",
"crumbs": [
"Getting Started",
"Model Guides",
"SmolVLM 2"
]
},
{
"objectID": "docs/models/smolvlm2.html#optimization-guides",
"href": "docs/models/smolvlm2.html#optimization-guides",
"title": "SmolVLM 2",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"SmolVLM 2"
]
},
{
"objectID": "docs/models/smolvlm2.html#related-resources",
"href": "docs/models/smolvlm2.html#related-resources",
"title": "SmolVLM 2",
"section": "Related Resources",
"text": "Related Resources\n\nSmolVLM2 Blog\nAxolotl Docs\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"SmolVLM 2"
]
},
{
"objectID": "docs/models/ministral3/vision.html",
"href": "docs/models/ministral3/vision.html",
"title": "Ministral 3 Vision",
"section": "",
"text": "This guide covers fine-tuning Ministral3 2512 with vision capabilities using Axolotl.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral 3 Vision"
]
},
{
"objectID": "docs/models/ministral3/vision.html#prerequisites",
"href": "docs/models/ministral3/vision.html#prerequisites",
"title": "Ministral 3 Vision",
"section": "Prerequisites",
"text": "Prerequisites\nBefore starting, ensure you have:\n\nInstalled Axolotl from source (see main README)",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral 3 Vision"
]
},
{
"objectID": "docs/models/ministral3/vision.html#getting-started",
"href": "docs/models/ministral3/vision.html#getting-started",
"title": "Ministral 3 Vision",
"section": "Getting started",
"text": "Getting started\n\nInstall the required vision lib:\nbash pip install 'mistral-common[opencv]==1.8.6'\nDownload the example dataset image:\nwget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\nRun the fine-tuning:\naxolotl train examples/ministral3/vision/ministral3-3b-vision-qlora.yml\n\nWARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.\n\nTips\nKey differences from text-only model:\n- Multi-modal dataset format required\n- Sample packing not supported",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral 3 Vision"
]
},
{
"objectID": "docs/models/ministral3/vision.html#dataset-format",
"href": "docs/models/ministral3/vision.html#dataset-format",
"title": "Ministral 3 Vision",
"section": "Dataset Format",
"text": "Dataset Format\nThe vision model requires multi-modal dataset format as documented here.\nOne exception is that, passing \"image\": PIL.Image is not supported. MistralTokenizer only supports path, url, and base64 for now.\nExample:\n{\n \"messages\": [\n {\"role\": \"system\", \"content\": [{ \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}]},\n {\"role\": \"user\", \"content\": [\n { \"type\": \"text\", \"text\": \"What's in this image?\"},\n {\"type\": \"image\", \"path\": \"path/to/image.jpg\" }\n ]},\n {\"role\": \"assistant\", \"content\": [{ \"type\": \"text\", \"text\": \"...\" }]},\n ],\n}",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral 3 Vision"
]
},
{
"objectID": "docs/models/ministral3/vision.html#limitations",
"href": "docs/models/ministral3/vision.html#limitations",
"title": "Ministral 3 Vision",
"section": "Limitations",
"text": "Limitations\n\nSample Packing is not supported for multi-modality training currently.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral 3 Vision"
]
},
{
"objectID": "docs/models/voxtral.html",
"href": "docs/models/voxtral.html",
"title": "Voxtral",
"section": "",
"text": "Voxtral is a 3B/24B parameter opensource model from MistralAI found on HuggingFace. This guide shows how to fine-tune it with Axolotl.\nThanks to the team at MistralAI for giving us early access to prepare for this release.",
"crumbs": [
"Getting Started",
"Model Guides",
"Voxtral"
]
},
{
"objectID": "docs/models/voxtral.html#getting-started",
"href": "docs/models/voxtral.html#getting-started",
"title": "Voxtral",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nHere is an example of how to install from pip:\n\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n\nPlease install the below.\n\n# audio\npip3 install librosa==0.11.0\npip3 install 'mistral_common[audio]==1.8.3'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n\nDownload sample dataset files\n\n# for text + audio only\nwget https://huggingface.co/datasets/Nanobit/text-audio-2k-test/resolve/main/En-us-African_elephant.oga\n\nRun the finetuning example:\n\n# text only\naxolotl train examples/voxtral/voxtral-mini-qlora.yml\n\n# text + audio\naxolotl train examples/voxtral/voxtral-mini-audio-qlora.yml\nThese configs use about 4.8 GB VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nTIPS\n\nFor inference, the official MistralAI team recommends temperature: 0.2 and top_p: 0.95 for audio understanding and temperature: 0.0 for transcription.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe text dataset format follows the OpenAI Messages format as seen here.\nThe multimodal dataset format follows the OpenAI multi-content Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Voxtral"
]
},
{
"objectID": "docs/models/voxtral.html#optimization-guides",
"href": "docs/models/voxtral.html#optimization-guides",
"title": "Voxtral",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training\nLoRA Optimizations",
"crumbs": [
"Getting Started",
"Model Guides",
"Voxtral"
]
},
{
"objectID": "docs/models/voxtral.html#limitations",
"href": "docs/models/voxtral.html#limitations",
"title": "Voxtral",
"section": "Limitations",
"text": "Limitations\nWe only support the mistral-common tokenizer for Supervised Fine-tuning at the moment and for type: chat_template only.\nIn addition, we do not support overriding tokens yet.",
"crumbs": [
"Getting Started",
"Model Guides",
"Voxtral"
]
},
{
"objectID": "docs/models/voxtral.html#related-resources",
"href": "docs/models/voxtral.html#related-resources",
"title": "Voxtral",
"section": "Related Resources",
"text": "Related Resources\n\nMistralAI Magistral Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Voxtral"
]
},
{
"objectID": "docs/models/voxtral.html#future-work",
"href": "docs/models/voxtral.html#future-work",
"title": "Voxtral",
"section": "Future Work",
"text": "Future Work\n\nAdd parity to Preference Tuning, RL, etc.\nAdd parity to other tokenizer configs like overriding tokens.",
"crumbs": [
"Getting Started",
"Model Guides",
"Voxtral"
]
},
{
"objectID": "docs/models/ministral.html",
"href": "docs/models/ministral.html",
"title": "Ministral",
"section": "",
"text": "Ministral is a family of openweight models from MistralAI found on HuggingFace. This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral"
]
},
{
"objectID": "docs/models/ministral.html#getting-started",
"href": "docs/models/ministral.html#getting-started",
"title": "Ministral",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nInstall Cut Cross Entropy to reduce training VRAM usage.\nRun the finetuning example:\naxolotl train examples/ministral/ministral-small-qlora.yaml\n\nThis config uses about 8.76 GiB VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nTips\n\nWe recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repos files titled SYSTEM_PROMPT.txt.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe text dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral"
]
},
{
"objectID": "docs/models/ministral.html#optimization-guides",
"href": "docs/models/ministral.html#optimization-guides",
"title": "Ministral",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral"
]
},
{
"objectID": "docs/models/ministral.html#limitations",
"href": "docs/models/ministral.html#limitations",
"title": "Ministral",
"section": "Limitations",
"text": "Limitations\nWe only support the mistral-common tokenizer for Supervised Fine-tuning at the moment and for type: chat_template only.\nIn addition, we do not support overriding tokens yet.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral"
]
},
{
"objectID": "docs/models/ministral.html#related-resources",
"href": "docs/models/ministral.html#related-resources",
"title": "Ministral",
"section": "Related Resources",
"text": "Related Resources\n\nMistralAI Ministral Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral"
]
},
{
"objectID": "docs/models/ministral.html#future-work",
"href": "docs/models/ministral.html#future-work",
"title": "Ministral",
"section": "Future Work",
"text": "Future Work\n\nAdd parity to Preference Tuning, RL, etc.\nAdd parity to other tokenizer configs like overriding tokens.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral"
]
},
{
"objectID": "docs/models/granite4.html",
"href": "docs/models/granite4.html",
"title": "Granite 4",
"section": "",
"text": "Granite 4.0 are a family of open source models trained by IBM Research.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"Granite 4"
]
},
{
"objectID": "docs/models/granite4.html#getting-started",
"href": "docs/models/granite4.html#getting-started",
"title": "Granite 4",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide. You need to install from main as Granite4 is only on nightly or use our latest Docker images.\nHere is an example of how to install from main for pip:\n\n# Ensure you have Pytorch installed (Pytorch 2.7.1 min)\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\n\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn]'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n\nRun the finetuning example:\n\naxolotl train examples/granite4/granite-4.0-tiny-fft.yaml\nThis config uses about 40.8GiB VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nTIPS\n\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.\n\n\n\nLimitation\nAdapter finetuning does not work at the moment. It would error with\nRuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x3072 and 1x1179648)\nIn addition, if adapter training works, lora_target_linear: true will not work due to:\nValueError: Target module GraniteMoeHybridParallelExperts() is not supported.",
"crumbs": [
"Getting Started",
"Model Guides",
"Granite 4"
]
},
{
"objectID": "docs/models/granite4.html#optimization-guides",
"href": "docs/models/granite4.html#optimization-guides",
"title": "Granite 4",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training\nLoRA Optimizations",
"crumbs": [
"Getting Started",
"Model Guides",
"Granite 4"
]
},
{
"objectID": "docs/models/granite4.html#related-resources",
"href": "docs/models/granite4.html#related-resources",
"title": "Granite 4",
"section": "Related Resources",
"text": "Related Resources\n\nGranite Docs\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Granite 4"
]
},
{
"objectID": "docs/models/phi.html",
"href": "docs/models/phi.html",
"title": "Phi",
"section": "",
"text": "Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.\naccelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json\n\n# OR\n\npython -m axolotl.cli.train examples/phi/phi-qlora.yml",
"crumbs": [
"Getting Started",
"Model Guides",
"Phi"
]
},
{
"objectID": "docs/models/internvl3_5.html",
"href": "docs/models/internvl3_5.html",
"title": "InternVL 3.5",
"section": "",
"text": "InternVL 3.5 is a family of powerful vision-language models supporting dynamic resolution and multi-image understanding by OpenGV. It features a ViT-style vision encoder and strong language model backbone for tasks like visual question answering, OCR, and scene text understanding.\nThis guide shows how to fine-tune it with Axolotl.",
"crumbs": [
"Getting Started",
"Model Guides",
"InternVL 3.5"
]
},
{
"objectID": "docs/models/internvl3_5.html#getting-started",
"href": "docs/models/internvl3_5.html#getting-started",
"title": "InternVL 3.5",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nInstall timm for vision model support:\npip install timm==1.0.19\nInstall Cut Cross Entropy to reduce training VRAM usage.\nRun the finetuning example:\naxolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml\n\nThis config uses about 8.21 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀\n\nTips\n\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the multi-modal format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"InternVL 3.5"
]
},
{
"objectID": "docs/models/internvl3_5.html#optimization-guides",
"href": "docs/models/internvl3_5.html#optimization-guides",
"title": "InternVL 3.5",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"InternVL 3.5"
]
},
{
"objectID": "docs/models/internvl3_5.html#related-resources",
"href": "docs/models/internvl3_5.html#related-resources",
"title": "InternVL 3.5",
"section": "Related Resources",
"text": "Related Resources\n\nInternVL Paper\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"InternVL 3.5"
]
},
{
"objectID": "docs/models/magistral/think.html",
"href": "docs/models/magistral/think.html",
"title": "Magistral Thinking",
"section": "",
"text": "This guide covers fine-tuning Magistral Small 2507 with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral Thinking"
]
},
{
"objectID": "docs/models/magistral/think.html#prerequisites",
"href": "docs/models/magistral/think.html#prerequisites",
"title": "Magistral Thinking",
"section": "Prerequisites",
"text": "Prerequisites\nBefore starting, ensure you have:\n\nInstalled Axolotl (see main README)",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral Thinking"
]
},
{
"objectID": "docs/models/magistral/think.html#getting-started",
"href": "docs/models/magistral/think.html#getting-started",
"title": "Magistral Thinking",
"section": "Getting Started",
"text": "Getting Started\nRun the thinking model fine-tuning:\naxolotl train examples/magistral/think/magistral-small-think-qlora.yaml\nThis config uses about 19.1 GiB VRAM.\n\nTips\n\nDataset uses multi-content format with type: thinking support. See Dataset Format below.\nYou cannot mix content: str and content: list[dict], otherwise, dataset loading will fail. Keep it consistent.",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral Thinking"
]
},
{
"objectID": "docs/models/magistral/think.html#dataset-format",
"href": "docs/models/magistral/think.html#dataset-format",
"title": "Magistral Thinking",
"section": "Dataset Format",
"text": "Dataset Format\nThe thinking model requires the multi-content dataset format with support for an extra role: thinking within system and assistant messages.\nExample format:\n{\n \"messages\": [\n {\n \"role\": \"system\",\n \"content\": [\n { \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}\n ]\n },\n {\n \"role\": \"user\",\n \"content\": [\n { \"type\": \"text\", \"text\": \"Solve this step by step: What is 15% of 240?\"}\n ]\n },\n {\n \"role\": \"assistant\",\n \"content\": [\n {\n \"type\": \"thinking\",\n \"thinking\": \"I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36.\"\n },\n {\n \"type\": \"text\",\n \"text\": \"To find 15% of 240, I'll multiply 240 by 0.15:\\n\\n240 × 0.15 = 36\\n\\nTherefore, 15% of 240 is 36.\"\n }\n ]\n }\n ]\n}\n\nAdvanced Options\nThe thinking section supports an optional closed parameter:\n{\n \"type\": \"thinking\",\n \"thinking\": \"Internal reasoning here...\",\n \"closed\": true // Default: true, controls adding the closing [/THINK] tag\n}",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral Thinking"
]
},
{
"objectID": "docs/models/mistral-small.html",
"href": "docs/models/mistral-small.html",
"title": "Mistral Small 3.1/3.2",
"section": "",
"text": "This guide covers fine-tuning Mistral Small 3.1 and Mistral Small 3.2 with vision capabilities using Axolotl.",
"crumbs": [
"Getting Started",
"Model Guides",
"Mistral Small 3.1/3.2"
]
},
{
"objectID": "docs/models/mistral-small.html#prerequisites",
"href": "docs/models/mistral-small.html#prerequisites",
"title": "Mistral Small 3.1/3.2",
"section": "Prerequisites",
"text": "Prerequisites\nBefore starting, ensure you have:\n\nInstalled Axolotl (see Installation docs)",
"crumbs": [
"Getting Started",
"Model Guides",
"Mistral Small 3.1/3.2"
]
},
{
"objectID": "docs/models/mistral-small.html#getting-started",
"href": "docs/models/mistral-small.html#getting-started",
"title": "Mistral Small 3.1/3.2",
"section": "Getting Started",
"text": "Getting Started\n\nInstall the required vision lib:\nbash pip install 'mistral-common[opencv]==1.8.5'\nDownload the example dataset image:\nwget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\nRun the fine-tuning:\naxolotl train examples/mistral/mistral-small/mistral-small-3.1-24B-lora.yml\n\nThis config uses about 29.4 GiB VRAM.",
"crumbs": [
"Getting Started",
"Model Guides",
"Mistral Small 3.1/3.2"
]
},
{
"objectID": "docs/models/mistral-small.html#dataset-format",
"href": "docs/models/mistral-small.html#dataset-format",
"title": "Mistral Small 3.1/3.2",
"section": "Dataset Format",
"text": "Dataset Format\nThe vision model requires multi-modal dataset format as documented here.\nOne exception is that, passing \"image\": PIL.Image is not supported. MistralTokenizer only supports path, url, and base64 for now.\nExample:\n{\n \"messages\": [\n {\"role\": \"system\", \"content\": [{ \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}]},\n {\"role\": \"user\", \"content\": [\n { \"type\": \"text\", \"text\": \"What's in this image?\"},\n {\"type\": \"image\", \"path\": \"path/to/image.jpg\" }\n ]},\n {\"role\": \"assistant\", \"content\": [{ \"type\": \"text\", \"text\": \"...\" }]},\n ],\n}",
"crumbs": [
"Getting Started",
"Model Guides",
"Mistral Small 3.1/3.2"
]
},
{
"objectID": "docs/models/mistral-small.html#limitations",
"href": "docs/models/mistral-small.html#limitations",
"title": "Mistral Small 3.1/3.2",
"section": "Limitations",
"text": "Limitations\n\nSample Packing is not supported for multi-modality training currently.",
"crumbs": [
"Getting Started",
"Model Guides",
"Mistral Small 3.1/3.2"
]
},
{
"objectID": "docs/models/gemma3n.html",
"href": "docs/models/gemma3n.html",
"title": "Gemma 3n",
"section": "",
"text": "Gemma-3n is a family of multimodal models from Google found on HuggingFace. This guide shows how to fine-tune it with Axolotl.",
"crumbs": [
"Getting Started",
"Model Guides",
"Gemma 3n"
]
},
{
"objectID": "docs/models/gemma3n.html#getting-started",
"href": "docs/models/gemma3n.html#getting-started",
"title": "Gemma 3n",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nHere is an example of how to install from pip:\n\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n\nIn addition to Axolotls requirements, Gemma-3n requires:\n\npip3 install timm==1.0.17\n\n# for loading audio data\npip3 install librosa==0.11.0\n\nDownload sample dataset files\n\n# for text + vision + audio only\nwget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg\nwget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga\n\nRun the finetuning example:\n\n# text only\naxolotl train examples/gemma3n/gemma-3n-e2b-qlora.yml\n\n# text + vision\naxolotl train examples/gemma3n/gemma-3n-e2b-vision-qlora.yml\n\n# text + vision + audio\naxolotl train examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml\nLet us know how it goes. Happy finetuning! 🚀\nWARNING: The loss and grad norm will be much higher than normal. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.\n\nTIPS\n\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe text dataset format follows the OpenAI Messages format as seen here.\nThe multimodal dataset format follows the OpenAI multi-content Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Gemma 3n"
]
},
{
"objectID": "docs/models/gemma3n.html#optimization-guides",
"href": "docs/models/gemma3n.html#optimization-guides",
"title": "Gemma 3n",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training\nLoRA Optimizations",
"crumbs": [
"Getting Started",
"Model Guides",
"Gemma 3n"
]
},
{
"objectID": "docs/models/gemma3n.html#related-resources",
"href": "docs/models/gemma3n.html#related-resources",
"title": "Gemma 3n",
"section": "Related Resources",
"text": "Related Resources\n\nGemma 3n Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Gemma 3n"
]
},
{
"objectID": "docs/models/arcee.html",
"href": "docs/models/arcee.html",
"title": "Arcee AFM",
"section": "",
"text": "Arcee Foundation Models (AFM) are a family of 4.5B parameter open weight models trained by Arcee.ai.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\nThanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.",
"crumbs": [
"Getting Started",
"Model Guides",
"Arcee AFM"
]
},
{
"objectID": "docs/models/arcee.html#getting-started",
"href": "docs/models/arcee.html#getting-started",
"title": "Arcee AFM",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide. You need to install from main as AFM is only on nightly or use our latest Docker images.\nHere is an example of how to install from main for pip:\n\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\n\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn]'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n\nRun the finetuning example:\n\naxolotl train examples/arcee/afm-4.5b-qlora.yaml\nThis config uses about 7.8GiB VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nTIPS\n\nFor inference, the official Arcee.ai team recommends top_p: 0.95, temperature: 0.5, top_k: 50, and repeat_penalty: 1.1.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Arcee AFM"
]
},
{
"objectID": "docs/models/arcee.html#optimization-guides",
"href": "docs/models/arcee.html#optimization-guides",
"title": "Arcee AFM",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training\nLoRA Optimizations",
"crumbs": [
"Getting Started",
"Model Guides",
"Arcee AFM"
]
},
{
"objectID": "docs/models/arcee.html#related-resources",
"href": "docs/models/arcee.html#related-resources",
"title": "Arcee AFM",
"section": "Related Resources",
"text": "Related Resources\n\nAFM Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Arcee AFM"
]
},
{
"objectID": "docs/models/llama-2.html",
"href": "docs/models/llama-2.html",
"title": "Llama 2",
"section": "",
"text": "This is an example of a llama-2 configuration for 7b and 13b. The yaml file contains configuration for the 7b variant, but you can just aswell use the same settings for 13b.\nThe 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.\nThe 13b variant will fit if you change these settings to these values:\ngradient_accumulation_steps: 2\nmicro_batch_size: 1\naccelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml\nor\naccelerate launch -m axolotl.cli.train examples/llama-2/lora.yml\nTo launch a full finetuning with 16-bit precision:\naccelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml",
"crumbs": [
"Getting Started",
"Model Guides",
"Llama 2"
]
},
{
"objectID": "docs/models/llama-4.html",
"href": "docs/models/llama-4.html",
"title": "Llama 4",
"section": "",
"text": "While Flash Attention to support is “enabled” for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended.",
"crumbs": [
"Getting Started",
"Model Guides",
"Llama 4"
]
},
{
"objectID": "docs/models/llama-4.html#flash-attention-vs-flex-attention",
"href": "docs/models/llama-4.html#flash-attention-vs-flex-attention",
"title": "Llama 4",
"section": "",
"text": "While Flash Attention to support is “enabled” for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended.",
"crumbs": [
"Getting Started",
"Model Guides",
"Llama 4"
]
},
{
"objectID": "docs/models/llama-4.html#available-examples",
"href": "docs/models/llama-4.html#available-examples",
"title": "Llama 4",
"section": "Available Examples",
"text": "Available Examples\n\nLlama 4 Scout 17Bx16Experts (109B)\nFlex Attention\n- Text Single GPU (H100) QLoRA\n- Text Multi GPU QLoRA w/ FSDP2\nOur Single H100 implementation for Llama 4 Scout uses only 64.5GB VRAM for post-training with 4k context length @ 519 tokens/second. WandB logs here\nMulti-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @ 280tps/gpu, WandB logs here\n\n\nLlama 4 Maverick 17Bx128Experts (400B)\nComing Soon",
"crumbs": [
"Getting Started",
"Model Guides",
"Llama 4"
]
},
{
"objectID": "docs/models/llama-4.html#delinearized-llama-4-models",
"href": "docs/models/llama-4.html#delinearized-llama-4-models",
"title": "Llama 4",
"section": "Delinearized Llama 4 Models",
"text": "Delinearized Llama 4 Models\nWe provide a script to delinearize Llama 4 linearized models into regular HuggingFace Llama 4 models.\naxolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir\nNote: This only works with the non-quantized linearized model. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.",
"crumbs": [
"Getting Started",
"Model Guides",
"Llama 4"
]
},
{
"objectID": "docs/models/seed-oss.html",
"href": "docs/models/seed-oss.html",
"title": "Seed-OSS",
"section": "",
"text": "Seed-OSS are a series of 36B parameter open source models trained by ByteDances Seed Team.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"Seed-OSS"
]
},
{
"objectID": "docs/models/seed-oss.html#getting-started",
"href": "docs/models/seed-oss.html#getting-started",
"title": "Seed-OSS",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nHere is an example of how to install from pip:\n# Ensure you have a compatible version of Pytorch installed\npip3 install packaging setuptools wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n\n# Install Cut Cross Entropy\npython scripts/cutcrossentropy_install.py | sh\nRun the finetuning example:\n\naxolotl train examples/seed-oss/seed-oss-36b-qlora.yaml\nThis config uses about 27.7 GiB VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nTIPS\n\nFor inference, the official Seed Team recommends top_p=0.95 and temperature=1.1.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Seed-OSS"
]
},
{
"objectID": "docs/models/seed-oss.html#optimization-guides",
"href": "docs/models/seed-oss.html#optimization-guides",
"title": "Seed-OSS",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"Seed-OSS"
]
},
{
"objectID": "docs/models/seed-oss.html#related-resources",
"href": "docs/models/seed-oss.html#related-resources",
"title": "Seed-OSS",
"section": "Related Resources",
"text": "Related Resources\n\nByteDance Seed Website\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Seed-OSS"
]
},
{
"objectID": "docs/models/jamba.html",
"href": "docs/models/jamba.html",
"title": "Jamba",
"section": "",
"text": "✅ qlora w/ deepspeed Zero-2 needs at least 2x GPUs and\n\n35GiB VRAM per GPU w minimal context length\n56GiB VRAM per GPU (w multipack enabled)\n\n✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)\n✅ qlora single-gpu, ~51GiB VRAM\n✅ multipack\n✅ FSDP\n❓ 8-bit LoRA",
"crumbs": [
"Getting Started",
"Model Guides",
"Jamba"
]
},
{
"objectID": "docs/nccl.html",
"href": "docs/nccl.html",
"title": "NCCL",
"section": "",
"text": "NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several environment variables. A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:\nWatchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.\nOften, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends disabling PCI access control services (ACS) as a possible solution if this is available to you.\nForcing cross-GPU communication via NVLink may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:\nnvidia-smi nvlink --status\nTo force NCCL to use NVLink, simply set this in the environment:\nexport NCCL_P2P_LEVEL=NVL\nIf NVLink is not available in your environment there are other options for NCCL_P2P_LEVEL in the table below:\n\n\n\n\n\n\n\nNCCL_P2P_LEVEL\nDescription\n\n\n\n\nPIX\nP2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication.\n\n\nPXB\nP2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency.\n\n\nPHB\nP2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL)\n\n\n\nTo validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:\n./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3\nIt can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:\nexport NCCL_DEBUG=INFO\nexport NCCL_DEBUG_SUBSYS=ALL\nexport TORCH_DISTRIBUTED_DEBUG=INFO\nexport TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log\nFinally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ddp_timeout value in the Axolotl configuration. See PyTorch init_process_group for documentation on this value.",
"crumbs": [
"Troubleshooting",
"NCCL"
]
},
{
"objectID": "docs/multipack.html",
"href": "docs/multipack.html",
"title": "Multipack (Sample Packing)",
"section": "",
"text": "Because Flash Attention simply drops the attention mask, we do not need to\nconstruct a 4d attention mask. We only need to concatenate the sequences into\na single batch and let flash attention know where each new sequence begins.\n4k context, bsz =4,\neach character represents 256 tokens\nX represents a padding token\n 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5\n[[ A A A A A A A A A A A ]\n B B B B B B ]\n C C C C C C C ]\n D D D D ]]\n\n[[ E E E E E E E E ]\n [ F F F F ]\n [ G G G ]\n [ H H H H ]]\n\n[[ I I I ]\n [ J J J ]\n [ K K K K K]\n [ L L L ]]\nafter padding to longest input in each step\n 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5\n[[ A A A A A A A A A A A ]\n B B B B B B X X X X X X ]\n C C C C C C C X X X X ]\n D D D D X X X X X X X ]]\n\n[[ E E E E E E E E ]\n [ F F F F X X X X ]\n [ G G G X X X X X ]\n [ H H H H X X X X ]]\n\n[[ I I I X X ]\n [ J J J X X ]\n [ K K K K K ]\n [ L L L X X ]]\nw packing ( note its the same effective number of tokens per step, but a true bsz of 1)\n 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5\n[[ A A A A A A A A A A A B B B B B\n B C C C C C C C D D D D E E E E\n E E E E F F F F F G G G H H H H\n I I I J J J J K K K K K L L L X ]]\ncu_seqlens:\n[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]",
"crumbs": [
"Core Concepts",
"Multipack (Sample Packing)"
]
},
{
"objectID": "docs/multipack.html#visualization-of-multipack-with-flash-attention",
"href": "docs/multipack.html#visualization-of-multipack-with-flash-attention",
"title": "Multipack (Sample Packing)",
"section": "",
"text": "Because Flash Attention simply drops the attention mask, we do not need to\nconstruct a 4d attention mask. We only need to concatenate the sequences into\na single batch and let flash attention know where each new sequence begins.\n4k context, bsz =4,\neach character represents 256 tokens\nX represents a padding token\n 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5\n[[ A A A A A A A A A A A ]\n B B B B B B ]\n C C C C C C C ]\n D D D D ]]\n\n[[ E E E E E E E E ]\n [ F F F F ]\n [ G G G ]\n [ H H H H ]]\n\n[[ I I I ]\n [ J J J ]\n [ K K K K K]\n [ L L L ]]\nafter padding to longest input in each step\n 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5\n[[ A A A A A A A A A A A ]\n B B B B B B X X X X X X ]\n C C C C C C C X X X X ]\n D D D D X X X X X X X ]]\n\n[[ E E E E E E E E ]\n [ F F F F X X X X ]\n [ G G G X X X X X ]\n [ H H H H X X X X ]]\n\n[[ I I I X X ]\n [ J J J X X ]\n [ K K K K K ]\n [ L L L X X ]]\nw packing ( note its the same effective number of tokens per step, but a true bsz of 1)\n 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5\n[[ A A A A A A A A A A A B B B B B\n B C C C C C C C D D D D E E E E\n E E E E F F F F F G G G H H H H\n I I I J J J J K K K K K L L L X ]]\ncu_seqlens:\n[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]",
"crumbs": [
"Core Concepts",
"Multipack (Sample Packing)"
]
},
{
"objectID": "docs/multipack.html#multipack-without-flash-attention",
"href": "docs/multipack.html#multipack-without-flash-attention",
"title": "Multipack (Sample Packing)",
"section": "Multipack without Flash Attention",
"text": "Multipack without Flash Attention\nMultipack can still be achieved without Flash attention, but with lower packing\nefficiency as we are not able to join multiple batches into a single batch due to\ncontext length limits without flash attention. We can use either Pytorchs Scaled\nDot Product Attention implementation or native Pytorch attention implementation\nalong with 4d attention masks\nto pack sequences together and avoid cross attention.",
"crumbs": [
"Core Concepts",
"Multipack (Sample Packing)"
]
},
{
"objectID": "docs/debugging.html",
"href": "docs/debugging.html",
"title": "Debugging",
"section": "",
"text": "This document provides some tips and tricks for debugging Axolotl. It also provides an example configuration for debugging with VSCode. A good debugging setup is essential to understanding how Axolotl code works behind the scenes.",
"crumbs": [
"Troubleshooting",
"Debugging"
]
},
{
"objectID": "docs/debugging.html#table-of-contents",
"href": "docs/debugging.html#table-of-contents",
"title": "Debugging",
"section": "Table of Contents",
"text": "Table of Contents\n\nGeneral Tips\nDebugging with VSCode\n\nBackground\nConfiguration\nCustomizing your debugger\nVideo Tutorial\n\nDebugging With Docker\n\nSetup\nAttach To Container\nVideo - Attaching To Docker On Remote Host",
"crumbs": [
"Troubleshooting",
"Debugging"
]
},
{
"objectID": "docs/debugging.html#general-tips",
"href": "docs/debugging.html#general-tips",
"title": "Debugging",
"section": "General Tips",
"text": "General Tips\nWhile debugging its helpful to simplify your test scenario as much as possible. Here are some tips for doing so:\n\n[!Important]\nAll of these tips are incorporated into the example configuration for debugging with VSCode below.\n\n\nMake sure you are using the latest version of axolotl: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from main.\nEliminate concurrency: Restrict the number of processes to 1 for both training and data preprocessing:\n\nSet CUDA_VISIBLE_DEVICES to a single GPU, ex: export CUDA_VISIBLE_DEVICES=0.\nSet dataset_num_proc: 1 in your axolotl config or run the training command with --dataset_num_proc=1.\n\nUse a small dataset: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure sample_packing: False and eval_sample_packing: False to avoid errors. If you are in a pinch and dont have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):\ndatasets:\n ...\n shards: 20\nUse a small model: A good example of a small model is TinyLlama/TinyLlama-1.1B-Chat-v1.0.\nMinimize iteration time: Make sure the training loop finishes as fast as possible, with these settings.\n\nmicro_batch_size: 1\nmax_steps: 1\nval_set_size: 0\n\nClear Caches: Axolotl caches certain steps and so does the underlying HuggingFace trainer. You may want to clear some of these caches when debugging.\n\nData preprocessing: When debugging data preprocessing, which includes prompt template formation, you may want to delete the directory set in dataset_prepared_path: in your axolotl config. If you didnt set this value, the default is last_run_prepared.\nHF Hub: If you are debugging data preprocessing, you should clear the relevant HF cache HuggingFace cache, by deleting the appropriate ~/.cache/huggingface/datasets/... folder(s).\nThe recommended approach is to redirect all outputs and caches to a temporary folder and delete selected subfolders before each run. This is demonstrated in the example configuration below.",
"crumbs": [
"Troubleshooting",
"Debugging"
]
},
{
"objectID": "docs/debugging.html#debugging-with-vscode",
"href": "docs/debugging.html#debugging-with-vscode",
"title": "Debugging",
"section": "Debugging with VSCode",
"text": "Debugging with VSCode\n\nBackground\nThe below example shows how to configure VSCode to debug data preprocessing of the chat_template format. This is the format used when you have the following in your axolotl config:\ndatasets:\n - path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test\n type: chat_template\n\n[!Important]\nIf you are already familiar with advanced VSCode debugging, you can skip the below explanation and look at the files .vscode/launch.json and .vscode/tasks.json for an example configuration.\n\n\n[!Tip]\nIf you prefer to watch a video, rather than read, you can skip to the video tutorial below (but doing both is recommended).\n\n\n\nSetup\nMake sure you have an editable install of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:\npip3 install packaging\npip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'\n\nRemote Hosts\nIf you developing on a remote host, you can easily use VSCode to debug remotely. To do so, you will need to follow this remote - SSH guide. You can also see the video below on Docker and Remote SSH debugging.\n\n\n\nConfiguration\nThe easiest way to get started is to modify the .vscode/launch.json file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.\nFor example, to mimic the command cd devtools && CUDA_VISIBLE_DEVICES=0 axolotl train dev_chat_template.yml, you would use the below configuration1. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to devtools and set the env variable HF_HOME to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.\n// .vscode/launch.json\n{\n \"version\": \"0.2.0\",\n \"configurations\": [\n {\n \"name\": \"Debug axolotl prompt - chat_template\",\n \"type\": \"python\",\n \"module\": \"accelerate.commands.launch\",\n \"request\": \"launch\",\n \"args\": [\n \"-m\", \"axolotl.cli.train\", \"dev_chat_template.yml\",\n // The flags below simplify debugging by overriding the axolotl config\n // with the debugging tips above. Modify as needed.\n \"--dataset_num_proc=1\", // limits data preprocessing to one process\n \"--max_steps=1\", // limits training to just one step\n \"--batch_size=1\", // minimizes batch size\n \"--micro_batch_size=1\", // minimizes batch size\n \"--val_set_size=0\", // disables validation\n \"--sample_packing=False\", // disables sample packing which is necessary for small datasets\n \"--eval_sample_packing=False\",// disables sample packing on eval set\n \"--dataset_prepared_path=temp_debug/axolotl_outputs/data\", // send data outputs to a temp folder\n \"--output_dir=temp_debug/axolotl_outputs/model\" // send model outputs to a temp folder\n ],\n \"console\": \"integratedTerminal\", // show output in the integrated terminal\n \"cwd\": \"${workspaceFolder}/devtools\", // set working directory to devtools from the root of the project\n \"justMyCode\": true, // step through only axolotl code\n \"env\": {\"CUDA_VISIBLE_DEVICES\": \"0\", // Since we aren't doing distributed training, we need to limit to one GPU\n \"HF_HOME\": \"${workspaceFolder}/devtools/temp_debug/.hf-cache\"}, // send HF cache to a temp folder\n \"preLaunchTask\": \"cleanup-for-dataprep\", // delete temp folders (see below)\n }\n ]\n}\nAdditional notes about this configuration:\n\nThe argument justMyCode is set to true such that you step through only the axolotl code. If you want to step into dependencies, set this to false.\nThe preLaunchTask: cleanup-for-dataprep is defined in .vscode/tasks.json and is used to delete the following folders before debugging, which is essential to ensure that the data pre-processing code is run from scratch:\n\n./devtools/temp_debug/axolotl_outputs\n./devtools/temp_debug/.hf-cache/datasets\n\n\n\n[!Tip]\nYou may not want to delete these folders. For example, if you are debugging model training instead of data pre-processing, you may NOT want to delete the cache or output folders. You may also need to add additional tasks to the tasks.json file depending on your use case.\n\nBelow is the ./vscode/tasks.json file that defines the cleanup-for-dataprep task. This task is run before each debugging session when you use the above configuration. Note how there are two tasks that delete the two folders mentioned above. The third task cleanup-for-dataprep is a composite task that combines the two tasks. A composite task is necessary because VSCode does not allow you to specify multiple tasks in the preLaunchTask argument of the launch.json file.\n// .vscode/tasks.json\n// this file is used by launch.json\n{\n \"version\": \"2.0.0\",\n \"tasks\": [\n // this task changes into the devtools directory and deletes the temp_debug/axolotl_outputs folder\n {\n \"label\": \"delete-outputs\",\n \"type\": \"shell\",\n \"command\": \"rm -rf temp_debug/axolotl_outputs\",\n \"options\":{ \"cwd\": \"${workspaceFolder}/devtools\"},\n \"problemMatcher\": []\n },\n // this task changes into the devtools directory and deletes the `temp_debug/.hf-cache/datasets` folder\n {\n \"label\": \"delete-temp-hf-dataset-cache\",\n \"type\": \"shell\",\n \"command\": \"rm -rf temp_debug/.hf-cache/datasets\",\n \"options\":{ \"cwd\": \"${workspaceFolder}/devtools\"},\n \"problemMatcher\": []\n },\n // this task combines the two tasks above\n {\n \"label\": \"cleanup-for-dataprep\",\n \"dependsOn\": [\"delete-outputs\", \"delete-temp-hf-dataset-cache\"],\n }\n ]\n}\n\n\nCustomizing your debugger\nYour debugging use case may differ from the example above. The easiest thing to do is to put your own axolotl config in the devtools folder and modify the launch.json file to use your config. You may also want to modify the preLaunchTask to delete different folders or not delete anything at all.\n\n\nVideo Tutorial\nThe following video tutorial walks through the above configuration and demonstrates how to debug with VSCode, (click the image below to watch):\n\n\n\nHamel Husains tutorial: Debugging Axolotl w/VSCode",
"crumbs": [
"Troubleshooting",
"Debugging"
]
},
{
"objectID": "docs/debugging.html#debugging-with-docker",
"href": "docs/debugging.html#debugging-with-docker",
"title": "Debugging",
"section": "Debugging With Docker",
"text": "Debugging With Docker\nUsing official Axolotl Docker images is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.\n\nSetup\nOn the host that is running axolotl (ex: if you are using a remote host), clone the axolotl repo and change your current directory to the root:\ngit clone https://github.com/axolotl-ai-cloud/axolotl\ncd axolotl\n\n[!Tip]\nIf you already have axolotl cloned on your host, make sure you have the latest changes and change into the root of the project.\n\nNext, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:2\ndocker run --privileged --gpus '\"all\"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src=\"${PWD}\",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1\n\n[!Tip]\nTo understand which containers are available, see the Docker section of the README and the DockerHub repo. For details of how the Docker containers are built, see axolotls Docker CI builds.\n\nYou will now be in the container. Next, perform an editable install of Axolotl:\npip3 install packaging\npip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'\n\n\nAttach To Container\nNext, if you are using a remote host, Remote into this host with VSCode. If you are using a local host, you can skip this step.\nNext, select Dev Containers: Attach to Running Container... using the command palette (CMD + SHIFT + P) in VSCode. You will be prompted to select a container to attach to. Select the container you just created. You will now be in the container with a working directory that is at the root of the project. Any changes you make to the code will be reflected both in the container and on the host.\nNow you are ready to debug as described above (see Debugging with VSCode).\n\n\nVideo - Attaching To Docker On Remote Host\nHere is a short video that demonstrates how to attach to a Docker container on a remote host:\n\n\n\nHamel Husains tutorial: Debugging Axolotl Part 2: Attaching to Docker on a Remote Host",
"crumbs": [
"Troubleshooting",
"Debugging"
]
},
{
"objectID": "docs/debugging.html#footnotes",
"href": "docs/debugging.html#footnotes",
"title": "Debugging",
"section": "Footnotes",
"text": "Footnotes\n\n\nThe VSCode config uses accelerate.commands.launch as the Python module entry point, which is what axolotl train invokes under the hood.↩︎\nMany of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags here.↩︎",
"crumbs": [
"Troubleshooting",
"Debugging"
]
},
{
"objectID": "docs/dataset_preprocessing.html",
"href": "docs/dataset_preprocessing.html",
"title": "Dataset Preprocessing",
"section": "",
"text": "Dataset pre-processing is the step where Axolotl takes each dataset youve configured alongside\nthe dataset format and prompt strategies to:\n\nparse the dataset based on the dataset format\ntransform the dataset to how you would interact with the model based on the prompt strategy\ntokenize the dataset based on the configured model & tokenizer\nshuffle and merge multiple datasets together if using more than one\n\nThe processing of the datasets can happen one of two ways:\n\nBefore kicking off training by calling axolotl preprocess config.yaml --debug\nWhen training is started\n\n\n\nWhen training interactively or for sweeps\n(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly\nslow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent\ntraining parameters so that it will intelligently pull from its cache when possible.\nThe path of the cache is controlled by dataset_prepared_path: and is often left blank in example\nYAMLs as this leads to a more robust solution that prevents unexpectedly reusing cached data.\nIf dataset_prepared_path: is left empty, when training, the processed dataset will be cached in a\ndefault path of ./last_run_prepared/, but will ignore anything already cached there. By explicitly\nsetting dataset_prepared_path: ./last_run_prepared, the trainer will use whatever pre-processed\ndata is in the cache.\n\n\n\nLets say you are writing a custom prompt strategy or using a user-defined\nprompt template. Because the trainer cannot readily detect these changes, we cannot change the\ncalculated hash value for the pre-processed dataset.\nIf you have dataset_prepared_path: ... set\nand change your prompt templating logic, it may not pick up the changes you made and you will be\ntraining over the old prompt.",
"crumbs": [
"Core Concepts",
"Dataset Preprocessing"
]
},
{
"objectID": "docs/dataset_preprocessing.html#overview",
"href": "docs/dataset_preprocessing.html#overview",
"title": "Dataset Preprocessing",
"section": "",
"text": "Dataset pre-processing is the step where Axolotl takes each dataset youve configured alongside\nthe dataset format and prompt strategies to:\n\nparse the dataset based on the dataset format\ntransform the dataset to how you would interact with the model based on the prompt strategy\ntokenize the dataset based on the configured model & tokenizer\nshuffle and merge multiple datasets together if using more than one\n\nThe processing of the datasets can happen one of two ways:\n\nBefore kicking off training by calling axolotl preprocess config.yaml --debug\nWhen training is started\n\n\n\nWhen training interactively or for sweeps\n(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly\nslow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent\ntraining parameters so that it will intelligently pull from its cache when possible.\nThe path of the cache is controlled by dataset_prepared_path: and is often left blank in example\nYAMLs as this leads to a more robust solution that prevents unexpectedly reusing cached data.\nIf dataset_prepared_path: is left empty, when training, the processed dataset will be cached in a\ndefault path of ./last_run_prepared/, but will ignore anything already cached there. By explicitly\nsetting dataset_prepared_path: ./last_run_prepared, the trainer will use whatever pre-processed\ndata is in the cache.\n\n\n\nLets say you are writing a custom prompt strategy or using a user-defined\nprompt template. Because the trainer cannot readily detect these changes, we cannot change the\ncalculated hash value for the pre-processed dataset.\nIf you have dataset_prepared_path: ... set\nand change your prompt templating logic, it may not pick up the changes you made and you will be\ntraining over the old prompt.",
"crumbs": [
"Core Concepts",
"Dataset Preprocessing"
]
},
{
"objectID": "docs/vllm_serving.html",
"href": "docs/vllm_serving.html",
"title": "vLLM Serving for GRPO Training",
"section": "",
"text": "GRPO (Group Relative Policy Optimization) trains a language model by generating completions, scoring them with reward functions, and updating the policy to favor higher-reward outputs. The generation step is the bottleneck: producing thousands of tokens per training step with the policy model is slow using standard HuggingFace generation.\nAxolotl uses vLLM as a high-throughput generation backend. vLLM runs as a separate process (either on a dedicated GPU or colocated on the training GPU) and serves completions via an HTTP API. The trainer sends prompts to vLLM, receives completions, scores them, and performs gradient updates.\n┌──────────────────────┐ HTTP ┌──────────────────────┐\n│ Trainer (GPU 1) │ ───────────────── │ vLLM Server (GPU 0)│\n│ │ prompts/compls │ │\n│ - Policy model │ ◄──────────────── │ - Same base model │\n│ - Reward scoring │ │ - Fast generation │\n│ - Gradient updates │ weight sync │ - LoRA adapter │\n│ - LoRA adapter │ ─────────────────►│ (periodically │\n│ │ (every N steps) │ updated) │\n└──────────────────────┘ └──────────────────────┘\n\n\n\n\n\n\nImportant\n\n\n\nvLLM must serve the same base model specified in your training config. If the models do not match, weight synchronization will silently produce incorrect results.",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/vllm_serving.html#sec-overview",
"href": "docs/vllm_serving.html#sec-overview",
"title": "vLLM Serving for GRPO Training",
"section": "",
"text": "GRPO (Group Relative Policy Optimization) trains a language model by generating completions, scoring them with reward functions, and updating the policy to favor higher-reward outputs. The generation step is the bottleneck: producing thousands of tokens per training step with the policy model is slow using standard HuggingFace generation.\nAxolotl uses vLLM as a high-throughput generation backend. vLLM runs as a separate process (either on a dedicated GPU or colocated on the training GPU) and serves completions via an HTTP API. The trainer sends prompts to vLLM, receives completions, scores them, and performs gradient updates.\n┌──────────────────────┐ HTTP ┌──────────────────────┐\n│ Trainer (GPU 1) │ ───────────────── │ vLLM Server (GPU 0)│\n│ │ prompts/compls │ │\n│ - Policy model │ ◄──────────────── │ - Same base model │\n│ - Reward scoring │ │ - Fast generation │\n│ - Gradient updates │ weight sync │ - LoRA adapter │\n│ - LoRA adapter │ ─────────────────►│ (periodically │\n│ │ (every N steps) │ updated) │\n└──────────────────────┘ └──────────────────────┘\n\n\n\n\n\n\nImportant\n\n\n\nvLLM must serve the same base model specified in your training config. If the models do not match, weight synchronization will silently produce incorrect results.",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/vllm_serving.html#sec-server-mode",
"href": "docs/vllm_serving.html#sec-server-mode",
"title": "vLLM Serving for GRPO Training",
"section": "2 Server Mode",
"text": "2 Server Mode\nServer mode runs vLLM as an external process on dedicated GPU(s). This is the recommended configuration for most setups.\n\n2.1 Starting the Server\nUse the axolotl vllm-serve command with your training config:\n# Terminal 1: Start vLLM on GPU 0\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml\n# Terminal 2: Start training on GPU 1\nCUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml\nThe server reads vLLM settings from the vllm: section of your config and starts an HTTP server (default: http://0.0.0.0:8000).\n\n\n\n\n\n\nTip\n\n\n\nUse tmux or screen to manage the vLLM server process. Typical startup time is 30-90 seconds depending on model size and whether CUDA graphs are captured.\n\n\n\n\n2.2 Minimal Server Config\nbase_model: Qwen/Qwen2.5-1.5B-Instruct\n\nvllm:\n host: 0.0.0.0\n port: 8000\n gpu_memory_utilization: 0.85\n dtype: auto\n max_model_len: 4096\n\nrl: grpo\ntrl:\n use_vllm: true\n vllm_server_host: 0.0.0.0\n vllm_server_port: 8000\n vllm_server_timeout: 300\n\n\n2.3 Multi-GPU vLLM\nFor larger models, use tensor parallelism across multiple GPUs:\nvllm:\n tensor_parallel_size: 2\n gpu_memory_utilization: 0.85\n# vLLM on GPUs 2,3; training on GPUs 0,1\nCUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo_config.yaml\nCUDA_VISIBLE_DEVICES=0,1 axolotl train grpo_config.yaml --num-processes 2\n\n\n\n\n\n\nNote\n\n\n\nDue to how TRL maps vLLM device indices, the vLLM instance should use the last N GPUs (highest device indices), while training uses the first N.",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/vllm_serving.html#sec-colocate-mode",
"href": "docs/vllm_serving.html#sec-colocate-mode",
"title": "vLLM Serving for GRPO Training",
"section": "3 Colocate Mode",
"text": "3 Colocate Mode\nColocate mode runs vLLM on the same GPU as the trainer. This is useful when you only have a single GPU.\ntrl:\n use_vllm: true\n vllm_mode: colocate\n vllm_enable_sleep_mode: true\nWith vllm_enable_sleep_mode: true, vLLM offloads its VRAM allocation when not actively generating, freeing memory for training. When the trainer needs new completions, vLLM wakes up and reclaims VRAM.\n\n\n\n\n\n\nWarning\n\n\n\nColocate mode is significantly slower than server mode because generation and training cannot overlap. The GPU alternates between the two workloads. This mode is practical only for smaller models (up to ~3B on a 24 GB GPU).\n\n\nWhen to use colocate mode:\n\nYou have exactly one GPU\nThe model fits in memory with both vLLM and training active (with sleep mode), or is small enough to time-share\nYou accept the performance tradeoff for simpler setup (no separate vLLM process to manage)\n\nWhen to use server mode:\n\nYou have two or more GPUs\nYou want maximum throughput (generation overlaps with training via async prefetch)\nYou are running larger models (7B+)",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/vllm_serving.html#sec-lora-sync",
"href": "docs/vllm_serving.html#sec-lora-sync",
"title": "vLLM Serving for GRPO Training",
"section": "4 LoRA Sync",
"text": "4 LoRA Sync\nLoRA sync is the recommended weight synchronization method when training with LoRA adapters. Instead of merging adapter weights into the base model and broadcasting the full merged weights over NCCL, it saves only the LoRA adapter files to the filesystem and tells vLLM to load them natively.\n\n4.1 How It Works\n\nThe trainer calls model.save_pretrained() to write the LoRA adapter weights to a temporary directory\nThe trainer sends an HTTP POST to /set_lora_adapter/ on the vLLM server\nvLLM loads the adapter using its native LoRA support (Punica kernels)\nGeneration uses the updated adapter on the next request\n\n\n\n4.2 Benefits\n\nSmaller sync payload: Transfers ~40 MB of LoRA weights instead of ~1.4 GB+ of merged model weights (for a typical 0.5-3B model)\nNo NCCL communicator: Eliminates the need for a cross-GPU NCCL communication channel, removing GPU contention between vLLM generation and weight sync\nFaster sync: ~200 ms per sync vs. 350 ms to 5+ seconds for NCCL merge sync\nSimpler multi-GPU: No need to set up NCCL groups between trainer and vLLM processes\n\n\n\n4.3 Configuration\nadapter: lora\nlora_r: 32\nlora_alpha: 64\nlora_target_linear: true\n\ntrl:\n vllm_lora_sync: true # Enables LoRA sync mode\n vllm_sync_interval: 5 # Sync every 5 training steps\nSetting vllm_lora_sync: true automatically selects the LoRA-aware vLLM serve script (axolotl.scripts.vllm_serve_lora). You do not need to set vllm.serve_module manually.\n\n\n\n\n\n\nImportant\n\n\n\nLoRA sync requires that you are training with a LoRA adapter (adapter: lora or adapter: qlora). It is not applicable to full fine-tuning.",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/vllm_serving.html#sec-weight-sync",
"href": "docs/vllm_serving.html#sec-weight-sync",
"title": "vLLM Serving for GRPO Training",
"section": "5 Weight Synchronization",
"text": "5 Weight Synchronization\nDuring GRPO training, the policy model on the trainer is continuously updated via gradient steps. The vLLM server, however, still holds the old weights. Periodically, the trainer must push updated weights to vLLM so that future generations reflect the improved policy.\n\n5.1 Sync Interval\nThe vllm_sync_interval parameter controls how often weights are synced:\ntrl:\n vllm_sync_interval: 5 # Sync every 5 optimizer steps\nTradeoffs:\n\nLower interval (e.g., 1-3): Fresher generations, better on-policy data, but more sync overhead per step\nHigher interval (e.g., 5-10): Less overhead, but generations become increasingly off-policy between syncs\nRecommended: 3-5 for most setups. Axolotl includes importance sampling correction (vllm_importance_sampling_correction: true) to handle mild distribution mismatch from stale vLLM weights.\n\n\n\n5.2 Sync Methods\n\n\n\n\n\n\n\n\n\n\nMethod\nConfig\nPayload\nMechanism\nTypical Time\n\n\n\n\nLoRA sync\nvllm_lora_sync: true\nLoRA adapter only (~40 MB)\nFilesystem + HTTP\n~200 ms\n\n\nNCCL merge sync\nDefault (no lora_sync)\nFull merged weights (~1.4 GB+)\nHTTP trigger + NCCL broadcast\n350 ms - 5 s\n\n\n\n\n\n\n\n\n\nTip\n\n\n\nIf you are training with LoRA (which is recommended for GRPO), always enable vllm_lora_sync: true. The performance difference is substantial, especially as training progresses and NCCL contention increases.\n\n\n\n\n5.3 Importance Sampling Correction\nWhen vLLM weights are stale (between syncs), the generated data is slightly off-policy. Axolotl can correct for this:\ntrl:\n vllm_importance_sampling_correction: true\n importance_sampling_level: token # 'token' or 'sequence'\n off_policy_mask_threshold: 0.5 # KL threshold for masking stale sequences\n\nToken-level IS is recommended when using Liger kernel (sequence-level has numerical issues with chunked computation)\nOff-policy sequence masking (OPSM) drops sequences that have diverged too far from the current policy, providing a safety net against stale data",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/vllm_serving.html#sec-restart",
"href": "docs/vllm_serving.html#sec-restart",
"title": "vLLM Serving for GRPO Training",
"section": "6 Restart Requirements",
"text": "6 Restart Requirements\n\n\n\n\n\n\nWarning\n\n\n\nvLLM must be restarted between training runs. Weight syncs from a previous run leave the server in a corrupted state. If you start a new training run against a stale vLLM server, the model may fail to learn.\n\n\n\n6.1 When to Restart\n\nBefore every new training experiment\nAfter a training run crashes or is interrupted\nIf you change the base model in your config\n\n\n\n6.2 How to Restart\nKilling vLLM reliably requires terminating both the main process and its background EngineCore subprocess:\n# Kill all vLLM-related processes\npkill -9 -f \"vllm|EngineCore\"\n\n# Verify GPU memory is freed\nnvidia-smi\n\n# Restart the server\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml\n\n\n\n\n\n\nTip\n\n\n\nA single kill often does not fully stop vLLM. Always use kill -9 and verify with nvidia-smi that GPU memory has been released before restarting.\n\n\n\n\n6.3 Health Check\nThe vLLM server exposes a health endpoint. Wait for it to return 200 before starting training:\n# For the LoRA serve script (trailing slash required)\ncurl http://localhost:8000/health/\n\n# For the default TRL serve script\ncurl http://localhost:8000/health",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/vllm_serving.html#sec-config-reference",
"href": "docs/vllm_serving.html#sec-config-reference",
"title": "vLLM Serving for GRPO Training",
"section": "7 Configuration Reference",
"text": "7 Configuration Reference\n\n7.1 vLLM Server Options (vllm: section)\nThese control the vLLM server process started by axolotl vllm-serve.\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\nhost\nstr\n0.0.0.0\nHost address for the vLLM server\n\n\nport\nint\n8000\nPort for the vLLM server\n\n\ndevice\nstr\nauto\nDevice to use for vLLM\n\n\ntensor_parallel_size\nint\nNone\nNumber of GPUs for tensor parallelism\n\n\ndata_parallel_size\nint\nNone\nNumber of data parallel replicas\n\n\ngpu_memory_utilization\nfloat\n0.9\nFraction of GPU memory for vLLM (0.0-1.0)\n\n\ndtype\nstr\nauto\nData type (auto, float16, bfloat16)\n\n\nmax_model_len\nint\nNone\nMaximum model context length. Set explicitly if the default is too large for your GPU\n\n\nenable_prefix_caching\nbool\nNone\nEnable prefix caching for repeated prompt prefixes\n\n\nenable_reasoning\nbool\nNone\nEnable reasoning mode for models with thinking tokens\n\n\nreasoning_parser\nstr\nNone\nParser for reasoning output\n\n\nenforce_eager\nbool\nNone\nDisable CUDA graph capture (required for some architectures like Qwen3.5 hybrid attention)\n\n\nserve_module\nstr\nNone\nPython module for vLLM serve script. Auto-set when vllm_lora_sync: true\n\n\nworker_extension_cls\nstr\nNone\nvLLM worker extension class for weight sync\n\n\n\n\n\n7.2 Trainer vLLM Options (trl: section)\nThese control how the trainer interacts with vLLM.\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\nuse_vllm\nbool\nfalse\nEnable vLLM for generation\n\n\nvllm_mode\nstr\nNone\nserver (external process) or colocate (same GPU)\n\n\nvllm_server_host\nstr\n0.0.0.0\nHost of the vLLM server to connect to\n\n\nvllm_server_port\nint\n8000\nPort of the vLLM server to connect to\n\n\nvllm_server_timeout\nint\nNone\nTimeout in seconds for vLLM requests\n\n\nvllm_lora_sync\nbool\nfalse\nSync LoRA adapters via filesystem instead of NCCL merge\n\n\nvllm_sync_interval\nint\nNone\nSync weights every N optimizer steps\n\n\nvllm_enable_sleep_mode\nbool\nNone\nOffload vLLM VRAM when idle (colocate mode)\n\n\nvllm_guided_decoding_regex\nstr\nNone\nRegex constraint for guided decoding\n\n\n\nFor async pipeline and off-policy correction options, see the GRPO Configuration Reference.",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/vllm_serving.html#sec-complete-example",
"href": "docs/vllm_serving.html#sec-complete-example",
"title": "vLLM Serving for GRPO Training",
"section": "8 Complete Example",
"text": "8 Complete Example\nFor a full working GRPO config including vLLM, LoRA sync, async generation, rewards, and dataset setup, see the GRPO Quick Start. That config includes all the vLLM settings covered in this guide.\n# Terminal 1: Start vLLM\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml\n\n# Wait for health check to pass\ncurl http://localhost:8000/health/\n\n# Terminal 2: Start training\nCUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/vllm_serving.html#sec-troubleshooting",
"href": "docs/vllm_serving.html#sec-troubleshooting",
"title": "vLLM Serving for GRPO Training",
"section": "9 Troubleshooting",
"text": "9 Troubleshooting\n\n\n\n\n\n\n\n\nProblem\nLikely Cause\nSolution\n\n\n\n\nTraining hangs waiting for vLLM\nServer not started or wrong port\nCheck curl http://localhost:8000/health/ and verify vllm_server_host/vllm_server_port match\n\n\nOOM on vLLM GPU\ngpu_memory_utilization too high or max_model_len too large\nReduce gpu_memory_utilization to 0.7 or set max_model_len explicitly\n\n\nOOM on training GPU\nBatch too large for policy logprobs\nReduce micro_batch_size or num_generations\n\n\nAccuracy stays at zero\nStale vLLM from previous run\nRestart vLLM: pkill -9 -f \"vllm\\|EngineCore\", verify with nvidia-smi, restart\n\n\nResponseValidationError from vLLM\nMissing logprobs in response\nEnsure you are using the correct serve module (auto-selected with vllm_lora_sync: true)\n\n\nWeight sync takes 5+ seconds\nNCCL contention with vLLM generation\nSwitch to vllm_lora_sync: true to eliminate NCCL\n\n\nasync_prefetch deadlocks with FSDP\nBackground threads run unsynchronized FSDP collectives\nSet async_prefetch: false when using FSDP or DeepSpeed multi-GPU",
"crumbs": [
"How To Guides",
"vLLM Serving for GRPO Training"
]
},
{
"objectID": "docs/optimizers.html",
"href": "docs/optimizers.html",
"title": "Optimizers",
"section": "",
"text": "Axolotl supports all optimizers supported by transformers OptimizerNames\nHere is a list of optimizers supported by transformers as of v4.54.0:\n\nadamw_torch\nadamw_torch_fused\nadamw_torch_xla\nadamw_torch_npu_fused\nadamw_apex_fused\nadafactor\nadamw_anyprecision\nadamw_torch_4bit\nadamw_torch_8bit\nademamix\nsgd\nadagrad\nadamw_bnb_8bit\nadamw_8bit # alias for adamw_bnb_8bit\nademamix_8bit\nlion_8bit\nlion_32bit\npaged_adamw_32bit\npaged_adamw_8bit\npaged_ademamix_32bit\npaged_ademamix_8bit\npaged_lion_32bit\npaged_lion_8bit\nrmsprop\nrmsprop_bnb\nrmsprop_bnb_8bit\nrmsprop_bnb_32bit\ngalore_adamw\ngalore_adamw_8bit\ngalore_adafactor\ngalore_adamw_layerwise\ngalore_adamw_8bit_layerwise\ngalore_adafactor_layerwise\nlomo\nadalomo\ngrokadamw\nschedule_free_radam\nschedule_free_adamw\nschedule_free_sgd\napollo_adamw\napollo_adamw_layerwise\nstable_adamw",
"crumbs": [
"Core Concepts",
"Optimizers"
]
},
{
"objectID": "docs/optimizers.html#overview",
"href": "docs/optimizers.html#overview",
"title": "Optimizers",
"section": "",
"text": "Axolotl supports all optimizers supported by transformers OptimizerNames\nHere is a list of optimizers supported by transformers as of v4.54.0:\n\nadamw_torch\nadamw_torch_fused\nadamw_torch_xla\nadamw_torch_npu_fused\nadamw_apex_fused\nadafactor\nadamw_anyprecision\nadamw_torch_4bit\nadamw_torch_8bit\nademamix\nsgd\nadagrad\nadamw_bnb_8bit\nadamw_8bit # alias for adamw_bnb_8bit\nademamix_8bit\nlion_8bit\nlion_32bit\npaged_adamw_32bit\npaged_adamw_8bit\npaged_ademamix_32bit\npaged_ademamix_8bit\npaged_lion_32bit\npaged_lion_8bit\nrmsprop\nrmsprop_bnb\nrmsprop_bnb_8bit\nrmsprop_bnb_32bit\ngalore_adamw\ngalore_adamw_8bit\ngalore_adafactor\ngalore_adamw_layerwise\ngalore_adamw_8bit_layerwise\ngalore_adafactor_layerwise\nlomo\nadalomo\ngrokadamw\nschedule_free_radam\nschedule_free_adamw\nschedule_free_sgd\napollo_adamw\napollo_adamw_layerwise\nstable_adamw",
"crumbs": [
"Core Concepts",
"Optimizers"
]
},
{
"objectID": "docs/optimizers.html#custom-optimizers",
"href": "docs/optimizers.html#custom-optimizers",
"title": "Optimizers",
"section": "Custom Optimizers",
"text": "Custom Optimizers\nEnable custom optimizers by passing a string to the optimizer argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.\n\noptimi_adamw\noptimizer: optimi_adamw\n\n\nao_adamw_4bit\nDeprecated: Please use adamw_torch_4bit.\n\n\nao_adamw_8bit\nDeprecated: Please use adamw_torch_8bit.\n\n\nao_adamw_fp8\noptimizer: ao_adamw_fp8\n\n\nadopt_adamw\nGitHub: https://github.com/iShohei220/adopt\nPaper: https://arxiv.org/abs/2411.02853\noptimizer: adopt_adamw\n\n\ncame_pytorch\nGitHub: https://github.com/yangluo7/CAME/tree/master\nPaper: https://arxiv.org/abs/2307.02047\noptimizer: came_pytorch\n\n# optional args (defaults below)\nadam_beta1: 0.9\nadam_beta2: 0.999\nadam_beta3: 0.9999\nadam_epsilon: 1e-30\nadam_epsilon2: 1e-16\n\n\nmuon\nBlog: https://kellerjordan.github.io/posts/muon/\nPaper: https://arxiv.org/abs/2502.16982v1\noptimizer: muon\n\n\ndion\nMicrosofts Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient\northonormalizing optimizer that uses low-rank approximations to reduce gradient communication.\nGitHub: https://github.com/microsoft/dion\nPaper: https://arxiv.org/pdf/2504.05295\nNote: Implementation written for PyTorch 2.7+ for DTensor\noptimizer: dion\ndion_lr: 0.01\ndion_momentum: 0.95\nlr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW",
"crumbs": [
"Core Concepts",
"Optimizers"
]
},
{
"objectID": "docs/ebft.html",
"href": "docs/ebft.html",
"title": "EBFT Training",
"section": "",
"text": "Energy-Based Fine-Tuning (EBFT) is a training method that optimizes language models by matching the internal feature representations of generated text to those of ground-truth completions. Instead of relying on external reward models or hand-crafted reward functions, EBFT extracts hidden states from intermediate layers of a frozen copy of the model and uses cosine similarity between generated and reference features as the reward signal.\nPaper: “Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models” (Jelassi et al., 2026)\n\n\n\n\n\n\n\n\n\n\n\nMethod\nReward Signal\nRequires\nBest For\n\n\n\n\nGRPO\nExternal reward function(s)\nCustom reward code or reward model\nTasks with verifiable answers (math, code)\n\n\nDPO\nPreference pairs (chosen vs rejected)\nPaired preference data\nAlignment with human preferences\n\n\nEBFT\nFeature similarity to ground truth\nGround-truth completions\nAny task with reference outputs\n\n\n\nEBFTs key advantage is that it needs only ground-truth completions no reward engineering, no preference annotation, and no reward model training. The models own internal representations serve as the reward signal. This makes it particularly effective for:\n\nCode generation (match features of known-good solutions)\nInstruction following with reference outputs\nContinual pretraining on unstructured text (strided mode)\nMulti-turn dialogue with reference conversations\n\n\n\n\nThe EBFT reward for each generated completion is:\nreward = alignment_coef * cosine_similarity(gen_features, gt_features)\n - diversity_coef * mean_pairwise_similarity(gen_features)\n\nAlignment: How closely the generated outputs internal representations match the ground truth. Higher is better.\nDiversity: Penalizes generated samples that are too similar to each other (prevents mode collapse). Lower is better.\nCFM loss (Cross-Feature Matching): Tracks ||mean(gen_features) - gt_features||^2 as a diagnostic. This is the quantity that EBFT ultimately minimizes.",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/ebft.html#overview",
"href": "docs/ebft.html#overview",
"title": "EBFT Training",
"section": "",
"text": "Energy-Based Fine-Tuning (EBFT) is a training method that optimizes language models by matching the internal feature representations of generated text to those of ground-truth completions. Instead of relying on external reward models or hand-crafted reward functions, EBFT extracts hidden states from intermediate layers of a frozen copy of the model and uses cosine similarity between generated and reference features as the reward signal.\nPaper: “Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models” (Jelassi et al., 2026)\n\n\n\n\n\n\n\n\n\n\n\nMethod\nReward Signal\nRequires\nBest For\n\n\n\n\nGRPO\nExternal reward function(s)\nCustom reward code or reward model\nTasks with verifiable answers (math, code)\n\n\nDPO\nPreference pairs (chosen vs rejected)\nPaired preference data\nAlignment with human preferences\n\n\nEBFT\nFeature similarity to ground truth\nGround-truth completions\nAny task with reference outputs\n\n\n\nEBFTs key advantage is that it needs only ground-truth completions no reward engineering, no preference annotation, and no reward model training. The models own internal representations serve as the reward signal. This makes it particularly effective for:\n\nCode generation (match features of known-good solutions)\nInstruction following with reference outputs\nContinual pretraining on unstructured text (strided mode)\nMulti-turn dialogue with reference conversations\n\n\n\n\nThe EBFT reward for each generated completion is:\nreward = alignment_coef * cosine_similarity(gen_features, gt_features)\n - diversity_coef * mean_pairwise_similarity(gen_features)\n\nAlignment: How closely the generated outputs internal representations match the ground truth. Higher is better.\nDiversity: Penalizes generated samples that are too similar to each other (prevents mode collapse). Lower is better.\nCFM loss (Cross-Feature Matching): Tracks ||mean(gen_features) - gt_features||^2 as a diagnostic. This is the quantity that EBFT ultimately minimizes.",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/ebft.html#modes",
"href": "docs/ebft.html#modes",
"title": "EBFT Training",
"section": "Modes",
"text": "Modes\nEBFT supports three operational modes, each suited to different use cases.\n\nStructured Mode (Sync)\nUses vLLM on a separate GPU for generation, with sequential generate-score-train steps. This is the simplest mode and recommended for getting started.\nGPU 0: vLLM Server (generates completions, receives weight syncs)\nGPU 1: Trainer (feature extraction, reward computation, GRPO training)\nWhen to use: Standard instruction-following or QA datasets where you have prompt/completion pairs. Requires 2 GPUs.\n\n\nStructured Mode (Async)\nSame architecture as sync, but overlaps generation of the next batch with training on the current batch. Faster throughput at the cost of slightly stale weights during generation.\nWhen to use: Same data as sync mode, but when you want faster training and can tolerate weight staleness (controlled by vllm_sync_interval).\n\n\nStrided Mode\nRuns entirely on a single GPU with no vLLM dependency. Places anchor points throughout a document and generates short rollouts at each anchor using block-parallel attention patterns.\nSingle GPU: Base model + LoRA adapter\n - Strided block-parallel generation (flex_attention)\n - Feature extraction via disable_adapter()\n - No vLLM needed\nWhen to use: Unstructured text data (raw code, prose, documents) where there is no natural prompt/completion split. Also works with structured data that includes prompt boundaries. Requires only 1 GPU.",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/ebft.html#quick-start",
"href": "docs/ebft.html#quick-start",
"title": "EBFT Training",
"section": "Quick Start",
"text": "Quick Start\n\nStructured Mode\nThis minimal example fine-tunes Qwen2-0.5B on code data using EBFT with vLLM generation.\nStep 1: Create a config file ebft_quickstart.yaml:\nbase_model: Qwen/Qwen2-0.5B-Instruct\n\nrl: ebft\n\nebft:\n feature_layers: [0.25, 0.5, 0.75]\n embed_method: last_token\n alignment_coef: 1.0\n diversity_coef: 1.0\n\ntrl:\n num_generations: 4\n max_completion_length: 256\n temperature: 0.7\n use_vllm: true\n vllm_server_host: 0.0.0.0\n vllm_server_port: 8000\n vllm_lora_sync: true\n vllm_sync_interval: 3\n use_data_producer: true\n async_prefetch: false\n scale_rewards: true\n loss_type: grpo\n\nvllm:\n gpu_memory_utilization: 0.5\n max_model_len: 1024\n\ndatasets:\n - path: nvidia/OpenCodeInstruct\n type: ebft_opencode.transform\n split: train[:500]\n\n# Standard training settings (see getting-started.qmd for details)\nadapter: lora\nlora_r: 16\nlora_alpha: 32\nlora_target_linear: true\nsequence_len: 1024\nmicro_batch_size: 2\ngradient_accumulation_steps: 4\nmax_steps: 20\nlearning_rate: 5.0e-6\nbf16: auto\nflash_attention: true\ngradient_checkpointing: true\noutput_dir: ./outputs/ebft-quickstart\nStep 2: Start vLLM on GPU 0:\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve ebft_quickstart.yaml\nStep 3: Wait approximately 30 seconds for vLLM to initialize, then start training on GPU 1:\nCUDA_VISIBLE_DEVICES=1 axolotl train ebft_quickstart.yaml\n\n\n\n\n\n\nImportant\n\n\n\nThe micro_batch_size must be divisible by num_generations. For example, with num_generations: 4, valid values are 4, 8, 12, etc.\n\n\n\n\nDataset Format\nStructured mode datasets must produce two fields after the transform:\n\nprompt: Either a string or a list of chat messages ([{\"role\": \"user\", \"content\": \"...\"}])\nground_truth: A string containing the reference completion\n\nExample raw dataset row:\n{\n \"input\": \"Write a function to compute fibonacci numbers.\",\n \"output\": \"def fibonacci(n):\\n if n <= 1:\\n return n\\n return fibonacci(n-1) + fibonacci(n-2)\"\n}\nThe ebft_opencode.transform converts this to the required {prompt, ground_truth} format automatically.",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/ebft.html#feature-extraction",
"href": "docs/ebft.html#feature-extraction",
"title": "EBFT Training",
"section": "Feature Extraction",
"text": "Feature Extraction\nEBFT extracts hidden states from intermediate transformer layers and pools them into per-sequence embeddings. These embeddings are compared between generated and ground-truth completions to compute rewards.\n\nFeature Layers\nThe feature_layers parameter specifies which layers to extract, as fractions of total model depth:\nebft:\n feature_layers: [0.25, 0.5, 0.75] # Quarter, middle, three-quarter depth\nFor a 32-layer model, this extracts layers 8, 16, and 24. The hidden states from all selected layers are concatenated along the feature dimension, producing embeddings of size num_layers * hidden_dim.\n\n\n\n\n\n\nTip\n\n\n\nUsing multiple layers captures both low-level syntactic features (early layers) and high-level semantic features (later layers). The default [0.25, 0.5, 0.75] works well across model sizes.\n\n\n\n\nEmbed Methods\nThe embed_method controls how per-token hidden states are pooled into a single vector per sequence:\n\n\n\n\n\n\n\n\n\nMethod\nDescription\nOutput Shape\nNotes\n\n\n\n\nlast_token\nHidden state at the last non-padding token\n(B, D)\nDefault. Good for autoregressive models where the last token summarizes the sequence.\n\n\nmean_pooling\nMean of all non-padding token states\n(B, D)\nConsiders the entire sequence equally.\n\n\ncompletion_mean\nMean over completion tokens only (excludes prompt)\n(B, D)\nFocuses reward signal on generated content. Requires prompt length information.\n\n\nconcat\nConcatenation of states at 25%, 50%, 75% positions\n(B, 3*D)\nCaptures positional structure. Higher dimensional.\n\n\n\nebft:\n embed_method: completion_mean # Focus on completion features\n\n\nSVD Whitening\nWhitening decorrelates the feature dimensions so that no single direction dominates the feature-matching loss. This is computed via SVD on the generated embeddings, with the same transform applied to the ground-truth embeddings.\nebft:\n use_whitening: true\nWhen whitening is enabled, the reward computation applies a whitening matrix W = U @ diag(1/S) @ U^T derived from the SVD of generated embeddings. This ensures all feature dimensions contribute equally to the alignment reward.\n\n\n\n\n\n\nNote\n\n\n\nSingular values scale with sqrt(batch_size), so reward magnitudes are batch-size dependent. This is acceptable because the number of samples per prompt (n_samples_per_prompt or num_generations) is fixed during training.\n\n\n\n\nAlignment and Diversity Coefficients\nThe two reward components are weighted by coefficients:\nebft:\n alignment_coef: 1.0 # Weight for cosine similarity with ground truth\n diversity_coef: 1.0 # Weight for pairwise similarity penalty\nBoth values are scaled by 2 internally (per paper equation 7). The final reward per sample is:\nreward_j = 2 * alignment_coef * cos(gen_j, gt)\n - 2 * diversity_coef * (1/(n-1)) * sum_{j' != j} dot(gen_j, gen_j')\nSetting diversity_coef: 0.0 disables the diversity penalty entirely, which may be appropriate when num_generations is small (e.g., 2).",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/ebft.html#strided-mode-1",
"href": "docs/ebft.html#strided-mode-1",
"title": "EBFT Training",
"section": "Strided Mode",
"text": "Strided Mode\nStrided mode is designed for training on unstructured text data where there is no natural prompt/completion boundary. Instead of generating full completions with vLLM, it places anchor points at regular intervals throughout each document and generates short rollouts at each anchor using block-parallel attention.\n\nHow Block-Parallel Generation Works\nGiven a document of length S tokens:\n\nAnchor placement: Starting at position anchor_offset, place anchors every stride tokens. Each anchor defines a block.\nContext window: Each block sees context_length tokens of preceding context from the original document.\nGeneration: At each anchor, generate generate_max_len tokens autoregressively, conditioned only on the context window.\nParallelism: All blocks are processed in a single forward pass using a specialized attention mask that prevents information leakage between blocks.\n\nDocument: [tok0, tok1, ..., tok_S]\n | | |\n anchor_0 anchor_1 anchor_2\n | | |\n [ctx][gen] [ctx][gen] [ctx][gen]\nThe attention mask ensures:\n\nPrompt tokens use standard causal attention\nEach generated block attends to its own context window and its own preceding generated tokens\nBlocks do not attend to each others generated tokens\n\nWhen flex_attention is available (PyTorch >= 2.5), the mask is compiled into efficient fused kernels. Otherwise, a dense 4D attention mask is used as a fallback.\n\n\nStrided Mode Configuration\nbase_model: meta-llama/Llama-3.2-1B\nrl: ebft\n\nebft:\n mode: strided\n stride: 8 # Tokens between anchor points\n context_length: 8 # Context window per block\n generate_max_len: 8 # Tokens to generate per block\n n_samples_per_prompt: 4 # Independent rollouts per document\n temperature: 0.6\n feature_layers: [0.25, 0.5, 0.75]\n embed_method: last_token\n use_whitening: true\n alignment_coef: 1.0\n diversity_coef: 1.0\n rl_coef: 1.0 # RL policy gradient loss weight\n ce_coef: 0.03 # Cross-entropy loss on GT tokens\n advantage_estimator: rloo # rloo, group_norm, or reinforce\n min_completion_prefix: 8 # Skip anchors in prompt region\n\ndatasets:\n - path: nvidia/OpenCodeInstruct\n type: ebft_strided_structured.transform\n split: train[:1%]\n\nsequence_len: 2048\nmicro_batch_size: 1\ngradient_accumulation_steps: 2\n\nadapter: lora\nlora_r: 16\nlora_alpha: 32\nlora_target_linear: true\n\nbf16: auto\nflex_attention: true\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n use_reentrant: true # Required with flex_attention\nRun with a single command (no vLLM needed):\nCUDA_VISIBLE_DEVICES=0 axolotl train config.yaml\n\n\nAdvantage Estimators\nStrided mode supports three advantage estimation methods:\n\n\n\n\n\n\n\n\nEstimator\nFormula\nRequirements\n\n\n\n\nrloo\nLeave-one-out baseline: reward_j - mean(rewards_{-j})\nn_samples_per_prompt >= 2\n\n\ngroup_norm\nGroup normalization: (reward_j - mean) / std\nn_samples_per_prompt >= 2\n\n\nreinforce\nRaw reward as advantage (no baseline)\nWorks with n_samples_per_prompt = 1\n\n\n\n\n\n\n\n\n\nWarning\n\n\n\nWhen n_samples_per_prompt: 1, the trainer automatically falls back to reinforce and disables the diversity penalty (which requires multiple samples).\n\n\n\n\nStrided Mode Constraints\n\nflex_attention: true is strongly recommended. Without it, dense 4D masks consume significantly more memory.\ntorch_compile: true must NOT be set. flex_attention compiles its own kernels internally; adding torch_compile causes conflicts and OOM.\nGradient checkpointing must use use_reentrant: true. Non-reentrant checkpointing causes CheckpointError with flex_attention block masks.\nactivation_offloading is incompatible with flex_attention.\n\n\n\nCross-Entropy Loss\nStrided mode supports an optional cross-entropy loss term on ground-truth tokens. This acts as a regularizer to prevent the model from drifting too far from the original distribution:\nebft:\n ce_coef: 0.03 # Small CE coefficient\n rl_coef: 1.0 # RL loss coefficient\nThe total loss is rl_coef * rl_loss + ce_coef * ce_loss. For structured mode, ce_coef is typically 0.0 since vLLM generation provides sufficient learning signal.",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/ebft.html#dataset-formats",
"href": "docs/ebft.html#dataset-formats",
"title": "EBFT Training",
"section": "Dataset Formats",
"text": "Dataset Formats\nEBFT provides several built-in dataset transforms in src/axolotl/prompt_strategies/ebft/.\n\nBuilt-In Transforms\n\n\n\n\n\n\n\n\n\nTransform\nInput Format\nOutput Fields\nUse Case\n\n\n\n\nebft_opencode.transform\n{input, output}\n{prompt, ground_truth}\nOpenCodeInstruct, structured QA\n\n\nebft_strided_structured.transform\n{input, output}\n{input_ids, labels, prompt_length}\nStrided mode with structured data\n\n\nebft_strided_chat.transform\n{messages: [...]}\n{input_ids, labels, prompt_length}\nStrided mode with chat data\n\n\nebft_chat_multiturn.transform\n{messages: [...]}\n{prompt, ground_truth, remaining_turns}\nMulti-turn: first-turn target\n\n\nebft_chat_multiturn.transform_last_turn\n{messages: [...]}\n{prompt, ground_truth}\nMulti-turn: last-turn target\n\n\nebft_chat_multiturn.transform_all_turns\n{messages: [...]}\n{prompt[], ground_truth[]}\nMulti-turn: one example per turn\n\n\nebft_reasoning.transform\n{messages: [...]} (with <think>)\n{prompt, ground_truth}\nReasoning/thinking datasets\n\n\n\n\n\nStructured Mode Datasets\nFor structured (sync/async) mode, the transform must produce prompt and ground_truth fields:\ndatasets:\n - path: nvidia/OpenCodeInstruct\n type: ebft_opencode.transform\n split: train[:500]\n\n\nMulti-Turn Datasets\nMulti-turn transforms extract conversation data for sequential rollout. The transform variant targets the first assistant turn, while transform_last_turn targets the final turn:\ndatasets:\n - path: your/multiturn-dataset\n type: ebft_chat_multiturn.transform\nWhen remaining_turns is present in the dataset output, the trainer performs sequential rollouts: it generates the first assistant turn with vLLM, then continues generating subsequent turns by building up the conversation history.\n\n\nStrided Mode Datasets\nStrided transforms tokenize the full document and produce input_ids, labels, and prompt_length:\ndatasets:\n - path: nvidia/OpenCodeInstruct\n type: ebft_strided_structured.transform\n split: train[:1%]\n\n\nCustom Transforms\nTo use your own dataset format, write a transform function:\ndef transform(cfg, **kwargs):\n def transform_fn(example, tokenizer=None):\n return {\n \"prompt\": [{\"role\": \"user\", \"content\": example[\"question\"]}],\n \"ground_truth\": example[\"answer\"],\n }\n return transform_fn, {\"remove_columns\": \"__all__\"}\nThe \"__all__\" sentinel removes all original dataset columns after the mapping step. Reference this transform in your config:\ndatasets:\n - path: your/dataset\n type: your_module.transform",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/ebft.html#configuration-reference",
"href": "docs/ebft.html#configuration-reference",
"title": "EBFT Training",
"section": "Configuration Reference",
"text": "Configuration Reference\n\nCommon Parameters (All Modes)\nThese parameters are set under the ebft: key in the YAML config.\n\n\n\n\n\n\n\n\n\nParameter\nType\nDefault\nDescription\n\n\n\n\nmode\n\"structured\" or \"strided\"\n\"structured\"\nEBFT operating mode\n\n\nfeature_layers\nlist[float]\n[0.25, 0.5, 0.75]\nFractional layer depths for feature extraction\n\n\nembed_method\nstring\n\"last_token\"\nPooling method: last_token, mean_pooling, completion_mean, or concat\n\n\nuse_whitening\nbool\nfalse\nApply SVD whitening to feature embeddings before reward computation\n\n\nalignment_coef\nfloat\n1.0\nWeight for alignment reward (cosine similarity with ground truth)\n\n\ndiversity_coef\nfloat\n1.0\nWeight for diversity penalty (pairwise dot product between samples)\n\n\nce_coef\nfloat\n0.0\nCross-entropy loss coefficient on ground-truth tokens\n\n\nadaptive_max_tokens\nbool\ntrue\nDynamically set vLLM max_tokens based on ground-truth length (structured mode)\n\n\ngt_length_multiplier\nfloat\n1.5\nMultiplier for ground-truth token count when computing adaptive max tokens (min 0.1)\n\n\n\n\n\nStrided Mode Parameters\nThese additional parameters apply only when mode: strided.\n\n\n\n\n\n\n\n\n\nParameter\nType\nDefault\nDescription\n\n\n\n\nstride\nint\n8\nNumber of tokens between anchor points (must be >= 1)\n\n\ncontext_length\nint\n8\nContext window size for each generated block (must be >= 1)\n\n\ngenerate_max_len\nint\n8\nNumber of tokens to generate per block (must be >= 1)\n\n\nn_samples_per_prompt\nint\n4\nNumber of independent rollouts per document (must be >= 1)\n\n\ntemperature\nfloat\n0.6\nSampling temperature for strided generation\n\n\ntop_p\nfloat\n1.0\nTop-p nucleus sampling threshold\n\n\nrl_coef\nfloat\n1.0\nRL policy gradient loss coefficient\n\n\nadvantage_estimator\nstring\n\"rloo\"\nAdvantage estimation method: rloo, group_norm, or reinforce\n\n\nmin_completion_prefix\nint\n0\nMinimum tokens into the completion span before placing anchors\n\n\n\n\n\nStructured Mode TRL Parameters\nThese are set under the trl: key and control the GRPO training loop.\n\n\n\n\n\n\n\n\n\nParameter\nType\nDefault\nDescription\n\n\n\n\nnum_generations\nint\n\nNumber of completions generated per prompt\n\n\nmax_completion_length\nint\n\nMaximum tokens per generated completion\n\n\ntemperature\nfloat\n0.7\nSampling temperature for vLLM generation\n\n\nuse_vllm\nbool\n\nEnable vLLM generation backend\n\n\nvllm_lora_sync\nbool\nfalse\nSync LoRA adapters via filesystem (recommended)\n\n\nvllm_sync_interval\nint\n1\nSteps between weight syncs to vLLM\n\n\nuse_data_producer\nbool\n\nRequired for sync mode with LoRA sync\n\n\nasync_prefetch\nbool\nfalse\nEnable async generation (overlaps with training)\n\n\nstreaming_partial_batch\nbool\nfalse\nScore groups incrementally (async mode)\n\n\nskip_zero_advantage_batches\nbool\nfalse\nSkip micro-batches where all advantages are zero\n\n\nscale_rewards\nbool\n\nNormalize rewards within each prompt group\n\n\nloss_type\nstring\n\"grpo\"\nLoss type for policy optimization\n\n\nepsilon\nfloat\n0.2\nClipping parameter for importance sampling\n\n\n\n\n\nStop Tokens\nvLLM needs explicit stop token IDs for generation. Common configurations:\ntrl:\n generation_kwargs:\n stop_token_ids: [151645, 151643] # Qwen: <|im_end|>, <|endoftext|>\n\n\nMulti-Turn Chat Settings\nFor multi-turn conversations with Qwen3.5, disable thinking mode to prevent <think> tags in completions:\ntrl:\n chat_template_kwargs:\n enable_thinking: false",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/ebft.html#monitoring",
"href": "docs/ebft.html#monitoring",
"title": "EBFT Training",
"section": "Monitoring",
"text": "Monitoring\n\nKey Metrics\nEBFT logs several custom metrics to wandb and the training console. Here is what to watch for:\n\n\n\n\n\n\n\n\nMetric\nHealthy Range\nInterpretation\n\n\n\n\nebft/alignment\n0.3 0.9, trending upward\nCosine similarity between generated and ground-truth features. Higher means the model is learning to produce representations that match the reference.\n\n\nebft/diversity\n0.01 0.1\nMean pairwise similarity between different generations for the same prompt. Values above 1.0 indicate mode collapse.\n\n\nebft/cfm_loss\nBelow 10, trending downward\nCross-Feature Matching loss. This is the core quantity being minimized. Consistently above 100 indicates instability.\n\n\nebft/reward\nTrending upward (may start negative)\nCombined reward signal. If stuck at -1.0, the diversity penalty is dominating alignment.\n\n\ngrad_norm\n0.1 3.0\nGradient magnitude. Values of 0.0 indicate zero-advantage skip (normal). Values above 10 suggest instability.\n\n\nentropy\n0.05 0.5\nPolicy entropy. Values below 0.01 suggest mode collapse.\n\n\nIS ratio min\nAbove 0.1\nImportance sampling ratio minimum. Near-zero values mean the policy is too far off-policy; increase vllm_sync_interval.\n\n\n\n\n\nConsole Log Example\nDuring training, you will see periodic EBFT reward logs:\nebft reward | align +0.412 ^ | divers +0.023 v | cfm 4.231 v | reward +0.389 ^\nThe arrows indicate the desired direction: alignment and reward should trend upward, while diversity and CFM loss should trend downward.\n\n\nTroubleshooting\n\n\n\n\n\n\n\n\nSymptom\nLikely Cause\nFix\n\n\n\n\nalignment stays below 0.1\nFeature layers not capturing useful information\nTry different feature_layers or embed_method\n\n\ndiversity exceeds 1.0\nMode collapse generations are too similar\nIncrease diversity_coef or temperature\n\n\nreward stuck at -1.0\nDiversity penalty dominates alignment\nReduce diversity_coef or increase alignment_coef\n\n\ngrad_norm consistently 0.0\nAll micro-batches have zero advantage\nIncrease num_generations or check data quality\n\n\nCheckpointError in strided mode\nIncompatible gradient checkpointing settings\nSet use_reentrant: true in gradient_checkpointing_kwargs\n\n\nOOM during training\nLogits tensor too large\nReduce sequence_len or micro_batch_size; strided mode uses chunked lm_head to mitigate this\n\n\nvLLM 500 errors\ntruncate_prompt_tokens not supported\nEnsure you are using axolotl vllm-serve (not trl vllm-serve)\n\n\n\n\n\nFeature Network Memory\nIn PEFT (LoRA) mode, the feature network shares base weights with the actor model by using the disable_adapter() context manager. This saves an entire model copy in VRAM (approximately 116 GB depending on model size). For non-PEFT training, a separate frozen deepcopy is created.\n\n\n\n\n\n\nNote\n\n\n\nThe disable_adapter() approach relies on an invariant: merge_adapter() is never called on the base weights. All weight sync paths (LoRA sync, HTTP, NCCL) compute merged weights as new tensors or save the adapter to the filesystem, leaving base weights unmodified.",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/ebft.html#examples",
"href": "docs/ebft.html#examples",
"title": "EBFT Training",
"section": "Examples",
"text": "Examples\nComplete example configurations are available in examples/ebft/:\n\n\n\n\n\n\n\n\n\nConfig\nModel\nMode\nDescription\n\n\n\n\nllama-1b-ebft-strided-structured.yaml\nLlama 3.2 1B\nStrided\nSingle-GPU strided training on code data\n\n\nqwen3-4b-ebft-structured.yaml\nQwen3 4B\nStructured (sync)\nTwo-GPU structured training\n\n\nqwen3-4b-ebft-structured-async.yaml\nQwen3 4B\nStructured (async)\nTwo-GPU async training with prefetch\n\n\nqwen3-8b-ebft-structured.yaml\nQwen3 8B\nStructured (sync)\nTwo-GPU structured training for larger model\n\n\nqwen35-4b-ebft-structured.yaml\nQwen3.5 4B\nStructured (sync)\nTwo-GPU with Qwen3.5\n\n\nqwen35-4b-ebft-structured-async.yaml\nQwen3.5 4B\nStructured (async)\nTwo-GPU async with Qwen3.5\n\n\nqwen35-9b-ebft-structured.yaml\nQwen3.5 9B\nStructured (sync)\nTwo-GPU structured for 9B model",
"crumbs": [
"How To Guides",
"EBFT Training"
]
},
{
"objectID": "docs/torchao.html",
"href": "docs/torchao.html",
"title": "PyTorch ao",
"section": "",
"text": "To use experimental optimizers (AdamWFp8, AdamW4bit, AdamW8bit) from Pytorch Ao, please install the package as shown below.\n\n\n\n\n\n\nTip\n\n\n\nSome experimental optimizers are already present in regular Pytorch, so please re-check if you actually need this package!\n\n\n\nInstallation\nStable Release from the PyTorch index\npip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124\nNightly release\npip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124",
"crumbs": [
"Advanced Features",
"PyTorch ao"
]
},
{
"objectID": "docs/lr_groups.html",
"href": "docs/lr_groups.html",
"title": "Learning Rate Groups",
"section": "",
"text": "Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of\nmodules in a model.",
"crumbs": [
"How To Guides",
"Learning Rate Groups"
]
},
{
"objectID": "docs/lr_groups.html#background",
"href": "docs/lr_groups.html#background",
"title": "Learning Rate Groups",
"section": "",
"text": "Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of\nmodules in a model.",
"crumbs": [
"How To Guides",
"Learning Rate Groups"
]
},
{
"objectID": "docs/lr_groups.html#example",
"href": "docs/lr_groups.html#example",
"title": "Learning Rate Groups",
"section": "Example",
"text": "Example\nlr_groups:\n - name: o_proj\n modules:\n - self_attn.o_proj.weight\n lr: 1e-6\n - name: q_proj\n modules:\n - model.layers.2.self_attn.q_proj.weight\n lr: 1e-5\n\nlearning_rate: 2e-5\nIn this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate\nof 1e-6 for all the self attention o_proj modules across all layers, and a learning are of 1e-5 to the 3rd layers\nself attention q_proj module.\n\n\n\n\n\n\nNote\n\n\n\nWe currently only support varying lr for now. If youre interested in adding support for others (weight_decay), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17",
"crumbs": [
"How To Guides",
"Learning Rate Groups"
]
},
{
"objectID": "docs/streaming.html",
"href": "docs/streaming.html",
"title": "Streaming Datasets",
"section": "",
"text": "Streaming enables memory-efficient training with large datasets by loading data\nincrementally rather than loading the entire dataset into memory at once.\nUse streaming when:\nStreaming works with both remote and locally stored datasets!",
"crumbs": [
"Core Concepts",
"Streaming Datasets"
]
},
{
"objectID": "docs/streaming.html#configuration",
"href": "docs/streaming.html#configuration",
"title": "Streaming Datasets",
"section": "Configuration",
"text": "Configuration\n\nBasic Streaming\nEnable streaming mode by setting the streaming flag:\nstreaming: true\n\n\nPretraining with Streaming\nFor pretraining tasks, streaming is automatically enabled when using pretraining_dataset:\npretraining_dataset:\n - path: HuggingFaceFW/fineweb-edu\n type: pretrain\n text_column: text\n split: train\n\n# Optionally, enable sample packing\nstreaming_multipack_buffer_size: 10000\nsample_packing: true\n\n\nSFT with Streaming\nFor supervised fine-tuning with streaming:\nstreaming: true\ndatasets:\n - path: tatsu-lab/alpaca\n type: alpaca\n split: train\n\n# Optionally, enable sample packing\nstreaming_multipack_buffer_size: 10000\nsample_packing: true",
"crumbs": [
"Core Concepts",
"Streaming Datasets"
]
},
{
"objectID": "docs/streaming.html#configuration-options",
"href": "docs/streaming.html#configuration-options",
"title": "Streaming Datasets",
"section": "Configuration Options",
"text": "Configuration Options\n\nstreaming_multipack_buffer_size\nControls the buffer size for multipack streaming (default: 10,000). This determines how\nmany samples are buffered before packing. Larger buffers can improve packing efficiency\nbut use more memory.\n\n\nshuffle_merged_datasets\nWhen enabled, shuffles the streaming dataset using the buffer. This requires additional\nmemory for the shuffle buffer.",
"crumbs": [
"Core Concepts",
"Streaming Datasets"
]
},
{
"objectID": "docs/streaming.html#sample-packing-with-streaming",
"href": "docs/streaming.html#sample-packing-with-streaming",
"title": "Streaming Datasets",
"section": "Sample Packing with Streaming",
"text": "Sample Packing with Streaming\nSample packing is supported for streaming datasets. When enabled, multiple samples are\npacked into a single sequence to maximize GPU utilization:\nsample_packing: true\nstreaming_multipack_buffer_size: 10000\n\n# For SFT: attention is automatically isolated between packed samples\n# For pretraining: control with pretrain_multipack_attn\npretrain_multipack_attn: true # prevent cross-attention between packed samples\nFor more information, see our documentation on multipacking.",
"crumbs": [
"Core Concepts",
"Streaming Datasets"
]
},
{
"objectID": "docs/streaming.html#important-considerations",
"href": "docs/streaming.html#important-considerations",
"title": "Streaming Datasets",
"section": "Important Considerations",
"text": "Important Considerations\n\nMemory Usage\nWhile streaming reduces memory usage compared to loading entire datasets, you still need\nto consider:\n\nYou can control the memory usage by adjusting streaming_multipack_buffer_size\nSample packing requires buffering multiple samples\nShuffling requires additional memory for the shuffle buffer\n\n\n\nPerformance\n\nStreaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly\nNetwork speed and disk read speed are important when streaming from remote sources or a local dataset, respectively\nConsider using axolotl preprocess for smaller or more frequently used datasets\n\n\n\nEvaluation Datasets\nEvaluation datasets are not streamed to ensure consistent evaluation metrics. Theyre\nloaded normally even when training uses streaming.",
"crumbs": [
"Core Concepts",
"Streaming Datasets"
]
},
{
"objectID": "docs/streaming.html#examples",
"href": "docs/streaming.html#examples",
"title": "Streaming Datasets",
"section": "Examples",
"text": "Examples\nSee the examples/streaming/ directory for complete configuration examples:\n\npretrain.yaml: Pretraining with streaming dataset\nsft.yaml: Supervised fine-tuning with streaming",
"crumbs": [
"Core Concepts",
"Streaming Datasets"
]
},
{
"objectID": "docs/amd_hpc.html",
"href": "docs/amd_hpc.html",
"title": "AMD GPUs on HPC Systems",
"section": "",
"text": "This guide provides step-by-step instructions for installing and configuring Axolotl on a High-Performance Computing (HPC) environment equipped with AMD GPUs.",
"crumbs": [
"Deployments",
"AMD GPUs on HPC Systems"
]
},
{
"objectID": "docs/amd_hpc.html#setup",
"href": "docs/amd_hpc.html#setup",
"title": "AMD GPUs on HPC Systems",
"section": "Setup",
"text": "Setup\n\n1. Install Python\nWe recommend using Miniforge, a minimal conda-based Python distribution:\ncurl -L -O \"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh\"\nbash Miniforge3-$(uname)-$(uname -m).sh\n\n\n2. Configure Python Environment\nAdd Python to your PATH and ensure its available at login:\necho 'export PATH=~/miniforge3/bin:$PATH' >> ~/.bashrc\necho 'if [ -f ~/.bashrc ]; then . ~/.bashrc; fi' >> ~/.bash_profile\n\n\n3. Load AMD GPU Software\nLoad the ROCm module:\nmodule load rocm/5.7.1\nNote: The specific module name and version may vary depending on your HPC system. Consult your system documentation for the correct module name.\n\n\n4. Install PyTorch\nInstall PyTorch with ROCm support:\npip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7 --force-reinstall\n\n\n5. Install Flash Attention\nClone and install the Flash Attention repository:\ngit clone --recursive https://github.com/ROCmSoftwarePlatform/flash-attention.git\nexport GPU_ARCHS=\"gfx90a\"\ncd flash-attention\nexport PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')\npatch \"${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py\" hipify_patch.patch\npip install --no-build-isolation .\n\n\n6. Install Axolotl\nClone and install Axolotl:\ngit clone https://github.com/axolotl-ai-cloud/axolotl\ncd axolotl\npip install packaging ninja\npip install --no-build-isolation -e .\n\n\n7. Apply xformers Workaround\nxformers appears to be incompatible with ROCm. Apply the following workarounds:\n- Edit $HOME/packages/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py modifying the code to always return False for SwiGLU availability from xformers.\n- Edit $HOME/miniforge3/lib/python3.10/site-packages/xformers/ops/swiglu_op.py replacing the “SwiGLU” function with a pass statement.\n\n\n8. Prepare Job Submission Script\nCreate a script for job submission using your HPCs particular software (e.g. Slurm, PBS). Include necessary environment setup and the command to run Axolotl training. If the compute node(s) do(es) not have internet access, it is recommended to include\nexport TRANSFORMERS_OFFLINE=1\nexport HF_DATASETS_OFFLINE=1\n\n\n9. Download Base Model\nDownload a base model using the Hugging Face CLI:\nhf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B\n\n\n10. Create Axolotl Configuration\nCreate an Axolotl configuration file (YAML format) tailored to your specific training requirements and dataset. Use FSDP for multi-node training.\nNote: Deepspeed did not work at the time of testing. However, if anyone managed to get it working, please let us know.\n\n\n11. Preprocess Data\nRun preprocessing on the login node:\nCUDA_VISIBLE_DEVICES=\"\" python -m axolotl.cli.preprocess /path/to/your/config.yaml\n\n\n12. Train\nYou are now ready to submit your previously prepared job script. 🚂",
"crumbs": [
"Deployments",
"AMD GPUs on HPC Systems"
]
},
{
"objectID": "docs/installation.html",
"href": "docs/installation.html",
"title": "Installation",
"section": "",
"text": "This guide covers all the ways you can install and set up Axolotl for your environment.",
"crumbs": [
"Getting Started",
"Installation"
]
},
{
"objectID": "docs/installation.html#sec-requirements",
"href": "docs/installation.html#sec-requirements",
"title": "Installation",
"section": "1 Requirements",
"text": "1 Requirements\n\nNVIDIA GPU (Ampere architecture or newer for bf16 and Flash Attention) or AMD GPU\nPython ≥3.11\nPyTorch ≥2.6.0",
"crumbs": [
"Getting Started",
"Installation"
]
},
{
"objectID": "docs/installation.html#sec-installation-methods",
"href": "docs/installation.html#sec-installation-methods",
"title": "Installation",
"section": "2 Installation Methods",
"text": "2 Installation Methods\n\n\n\n\n\n\nImportant\n\n\n\nPlease make sure to have Pytorch installed before installing Axolotl in your local environment.\nFollow the instructions at: https://pytorch.org/get-started/locally/\n\n\n\n\n\n\n\n\nImportant\n\n\n\nFor Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.\n\n\n\n2.1 PyPI Installation (Recommended)\npip3 install -U packaging setuptools wheel ninja\npip3 install --no-build-isolation axolotl[flash-attn,deepspeed]\nWe use --no-build-isolation in order to detect the installed PyTorch version (if\ninstalled) in order not to clobber it, and so that we set the correct version of\ndependencies that are specific to the PyTorch version or other installed\nco-dependencies.\n\n\n2.2 uv Installation\nuv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.\nInstall uv if not already installed\ncurl -LsSf https://astral.sh/uv/install.sh | sh\nsource $HOME/.local/bin/env\nChoose your CUDA version to use with PyTorch; e.g. cu124, cu126, cu128,\nthen create the venv and activate\nexport UV_TORCH_BACKEND=cu126\nuv venv --no-project --relocatable\nsource .venv/bin/activate\nInstall PyTorch\n- PyTorch 2.6.0 recommended\nuv pip install packaging setuptools wheel\nuv pip install torch==2.6.0\nuv pip install awscli pydantic\nInstall axolotl from PyPi\nuv pip install --no-build-isolation axolotl[deepspeed,flash-attn]\n\n# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO\nuv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]\n\n\n2.3 Edge/Development Build\nFor the latest features between releases:\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\npip3 install -U packaging setuptools wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'\n\n\n2.4 Docker\ndocker run --gpus '\"all\"' --rm -it axolotlai/axolotl:main-latest\nFor development with Docker:\ndocker compose up -d\n\n\n\n\n\n\nTipAdvanced Docker Configuration\n\n\n\ndocker run --privileged --gpus '\"all\"' --shm-size 10g --rm -it \\\n --name axolotl --ipc=host \\\n --ulimit memlock=-1 --ulimit stack=67108864 \\\n --mount type=bind,src=\"${PWD}\",target=/workspace/axolotl \\\n -v ${HOME}/.cache/huggingface:/root/.cache/huggingface \\\n axolotlai/axolotl:main-latest\n\n\n\n\n\n\n\n\nImportant\n\n\n\nFor Blackwell GPUs, please use axolotlai/axolotl:main-py3.11-cu128-2.9.1 or the cloud variant axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1.\n\n\nPlease refer to the Docker documentation for more information on the different Docker images that are available.",
"crumbs": [
"Getting Started",
"Installation"
]
},
{
"objectID": "docs/installation.html#sec-cloud",
"href": "docs/installation.html#sec-cloud",
"title": "Installation",
"section": "3 Cloud Environments",
"text": "3 Cloud Environments\n\n3.1 Cloud GPU Providers\nFor providers supporting Docker:\n\nUse axolotlai/axolotl-cloud:main-latest\nAvailable on:\n\nRunPod\nVast.ai\nPRIME Intellect\nModal\nNovita\nJarvisLabs.ai\nLatitude.sh\n\n\n\n\n3.2 Google Colab",
"crumbs": [
"Getting Started",
"Installation"
]
},
{
"objectID": "docs/installation.html#sec-platform-specific",
"href": "docs/installation.html#sec-platform-specific",
"title": "Installation",
"section": "4 Platform-Specific Instructions",
"text": "4 Platform-Specific Instructions\n\n4.1 macOS\npip3 install --no-build-isolation -e '.'\nSee Section 6 for Mac-specific issues.\n\n\n4.2 Windows\n\n\n\n\n\n\nImportant\n\n\n\nWe recommend using WSL2 (Windows Subsystem for Linux) or Docker.",
"crumbs": [
"Getting Started",
"Installation"
]
},
{
"objectID": "docs/installation.html#sec-env-managers",
"href": "docs/installation.html#sec-env-managers",
"title": "Installation",
"section": "5 Environment Managers",
"text": "5 Environment Managers\n\n5.1 Conda/Pip venv\n\nInstall Python ≥3.11\nInstall PyTorch: https://pytorch.org/get-started/locally/\nInstall Axolotl:\npip3 install -U packaging setuptools wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'\n(Optional) Login to Hugging Face:\nhf auth login",
"crumbs": [
"Getting Started",
"Installation"
]
},
{
"objectID": "docs/installation.html#sec-troubleshooting",
"href": "docs/installation.html#sec-troubleshooting",
"title": "Installation",
"section": "6 Troubleshooting",
"text": "6 Troubleshooting\nIf you encounter installation issues, see our FAQ and Debugging Guide.",
"crumbs": [
"Getting Started",
"Installation"
]
},
{
"objectID": "docs/inference.html",
"href": "docs/inference.html",
"title": "Inference and Merging",
"section": "",
"text": "This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.",
"crumbs": [
"Getting Started",
"Inference and Merging"
]
},
{
"objectID": "docs/inference.html#sec-quickstart",
"href": "docs/inference.html#sec-quickstart",
"title": "Inference and Merging",
"section": "1 Quick Start",
"text": "1 Quick Start\n\n\n\n\n\n\nTip\n\n\n\nUse the same config used for training on inference/merging.\n\n\n\n1.1 Basic Inference\n\nLoRA ModelsFull Fine-tuned Models\n\n\naxolotl inference your_config.yml --lora-model-dir=\"./lora-output-dir\"\n\n\naxolotl inference your_config.yml --base-model=\"./completed-model\"",
"crumbs": [
"Getting Started",
"Inference and Merging"
]
},
{
"objectID": "docs/inference.html#sec-advanced",
"href": "docs/inference.html#sec-advanced",
"title": "Inference and Merging",
"section": "2 Advanced Usage",
"text": "2 Advanced Usage\n\n2.1 Gradio Interface\nLaunch an interactive web interface:\naxolotl inference your_config.yml --gradio\n\n\n2.2 File-based Prompts\nProcess prompts from a text file:\ncat /tmp/prompt.txt | axolotl inference your_config.yml \\\n --base-model=\"./completed-model\" --prompter=None\n\n\n2.3 Memory Optimization\nFor large models or limited memory:\naxolotl inference your_config.yml --load-in-8bit=True",
"crumbs": [
"Getting Started",
"Inference and Merging"
]
},
{
"objectID": "docs/inference.html#sec-merging",
"href": "docs/inference.html#sec-merging",
"title": "Inference and Merging",
"section": "3 Merging LoRA Weights",
"text": "3 Merging LoRA Weights\nMerge LoRA adapters with the base model:\naxolotl merge-lora your_config.yml --lora-model-dir=\"./completed-model\"\n\n3.1 Memory Management for Merging\n\nConfiguration OptionsForce CPU Merging\n\n\ngpu_memory_limit: 20GiB # Adjust based on your GPU\nlora_on_cpu: true # Process on CPU if needed\n\n\nCUDA_VISIBLE_DEVICES=\"\" axolotl merge-lora ...",
"crumbs": [
"Getting Started",
"Inference and Merging"
]
},
{
"objectID": "docs/inference.html#sec-tokenization",
"href": "docs/inference.html#sec-tokenization",
"title": "Inference and Merging",
"section": "4 Tokenization",
"text": "4 Tokenization\n\n4.1 Common Issues\n\n\n\n\n\n\nWarning\n\n\n\nTokenization mismatches between training and inference are a common source of problems.\n\n\nTo debug:\n\nCheck training tokenization:\n\naxolotl preprocess your_config.yml --debug\n\nVerify inference tokenization by decoding tokens before model input\nCompare token IDs between training and inference\n\n\n\n4.2 Special Tokens\nConfigure special tokens in your YAML:\nspecial_tokens:\n bos_token: \"<s>\"\n eos_token: \"</s>\"\n unk_token: \"<unk>\"\ntokens:\n - \"<|im_start|>\"\n - \"<|im_end|>\"",
"crumbs": [
"Getting Started",
"Inference and Merging"
]
},
{
"objectID": "docs/inference.html#sec-troubleshooting",
"href": "docs/inference.html#sec-troubleshooting",
"title": "Inference and Merging",
"section": "5 Troubleshooting",
"text": "5 Troubleshooting\n\n5.1 Common Problems\n\nMemory IssuesToken IssuesPerformance Issues\n\n\n\nUse 8-bit loading\nReduce batch sizes\nTry CPU offloading\n\n\n\n\nVerify special tokens\nCheck tokenizer settings\nCompare training and inference preprocessing\n\n\n\n\nVerify model loading\nCheck prompt formatting\nEnsure temperature/sampling settings\n\n\n\n\nFor more details, see our debugging guide.",
"crumbs": [
"Getting Started",
"Inference and Merging"
]
},
{
"objectID": "docs/getting-started.html",
"href": "docs/getting-started.html",
"title": "Quickstart",
"section": "",
"text": "This guide will walk you through your first model fine-tuning project with Axolotl.",
"crumbs": [
"Getting Started",
"Quickstart"
]
},
{
"objectID": "docs/getting-started.html#sec-quick-example",
"href": "docs/getting-started.html#sec-quick-example",
"title": "Quickstart",
"section": "1 Quick Example",
"text": "1 Quick Example\nLets start by fine-tuning a small language model using LoRA. This example uses a 1B parameter model to ensure it runs on most GPUs.\nAssuming axolotl is installed (if not, see our Installation Guide)\n\nDownload example configs:\n\naxolotl fetch examples\n\nRun the training:\n\naxolotl train examples/llama-3/lora-1b.yml\nThats it! Lets understand what just happened.",
"crumbs": [
"Getting Started",
"Quickstart"
]
},
{
"objectID": "docs/getting-started.html#sec-understanding",
"href": "docs/getting-started.html#sec-understanding",
"title": "Quickstart",
"section": "2 Understanding the Process",
"text": "2 Understanding the Process\n\n2.1 The Configuration File\nThe YAML configuration file controls everything about your training. Heres what (part of) our example config looks like:\nbase_model: NousResearch/Llama-3.2-1B\n\nload_in_8bit: true\nadapter: lora\n\ndatasets:\n - path: teknium/GPT4-LLM-Cleaned\n type: alpaca\ndataset_prepared_path: last_run_prepared\nval_set_size: 0.1\noutput_dir: ./outputs/lora-out\n\n\n\n\n\n\nTip\n\n\n\nload_in_8bit: true and adapter: lora enables LoRA adapter finetuning.\n\nTo perform Full finetuning, remove these two lines.\nTo perform QLoRA finetuning, replace with load_in_4bit: true and adapter: qlora.\n\n\n\nSee our config options for more details.\n\n\n2.2 Training\nWhen you run axolotl train, Axolotl:\n\nDownloads the base model\n(If specified) applies QLoRA/LoRA adapter layers\nLoads and processes the dataset\nRuns the training loop\nSaves the trained model and / or LoRA weights",
"crumbs": [
"Getting Started",
"Quickstart"
]
},
{
"objectID": "docs/getting-started.html#sec-custom",
"href": "docs/getting-started.html#sec-custom",
"title": "Quickstart",
"section": "3 Your First Custom Training",
"text": "3 Your First Custom Training\nLets modify the example for your own data:\n\nCreate a new config file my_training.yml:\n\nbase_model: NousResearch/Nous-Hermes-llama-1b-v1\n\nload_in_8bit: true\nadapter: lora\n\n# Training settings\nmicro_batch_size: 2\nnum_epochs: 3\nlearning_rate: 0.0003\n\n# Your dataset\ndatasets:\n - path: my_data.jsonl # Your local data file\n type: alpaca # Or other format\nThis specific config is for LoRA fine-tuning a model with instruction tuning data using\nthe alpaca dataset format, which has the following format:\n{\n \"instruction\": \"Write a description of alpacas.\",\n \"input\": \"\",\n \"output\": \"Alpacas are domesticated South American camelids...\"\n}\nPlease see our Dataset Formats for more dataset formats and how to\nformat them.\n\nPrepare your JSONL data in the specified format (in this case, the expected alpaca\nformat):\n\n{\"instruction\": \"Classify this text\", \"input\": \"I love this!\", \"output\": \"positive\"}\n{\"instruction\": \"Classify this text\", \"input\": \"Not good at all\", \"output\": \"negative\"}\n\nRun the training:\n\naxolotl train my_training.yml",
"crumbs": [
"Getting Started",
"Quickstart"
]
},
{
"objectID": "docs/getting-started.html#sec-common-tasks",
"href": "docs/getting-started.html#sec-common-tasks",
"title": "Quickstart",
"section": "4 Common Tasks",
"text": "4 Common Tasks\n\n\n\n\n\n\nTip\n\n\n\nThe same yaml file is used for training, inference, and merging.\n\n\n\n4.1 Testing Your Model\nAfter training, test your model:\naxolotl inference my_training.yml --lora-model-dir=\"./outputs/lora-out\"\nMore details can be found in Inference.\n\n\n4.2 Using a UI\nLaunch a Gradio interface:\naxolotl inference my_training.yml --lora-model-dir=\"./outputs/lora-out\" --gradio\n\n\n4.3 Preprocessing Data\nFor large datasets, preprocess first:\naxolotl preprocess my_training.yml\nPlease make sure to set dataset_prepared_path: in your config to set the path to save the prepared dataset.\nMore details can be found in Dataset Preprocessing.\n\n\n4.4 Merging LoRA weights\nTo merge the LoRA weights back into the base model, run:\naxolotl merge-lora my_training.yml --lora-model-dir=\"./outputs/lora-out\"\nThe merged model will be saved in the {output_dir}/merged directory.\nMore details can be found in Merging LoRA weights.",
"crumbs": [
"Getting Started",
"Quickstart"
]
},
{
"objectID": "docs/getting-started.html#sec-next-steps",
"href": "docs/getting-started.html#sec-next-steps",
"title": "Quickstart",
"section": "5 Next Steps",
"text": "5 Next Steps\nNow that you have the basics, explore these guides based on what you want to do:\nChoose your path:\n\nChoosing a Fine-Tuning Method — SFT vs LoRA vs QLoRA vs GRPO vs DPO, with hardware recommendations\n\nCore guides:\n\nDataset Loading — Loading datasets from various sources\nDataset Formats — Working with different data formats\nOptimizations — Flash attention, gradient checkpointing, sample packing\nTraining Stability & Debugging — Monitoring metrics, fixing NaN, OOM debugging\n\nAdvanced training methods:\n\nRLHF / Preference Learning — DPO, KTO, GRPO, EBFT\nGRPO Training — RL with custom rewards and vLLM generation\nvLLM Serving — Setting up vLLM for GRPO\n\nScaling up:\n\nMulti-GPU Training — DeepSpeed, FSDP, DDP\nMulti-Node Training — Distributed training across machines",
"crumbs": [
"Getting Started",
"Quickstart"
]
},
{
"objectID": "docs/telemetry.html",
"href": "docs/telemetry.html",
"title": "Telemetry",
"section": "",
"text": "Axolotl implements anonymous telemetry to help maintainers understand how the library\nis used and where users encounter issues. This data helps prioritize features, optimize\nperformance, and fix bugs.\n\n\nWe collect:\n\nSystem info: OS, Python version, Axolotl version, PyTorch version, Transformers\nversion, etc.\nHardware info: CPU count, memory, GPU count and models\nRuntime metrics: Training progress, memory usage, timing information\nUsage patterns: Models (from a whitelist) and configurations used\nError tracking: Stack traces and error messages (sanitized to remove personal\ninformation)\n\nPersonally identifiable information (PII) is not collected.\n\n\n\nTelemetry is implemented using PostHog and consists of:\n\naxolotl.telemetry.TelemetryManager: A singleton class that initializes the\ntelemetry system and provides methods for tracking events.\naxolotl.telemetry.errors.send_errors: A decorator that captures exceptions and\nsends sanitized stack traces.\naxolotl.telemetry.runtime_metrics.RuntimeMetricsTracker: A class that tracks\nruntime metrics during training.\naxolotl.telemetry.callbacks.TelemetryCallback: A Trainer callback that sends\nruntime metrics telemetry.\n\nThe telemetry system will block training startup for 10 seconds to ensure users are\naware of data collection, unless telemetry is explicitly enabled or disabled.\n\n\n\nTelemetry is enabled by default on an opt-out basis. To disable it, set\nAXOLOTL_DO_NOT_TRACK=1 or DO_NOT_TRACK=1.\nA warning message will be logged on start to clearly inform users about telemetry.\nWe will remove this after some period.\nTo hide the warning message about telemetry that is displayed on train, etc. startup,\nexplicitly set: AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) or AXOLOTL_DO_NOT_TRACK=1\n(explicitly disable telemetry).\n\n\n\n\nAll path-like config information is automatically redacted from telemetry data\nModel information is only collected for whitelisted organizations\n\nSee axolotl/telemetry/whitelist.yaml for the set of whitelisted organizations\n\nEach run generates a unique anonymous ID\n\nThis allows us to link different telemetry events in a single same training run\n\nTelemetry is only sent from the main process to avoid duplicate events",
"crumbs": [
"Getting Started",
"Telemetry"
]
},
{
"objectID": "docs/telemetry.html#data-collection",
"href": "docs/telemetry.html#data-collection",
"title": "Telemetry",
"section": "",
"text": "We collect:\n\nSystem info: OS, Python version, Axolotl version, PyTorch version, Transformers\nversion, etc.\nHardware info: CPU count, memory, GPU count and models\nRuntime metrics: Training progress, memory usage, timing information\nUsage patterns: Models (from a whitelist) and configurations used\nError tracking: Stack traces and error messages (sanitized to remove personal\ninformation)\n\nPersonally identifiable information (PII) is not collected.",
"crumbs": [
"Getting Started",
"Telemetry"
]
},
{
"objectID": "docs/telemetry.html#implementation",
"href": "docs/telemetry.html#implementation",
"title": "Telemetry",
"section": "",
"text": "Telemetry is implemented using PostHog and consists of:\n\naxolotl.telemetry.TelemetryManager: A singleton class that initializes the\ntelemetry system and provides methods for tracking events.\naxolotl.telemetry.errors.send_errors: A decorator that captures exceptions and\nsends sanitized stack traces.\naxolotl.telemetry.runtime_metrics.RuntimeMetricsTracker: A class that tracks\nruntime metrics during training.\naxolotl.telemetry.callbacks.TelemetryCallback: A Trainer callback that sends\nruntime metrics telemetry.\n\nThe telemetry system will block training startup for 10 seconds to ensure users are\naware of data collection, unless telemetry is explicitly enabled or disabled.",
"crumbs": [
"Getting Started",
"Telemetry"
]
},
{
"objectID": "docs/telemetry.html#opt-out-mechanism",
"href": "docs/telemetry.html#opt-out-mechanism",
"title": "Telemetry",
"section": "",
"text": "Telemetry is enabled by default on an opt-out basis. To disable it, set\nAXOLOTL_DO_NOT_TRACK=1 or DO_NOT_TRACK=1.\nA warning message will be logged on start to clearly inform users about telemetry.\nWe will remove this after some period.\nTo hide the warning message about telemetry that is displayed on train, etc. startup,\nexplicitly set: AXOLOTL_DO_NOT_TRACK=0 (enable telemetry) or AXOLOTL_DO_NOT_TRACK=1\n(explicitly disable telemetry).",
"crumbs": [
"Getting Started",
"Telemetry"
]
},
{
"objectID": "docs/telemetry.html#privacy",
"href": "docs/telemetry.html#privacy",
"title": "Telemetry",
"section": "",
"text": "All path-like config information is automatically redacted from telemetry data\nModel information is only collected for whitelisted organizations\n\nSee axolotl/telemetry/whitelist.yaml for the set of whitelisted organizations\n\nEach run generates a unique anonymous ID\n\nThis allows us to link different telemetry events in a single same training run\n\nTelemetry is only sent from the main process to avoid duplicate events",
"crumbs": [
"Getting Started",
"Telemetry"
]
},
{
"objectID": "src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.html",
"href": "src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.html",
"title": "Axolotl",
"section": "",
"text": "Acknowledgements\nPortions of this Cut Cross Entropy Software may utilize the following copyrighted\nmaterial, the use of which is hereby acknowledged.\n\nPyTorch\nFrom PyTorch:\n\nCopyright (c) 2016- Facebook, Inc (Adam Paszke)\nCopyright (c) 2014- Facebook, Inc (Soumith Chintala)\nCopyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\nCopyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)\nCopyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\nCopyright (c) 2011-2013 NYU (Clement Farabet)\nCopyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\nCopyright (c) 2006 Idiap Research Institute (Samy Bengio)\nCopyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\nFrom Caffe2:\n\nCopyright (c) 2016-present, Facebook Inc. All rights reserved.\n\nAll contributions by Facebook:\nCopyright (c) 2016 Facebook Inc.\n\nAll contributions by Google:\nCopyright (c) 2015 Google Inc.\nAll rights reserved.\n\nAll contributions by Yangqing Jia:\nCopyright (c) 2015 Yangqing Jia\nAll rights reserved.\n\nAll contributions by Kakao Brain:\nCopyright 2019-2020 Kakao Brain\n\nAll contributions by Cruise LLC:\nCopyright (c) 2022 Cruise LLC.\nAll rights reserved.\n\nAll contributions by Arm:\nCopyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates\n\nAll contributions from Caffe:\nCopyright(c) 2013, 2014, 2015, the respective contributors\nAll rights reserved.\n\nAll other contributions:\nCopyright(c) 2015, 2016 the respective contributors\nAll rights reserved.\n\nCaffe2 uses a copyright model similar to Caffe: each contributor holds\ncopyright over their contributions to Caffe2. The project versioning records\nall such contribution and copyright details. If a contributor wants to further\nmark their specific copyright on a particular contribution, they should\nindicate their copyright solely in the commit message of the change when it is\ncommitted.\n\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright\nnotice, this list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright\nnotice, this list of conditions and the following disclaimer in the\ndocumentation and/or other materials provided with the distribution.\n\n3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\nand IDIAP Research Institute nor the names of its contributors may be\nused to endorse or promote products derived from this software without\nspecific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\nARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\nLIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\nCONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\nSUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\nINTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\nCONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\nARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\nPOSSIBILITY OF SUCH DAMAGE.\nTriton\n/*\n* Copyright 2018-2020 Philippe Tillet\n* Copyright 2020-2022 OpenAI\n*\n* Permission is hereby granted, free of charge, to any person obtaining\n* a copy of this software and associated documentation files\n* (the \"Software\"), to deal in the Software without restriction,\n* including without limitation the rights to use, copy, modify, merge,\n* publish, distribute, sublicense, and/or sell copies of the Software,\n* and to permit persons to whom the Software is furnished to do so,\n* subject to the following conditions:\n*\n* The above copyright notice and this permission notice shall be\n* included in all copies or substantial portions of the Software.\n*\n* THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\n* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.\n* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY\n* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,\n* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE\n* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n*/\nTransformers\nCopyright 2018- The Hugging Face team. All rights reserved.\n\n Apache License\n Version 2.0, January 2004\n http://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n1. Definitions.\n\n \"License\" shall mean the terms and conditions for use, reproduction,\n and distribution as defined by Sections 1 through 9 of this document.\n\n \"Licensor\" shall mean the copyright owner or entity authorized by\n the copyright owner that is granting the License.\n\n \"Legal Entity\" shall mean the union of the acting entity and all\n other entities that control, are controlled by, or are under common\n control with that entity. For the purposes of this definition,\n \"control\" means (i) the power, direct or indirect, to cause the\n direction or management of such entity, whether by contract or\n otherwise, or (ii) ownership of fifty percent (50%) or more of the\n outstanding shares, or (iii) beneficial ownership of such entity.\n\n \"You\" (or \"Your\") shall mean an individual or Legal Entity\n exercising permissions granted by this License.\n\n \"Source\" form shall mean the preferred form for making modifications,\n including but not limited to software source code, documentation\n source, and configuration files.\n\n \"Object\" form shall mean any form resulting from mechanical\n transformation or translation of a Source form, including but\n not limited to compiled object code, generated documentation,\n and conversions to other media types.\n\n \"Work\" shall mean the work of authorship, whether in Source or\n Object form, made available under the License, as indicated by a\n copyright notice that is included in or attached to the work\n (an example is provided in the Appendix below).\n\n \"Derivative Works\" shall mean any work, whether in Source or Object\n form, that is based on (or derived from) the Work and for which the\n editorial revisions, annotations, elaborations, or other modifications\n represent, as a whole, an original work of authorship. For the purposes\n of this License, Derivative Works shall not include works that remain\n separable from, or merely link (or bind by name) to the interfaces of,\n the Work and Derivative Works thereof.\n\n \"Contribution\" shall mean any work of authorship, including\n the original version of the Work and any modifications or additions\n to that Work or Derivative Works thereof, that is intentionally\n submitted to Licensor for inclusion in the Work by the copyright owner\n or by an individual or Legal Entity authorized to submit on behalf of\n the copyright owner. For the purposes of this definition, \"submitted\"\n means any form of electronic, verbal, or written communication sent\n to the Licensor or its representatives, including but not limited to\n communication on electronic mailing lists, source code control systems,\n and issue tracking systems that are managed by, or on behalf of, the\n Licensor for the purpose of discussing and improving the Work, but\n excluding communication that is conspicuously marked or otherwise\n designated in writing by the copyright owner as \"Not a Contribution.\"\n\n \"Contributor\" shall mean Licensor and any individual or Legal Entity\n on behalf of whom a Contribution has been received by Licensor and\n subsequently incorporated within the Work.\n\n2. Grant of Copyright License. Subject to the terms and conditions of\n this License, each Contributor hereby grants to You a perpetual,\n worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n copyright license to reproduce, prepare Derivative Works of,\n publicly display, publicly perform, sublicense, and distribute the\n Work and such Derivative Works in Source or Object form.\n\n3. Grant of Patent License. Subject to the terms and conditions of\n this License, each Contributor hereby grants to You a perpetual,\n worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n (except as stated in this section) patent license to make, have made,\n use, offer to sell, sell, import, and otherwise transfer the Work,\n where such license applies only to those patent claims licensable\n by such Contributor that are necessarily infringed by their\n Contribution(s) alone or by combination of their Contribution(s)\n with the Work to which such Contribution(s) was submitted. If You\n institute patent litigation against any entity (including a\n cross-claim or counterclaim in a lawsuit) alleging that the Work\n or a Contribution incorporated within the Work constitutes direct\n or contributory patent infringement, then any patent licenses\n granted to You under this License for that Work shall terminate\n as of the date such litigation is filed.\n\n4. Redistribution. You may reproduce and distribute copies of the\n Work or Derivative Works thereof in any medium, with or without\n modifications, and in Source or Object form, provided that You\n meet the following conditions:\n\n (a) You must give any other recipients of the Work or\n Derivative Works a copy of this License; and\n\n (b) You must cause any modified files to carry prominent notices\n stating that You changed the files; and\n\n (c) You must retain, in the Source form of any Derivative Works\n that You distribute, all copyright, patent, trademark, and\n attribution notices from the Source form of the Work,\n excluding those notices that do not pertain to any part of\n the Derivative Works; and\n\n (d) If the Work includes a \"NOTICE\" text file as part of its\n distribution, then any Derivative Works that You distribute must\n include a readable copy of the attribution notices contained\n within such NOTICE file, excluding those notices that do not\n pertain to any part of the Derivative Works, in at least one\n of the following places: within a NOTICE text file distributed\n as part of the Derivative Works; within the Source form or\n documentation, if provided along with the Derivative Works; or,\n within a display generated by the Derivative Works, if and\n wherever such third-party notices normally appear. The contents\n of the NOTICE file are for informational purposes only and\n do not modify the License. You may add Your own attribution\n notices within Derivative Works that You distribute, alongside\n or as an addendum to the NOTICE text from the Work, provided\n that such additional attribution notices cannot be construed\n as modifying the License.\n\n You may add Your own copyright statement to Your modifications and\n may provide additional or different license terms and conditions\n for use, reproduction, or distribution of Your modifications, or\n for any such Derivative Works as a whole, provided Your use,\n reproduction, and distribution of the Work otherwise complies with\n the conditions stated in this License.\n\n5. Submission of Contributions. Unless You explicitly state otherwise,\n any Contribution intentionally submitted for inclusion in the Work\n by You to the Licensor shall be under the terms and conditions of\n this License, without any additional terms or conditions.\n Notwithstanding the above, nothing herein shall supersede or modify\n the terms of any separate license agreement you may have executed\n with Licensor regarding such Contributions.\n\n6. Trademarks. This License does not grant permission to use the trade\n names, trademarks, service marks, or product names of the Licensor,\n except as required for reasonable and customary use in describing the\n origin of the Work and reproducing the content of the NOTICE file.\n\n7. Disclaimer of Warranty. Unless required by applicable law or\n agreed to in writing, Licensor provides the Work (and each\n Contributor provides its Contributions) on an \"AS IS\" BASIS,\n WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n implied, including, without limitation, any warranties or conditions\n of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n PARTICULAR PURPOSE. You are solely responsible for determining the\n appropriateness of using or redistributing the Work and assume any\n risks associated with Your exercise of permissions under this License.\n\n8. Limitation of Liability. In no event and under no legal theory,\n whether in tort (including negligence), contract, or otherwise,\n unless required by applicable law (such as deliberate and grossly\n negligent acts) or agreed to in writing, shall any Contributor be\n liable to You for damages, including any direct, indirect, special,\n incidental, or consequential damages of any character arising as a\n result of this License or out of the use or inability to use the\n Work (including but not limited to damages for loss of goodwill,\n work stoppage, computer failure or malfunction, or any and all\n other commercial damages or losses), even if such Contributor\n has been advised of the possibility of such damages.\n\n9. Accepting Warranty or Additional Liability. While redistributing\n the Work or Derivative Works thereof, You may choose to offer,\n and charge a fee for, acceptance of support, warranty, indemnity,\n or other liability obligations and/or rights consistent with this\n License. However, in accepting such obligations, You may act only\n on Your own behalf and on Your sole responsibility, not on behalf\n of any other Contributor, and only if You agree to indemnify,\n defend, and hold each Contributor harmless for any liability\n incurred by, or claims asserted against, such Contributor by reason\n of your accepting any such warranty or additional liability.\n\nEND OF TERMS AND CONDITIONS\n\nAPPENDIX: How to apply the Apache License to your work.\n\n To apply the Apache License to your work, attach the following\n boilerplate notice, with the fields enclosed by brackets \"[]\"\n replaced with your own identifying information. (Don't include\n the brackets!) The text should be enclosed in the appropriate\n comment syntax for the file format. We also recommend that a\n file or class name and description of purpose be included on the\n same \"printed page\" as the copyright notice for easier\n identification within third-party archives.\n\nCopyright [yyyy] [name of copyright owner]\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License."
},
{
"objectID": "index.html",
"href": "index.html",
"title": "Axolotl",
"section": "",
"text": "A Free and Open Source LLM Fine-tuning Framework",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#latest-updates",
"href": "index.html#latest-updates",
"title": "Axolotl",
"section": "🎉 Latest Updates",
"text": "🎉 Latest Updates\n\n2026/03:\n\nNew model support has been added in Axolotl for Mistral Small 4, Qwen3.5, Qwen3.5 MoE, GLM-4.7-Flash, GLM-4.6V, and GLM-4.5-Air.\nMoE expert quantization support (via quantize_moe_experts: true) greatly reduces VRAM when training MoE models (FSDP2 compat).\n\n2026/02:\n\nScatterMoE LoRA support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.\nAxolotl now has support for SageAttention and GDPO (Generalized DPO).\n\n2026/01:\n\nNew integration for EAFT (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and Scalable Softmax, improves long context in attention.\n\n2025/12:\n\nAxolotl now includes support for Kimi-Linear, Plano-Orchestrator, MiMo, InternVL 3.5, Olmo3, Trinity, and Ministral3.\nDistributed Muon Optimizer support has been added for FSDP2 pretraining.\n\n2025/10: New model support has been added in Axolotl for: Qwen3 Next, Qwen2.5-vl, Qwen3-vl, Qwen3, Qwen3MoE, Granite 4, HunYuan, Magistral 2509, Apertus, and Seed-OSS.\n\n\n\nExpand older updates\n\n\n2025/09: Axolotl now has text diffusion training. Read more here.\n2025/08: QAT has been updated to include NVFP4 support. See PR.\n2025/07:\n\nND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the blog post for more info.\nAxolotl adds more models: GPT-OSS, Gemma 3n, Liquid Foundation Model 2 (LFM2), and Arcee Foundation Models (AFM).\nFP8 finetuning with fp8 gather op is now possible in Axolotl via torchao. Get started here!\nVoxtral, Magistral 1.1, and Devstral with mistral-common tokenizer support has been integrated in Axolotl!\nTiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See examples for using ALST with Axolotl!\n\n2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See docs to start training your own Magistral models with Axolotl!\n2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the docs to learn more!\n2025/04: Llama 4 support has been added in Axolotl. See docs to start training your own Llama 4 models with Axolotls linearized version!\n2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the blog and docs to learn how to scale your context length when fine-tuning.\n2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the docs to fine-tune your own!\n2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the docs to give it a try.\n2025/02: Axolotl has added GRPO support. Dive into our blog and GRPO example and have some fun!\n2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See docs.",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#overview",
"href": "index.html#overview",
"title": "Axolotl",
"section": "✨ Overview",
"text": "✨ Overview\nAxolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).\nFeatures:\n\nMultiple Model Support: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.\nMultimodal Training: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.\nTraining Methods: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).\nEasy Configuration: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.\nPerformance Optimizations: Multipacking, Flash Attention 2/3/4, Xformers, Flex Attention, SageAttention, Liger Kernel, Cut Cross Entropy, ScatterMoE, Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!\nFlexible Dataset Handling: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.\nCloud Ready: We ship Docker images and also PyPI packages for use on cloud platforms and local hardware.",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#quick-start---llm-fine-tuning-in-minutes",
"href": "index.html#quick-start---llm-fine-tuning-in-minutes",
"title": "Axolotl",
"section": "🚀 Quick Start - LLM Fine-tuning in Minutes",
"text": "🚀 Quick Start - LLM Fine-tuning in Minutes\nRequirements:\n\nNVIDIA GPU (Ampere or newer for bf16 and Flash Attention) or AMD GPU\nPython 3.11\nPyTorch ≥2.9.1\n\n\nGoogle Colab\n\n\n\nOpen In Colab\n\n\n\n\nInstallation\n\nUsing pip\npip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation axolotl[flash-attn,deepspeed]\n\n# Download example axolotl configs, deepspeed configs\naxolotl fetch examples\naxolotl fetch deepspeed_configs # OPTIONAL\n\n\nUsing Docker\nInstalling with Docker can be less error prone than installing in your own environment.\ndocker run --gpus '\"all\"' --rm -it axolotlai/axolotl:main-latest\nOther installation approaches are described here.\n\n\nCloud Providers\n\n\nRunPod\nVast.ai\nPRIME Intellect\nModal\nNovita\nJarvisLabs.ai\nLatitude.sh\n\n\n\n\n\nYour First Fine-tune\n# Fetch axolotl examples\naxolotl fetch examples\n\n# Or, specify a custom path\naxolotl fetch examples --dest path/to/folder\n\n# Train a model using LoRA\naxolotl train examples/llama-3/lora-1b.yml\nThats it! Check out our Getting Started Guide for a more detailed walkthrough.",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#documentation",
"href": "index.html#documentation",
"title": "Axolotl",
"section": "📚 Documentation",
"text": "📚 Documentation\n\nInstallation Options - Detailed setup instructions for different environments\nConfiguration Guide - Full configuration options and examples\nDataset Loading - Loading datasets from various sources\nDataset Guide - Supported formats and how to use them\nMulti-GPU Training\nMulti-Node Training\nMultipacking\nAPI Reference - Auto-generated code documentation\nFAQ - Frequently asked questions",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#getting-help",
"href": "index.html#getting-help",
"title": "Axolotl",
"section": "🤝 Getting Help",
"text": "🤝 Getting Help\n\nJoin our Discord community for support\nCheck out our Examples directory\nRead our Debugging Guide\nNeed dedicated support? Please contact ✉wing@axolotl.ai for options",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#contributing",
"href": "index.html#contributing",
"title": "Axolotl",
"section": "🌟 Contributing",
"text": "🌟 Contributing\nContributions are welcome! Please see our Contributing Guide for details.",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#telemetry",
"href": "index.html#telemetry",
"title": "Axolotl",
"section": "📈 Telemetry",
"text": "📈 Telemetry\nAxolotl has opt-out telemetry that helps us understand how the project is being used\nand prioritize improvements. We collect basic system information, model types, and\nerror rates—never personal data or file paths. Telemetry is enabled by default. To\ndisable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our telemetry documentation.",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#sponsors",
"href": "index.html#sponsors",
"title": "Axolotl",
"section": "❤️ Sponsors",
"text": "❤️ Sponsors\nInterested in sponsoring? Contact us at wing@axolotl.ai",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#citing-axolotl",
"href": "index.html#citing-axolotl",
"title": "Axolotl",
"section": "📝 Citing Axolotl",
"text": "📝 Citing Axolotl\nIf you use Axolotl in your research or projects, please cite it as follows:\n@software{axolotl,\n title = {Axolotl: Open Source LLM Post-Training},\n author = {{Axolotl maintainers and contributors}},\n url = {https://github.com/axolotl-ai-cloud/axolotl},\n license = {Apache-2.0},\n year = {2023}\n}",
"crumbs": [
"Home"
]
},
{
"objectID": "index.html#license",
"href": "index.html#license",
"title": "Axolotl",
"section": "📜 License",
"text": "📜 License\nThis project is licensed under the Apache 2.0 License - see the LICENSE file for details.",
"crumbs": [
"Home"
]
},
{
"objectID": "examples/colab-notebooks/colab-axolotl-example.html",
"href": "examples/colab-notebooks/colab-axolotl-example.html",
"title": "Fine-Tune Qwen3 14B with Axolotl",
"section": "",
"text": "Axolotl is the most performant LLM post-training framework available, delivering faster training with efficient, consistent and stable performance. Train your workload and ship your product 30% faster; saving you both time and money."
},
{
"objectID": "examples/colab-notebooks/colab-axolotl-example.html#demo-talk-like-a-pirate",
"href": "examples/colab-notebooks/colab-axolotl-example.html#demo-talk-like-a-pirate",
"title": "Fine-Tune Qwen3 14B with Axolotl",
"section": "Demo: Talk Like a Pirate",
"text": "Demo: Talk Like a Pirate\nIn this demo, we are training the model to respond like a pirate. This was chosen as a way to easily show how to train a model to respond in a certain style of your choosing (without being prompted) and is quite easy to validate within the scope of a Colab.\n\nUpload your own dataset or use a Huggingface dataset\nYou can choose to use your own JSONL file from your own Google Drive; for example downloading the Pirate-Ultrachat JSONL to your Google Drive. JSONL datasets should be formatted similar to the OpenAI dataset format.\nYou can also simply use the winglian/pirate-ultrachat-10k dataset directly.\n\n# Default to HF dataset location\ndataset_id = \"winglian/pirate-ultrachat-10k\"\nuploaded = {}\n\n\nimport os\n\n# Optionally, upload your own JSONL to your Google Drive\nGOOGLE_DRIVE_PATH = \"\" # ex: \"MyDrive/Colab\\ Notebooks/train.jsonl\"\n\n# \"Select All\" permissions, or you may get the error:\n# \"MessageError: Error: credential propagation was unsuccessful\"\nif GOOGLE_DRIVE_PATH:\n from google.colab import drive\n\n # Mount your Google Drive\n GOOGLE_DRIVE_MNT = \"/content/drive/\"\n drive.mount(GOOGLE_DRIVE_MNT, force_remount=True)\n tmp_path = os.path.join(GOOGLE_DRIVE_MNT, GOOGLE_DRIVE_PATH.lstrip(\"/\"))\n # make sure file exists\n if not os.path.isfile(tmp_path):\n raise ValueError(f\"File {tmp_path} does not exist\")\n dataset_id = tmp_path"
},
{
"objectID": "src/axolotl/integrations/LICENSE.html",
"href": "src/axolotl/integrations/LICENSE.html",
"title": "Axolotl",
"section": "",
"text": "AXOLOTL COMMUNITY LICENSE AGREEMENT\nThis Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and\nany individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms\nand conditions set forth in this Agreement.\n\nDefinitions\n1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.\n1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,\nwhich may be licensed separately by their respective authors and/or licensors.\n1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at\nhttps://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which\npermits Plugin Integrations to integrate with the Axolotl service.\nGrant of License\n2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,\npublish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:\n- Licensee must comply with all the terms and conditions of this Agreement.\n- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial\nportions of the Software.\n2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.\nRestrictions\n3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for\nfree or for sale any services, platform, or equivalent to third parties for the purposes of allowing such\nthird parties to fine-tune artificial intelligence models.\n3.2 Licensee shall not:\n- Use the Software for any illegal or unauthorized purpose.\n- Reverse engineer, decompile, or disassemble the Software.\n- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.\n- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the\nSoftware or interfere with any third-party use of the Software.\n3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.\nIntellectual Property Rights\n4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee\nacknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to\nLicensee.\nDisclaimer of Warranty\n5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED\nTO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL\nTHE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF\nCONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\nDEALINGS IN THE SOFTWARE.\nTermination\n6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and\nconditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any\ncopies in its possession.\nGoverning Law\n7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,\nwithout regards to conflicts of laws provisions thereof.\nEntire Agreement\n8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter\nhereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning\nthe Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and\nLicensees continued use of the Software after any such updates shall constitute acceptance of updated terms\non a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any\nmaterial updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be\nbound by the terms and conditions of this Agreement.\n\nThis Agreement was last updated on August 23, 2024."
},
{
"objectID": "docs/batch_vs_grad.html",
"href": "docs/batch_vs_grad.html",
"title": "Batch size vs Gradient accumulation",
"section": "",
"text": "Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesnt significantly impact learning.\nThis method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Heres why:\n\nMemory Consumption with Batch Size: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.\nGradient Accumulation: With gradient accumulation, youre effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, youre only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.\n\nExample 1:\nMicro batch size: 3\nGradient accumulation steps: 2\nNumber of GPUs: 3\nTotal batch size = 3 * 2 * 3 = 18\n| GPU 1 | GPU 2 | GPU 3 |\n|----------------|----------------|----------------|\n| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |\n| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |\n|----------------|----------------|----------------|\n| → (accumulate) | → (accumulate) | → (accumulate) |\n|----------------|----------------|----------------|\n| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |\n| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |\n|----------------|----------------|----------------|\n| → (apply) | → (apply) | → (apply) |\n\nAccumulated gradient for the weight w1 after the second iteration (considering all GPUs):\nTotal gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18\n\nWeight update for w1:\nw1_new = w1_old - learning rate x (Total gradient for w1 / 18)\nExample 2:\nMicro batch size: 2\nGradient accumulation steps: 1\nNumber of GPUs: 3\nTotal batch size = 2 * 1 * 3 = 6\n| GPU 1 | GPU 2 | GPU 3 |\n|-----------|-----------|-----------|\n| S1, S2 | S3, S4 | S5, S6 |\n| e1, e2 | e3, e4 | e5, e6 |\n|-----------|-----------|-----------|\n| → (apply) | → (apply) | → (apply) |\n\nAccumulated gradient for the weight w1 (considering all GPUs):\nTotal gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6\n\nWeight update for w1:\nw1_new = w1_old - learning rate × (Total gradient for w1 / 6)",
"crumbs": [
"Core Concepts",
"Batch size vs Gradient accumulation"
]
},
{
"objectID": "docs/sequence_parallelism.html",
"href": "docs/sequence_parallelism.html",
"title": "Sequence Parallelism",
"section": "",
"text": "Sequence parallelism is a technique that splits sequences across multiple GPUs,\nallowing you to train with very long sequences that wouldnt fit on a single GPU. Each\nGPU processes a different portion of the sequence, and the results are aggregated\nthrough a ring communication pattern.",
"crumbs": [
"Advanced Features",
"Sequence Parallelism"
]
},
{
"objectID": "docs/sequence_parallelism.html#when-to-use-sequence-parallelism",
"href": "docs/sequence_parallelism.html#when-to-use-sequence-parallelism",
"title": "Sequence Parallelism",
"section": "When to Use Sequence Parallelism",
"text": "When to Use Sequence Parallelism\nUse sequence parallelism when:\n\nYou need to train with sequence lengths that dont fit into a single GPUs memory\nYou have multiple GPUs available\nYoure experiencing OOM (Out Of Memory) errors with long sequences",
"crumbs": [
"Advanced Features",
"Sequence Parallelism"
]
},
{
"objectID": "docs/sequence_parallelism.html#configuration",
"href": "docs/sequence_parallelism.html#configuration",
"title": "Sequence Parallelism",
"section": "Configuration",
"text": "Configuration\nTo enable sequence parallelism, add the following to your configuration file:\n# Set to a divisor (> 1) of the number of GPUs available\ncontext_parallel_size: 4 # Split sequences across 4 GPUs\n# Optional; strides across the key dimension. Larger values use more memory but should make training faster.\nheads_k_stride: 1\n# Optional; one of \"varlen_llama3\" or \"batch_ring\". Defaults to\n# \"varlen_llama3\" when `sample_packing: true`, and \"batch_ring\" otherwise.\nring_attn_func:\nThe context_parallel_size should be a divisor of the total number of GPUs. For example:\n\nWith 8 GPUs, valid values would be 2, 4, or 8\nWith 4 GPUs, valid values would be 2 or 4",
"crumbs": [
"Advanced Features",
"Sequence Parallelism"
]
},
{
"objectID": "docs/sequence_parallelism.html#implementation-details",
"href": "docs/sequence_parallelism.html#implementation-details",
"title": "Sequence Parallelism",
"section": "Implementation Details",
"text": "Implementation Details\nWhen sequence parallelism is enabled:\n\nEach sequence is divided into equal chunks across the GPUs in a sequence parallel group\nThe data collator handles the chunking of input_ids, attention_mask, labels, and position_ids\nPosition IDs are adjusted to maintain proper relative positions\nThe trainer uses special ring communication patterns for attention operations",
"crumbs": [
"Advanced Features",
"Sequence Parallelism"
]
},
{
"objectID": "docs/sequence_parallelism.html#requirements",
"href": "docs/sequence_parallelism.html#requirements",
"title": "Sequence Parallelism",
"section": "Requirements",
"text": "Requirements\nTo use sequence parallelism, you need:\n\nMultiple GPUs (at least 2)\nThe ring-flash-attn package. Install with:\n\npip install axolotl[ring-flash-attn] (preferred)\npip install ring-flash-attn>=0.1.4",
"crumbs": [
"Advanced Features",
"Sequence Parallelism"
]
},
{
"objectID": "docs/sequence_parallelism.html#limitations",
"href": "docs/sequence_parallelism.html#limitations",
"title": "Sequence Parallelism",
"section": "Limitations",
"text": "Limitations\n\nFlash attention must be enabled for this to work (flash_attention: true in config YAML)\nMay have a small performance overhead due to communication between GPUs",
"crumbs": [
"Advanced Features",
"Sequence Parallelism"
]
},
{
"objectID": "docs/sequence_parallelism.html#example",
"href": "docs/sequence_parallelism.html#example",
"title": "Sequence Parallelism",
"section": "Example",
"text": "Example\nbase_model: meta-llama/Llama-3-8B-Instruct\nsequence_len: 8192\n\n...\n\ncontext_parallel_size: 4 # Split each sequence into 4 parts, one per GPU\n# Optional; strides across the key dimension. Larger values use more memory but should make training faster.\nheads_k_stride: 1\n# Optional; one of \"varlen_llama3\" or \"batch_ring\". Defaults to\n# \"varlen_llama3\" when `sample_packing: true`, and \"batch_ring\" otherwise.\nring_attn_func:\n\n...\nThis will train the Llama 3 8B model with 8K context length, with each sequence split\ninto 2 subsequences of length 4096 across 2 GPUs.",
"crumbs": [
"Advanced Features",
"Sequence Parallelism"
]
},
{
"objectID": "docs/sequence_parallelism.html#sample-packing-with-sequence-parallelism",
"href": "docs/sequence_parallelism.html#sample-packing-with-sequence-parallelism",
"title": "Sequence Parallelism",
"section": "Sample Packing with Sequence Parallelism",
"text": "Sample Packing with Sequence Parallelism\nSequence parallelism is compatible with Axolotls sample packing functionality. When using both features together:\n\nSamples are first packed together\nThe packed sequences are then divided across GPUs in the sequence parallel group\nPosition IDs are automatically adjusted to maintain proper relative positions",
"crumbs": [
"Advanced Features",
"Sequence Parallelism"
]
},
{
"objectID": "docs/sequence_parallelism.html#effect-on-batch-size",
"href": "docs/sequence_parallelism.html#effect-on-batch-size",
"title": "Sequence Parallelism",
"section": "Effect on Batch Size",
"text": "Effect on Batch Size\nWhen using sequence parallelism, your effective global batch size is divided by the context_parallel_size. This happens because:\n\nEach group of context_parallel_size GPUs works on the same batch (just different parts of each sequence)\nThe number of batches processed per step decreases\n\nFor example:\n- With 8 GPUs and no sequence parallelism: 8 different batches processed per step\n- With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs)\n- If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4",
"crumbs": [
"Advanced Features",
"Sequence Parallelism"
]
},
{
"objectID": "docs/quantize.html",
"href": "docs/quantize.html",
"title": "Quantization with torchao",
"section": "",
"text": "Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the torchao library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).",
"crumbs": [
"How To Guides",
"Quantization with torchao"
]
},
{
"objectID": "docs/quantize.html#configuring-quantization-in-axolotl",
"href": "docs/quantize.html#configuring-quantization-in-axolotl",
"title": "Quantization with torchao",
"section": "Configuring Quantization in Axolotl",
"text": "Configuring Quantization in Axolotl\nQuantization is configured using the quantization key in your configuration file.\nbase_model: # The path to the model to quantize.\nquantization:\n activation_dtype: # Optional[str] = \"int8\". Fake quantization layout to use for activation quantization. Valid options are \"int4\", \"int8\", \"float8\"\n weight_dtype: # Optional[str] = \"int8\". Fake quantization layout to use for weight quantization. Valid options are \"int4\", \"fp8\", and \"nvfp4\".\n group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization\n quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.\n\noutput_dir: # The path to the output directory.\nOnce quantization is complete, your quantized model will be saved in the {output_dir}/quantized directory.\nYou may also use the quantize command to quantize a model which has been trained with QAT - you can do this by using the existing QAT configuration file which\nyou used to train the model:\n# qat.yml\nqat:\n activation_dtype: int8\n weight_dtype: int4\n group_size: 256\n\noutput_dir: # The path to the output directory used during training where the final checkpoint has been saved.\naxolotl quantize qat.yml\nThis ensures that an identical quantization configuration is used to quantize the model as was used to train it.\n\n\n\n\n\n\nNote\n\n\n\nIf you have configured pushing to hub with hub_model_id, your model hub name will have the quantization schema appended to it,\ne.g. axolotl-ai-cloud/qat-nvfp4-llama3B will become axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w",
"crumbs": [
"How To Guides",
"Quantization with torchao"
]
},
{
"objectID": "docs/docker.html",
"href": "docs/docker.html",
"title": "Docker",
"section": "",
"text": "This section describes the different Docker images that are released by AxolotlAI at Docker Hub.",
"crumbs": [
"Deployments",
"Docker"
]
},
{
"objectID": "docs/docker.html#base",
"href": "docs/docker.html#base",
"title": "Docker",
"section": "Base",
"text": "Base\nThe base image is the most minimal image that can install Axolotl. It is based on the nvidia/cuda image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.\n\nImage\naxolotlai/axolotl-base\nLink: Docker Hub\n\n\nTags format\nmain-base-py{python_version}-cu{cuda_version}-{pytorch_version}\nTags examples:\n\nmain-base-py3.11-cu128-2.8.0\nmain-base-py3.11-cu128-2.9.1",
"crumbs": [
"Deployments",
"Docker"
]
},
{
"objectID": "docs/docker.html#main",
"href": "docs/docker.html#main",
"title": "Docker",
"section": "Main",
"text": "Main\nThe main image is the image that is used to run Axolotl. It is based on the axolotlai/axolotl-base image and includes the Axolotl codebase, dependencies, and more.\n\nImage\naxolotlai/axolotl\nLink: Docker Hub\n\n\nTags format\n# on push to main\nmain-py{python_version}-cu{cuda_version}-{pytorch_version}\n\n# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)\nmain-latest\n\n# nightly build\n{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}\n\n# tagged release\n{version}\n\n\n\n\n\n\nTip\n\n\n\nThere may be some extra tags appended to the image, like -vllm which installs those packages.\n\n\nTags examples:\n\nmain-py3.11-cu128-2.8.0\nmain-py3.11-cu128-2.9.1\nmain-latest\nmain-20250303-py3.11-cu124-2.6.0\nmain-20250303-py3.11-cu126-2.6.0\n0.12.0",
"crumbs": [
"Deployments",
"Docker"
]
},
{
"objectID": "docs/docker.html#cloud",
"href": "docs/docker.html#cloud",
"title": "Docker",
"section": "Cloud",
"text": "Cloud\nThe cloud image is the image that is used to run Axolotl in the cloud. It is based on the axolotlai/axolotl image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.\n\n\n\n\n\n\nTip\n\n\n\nJupyter lab is run by default. Set JUPYTER_DISABLE=1 in the environment variables to disable it.\n\n\n\nImage\naxolotlai/axolotl-cloud\nLink: Docker Hub\n\n\nTags format\nThis uses the same tags as the main image.\n\n\nEnvironment variables\n\nJUPYTER_DISABLE: Disable Jupyter lab.\nJUPYTER_PASSWORD: Set a password for the Jupyter lab.\nPUBLIC_KEY / SSH_KEY: Add a public key for the SSH service.\n\n\n\nVolume mounts\n\n\n\n\n\n\nTip\n\n\n\nWe recommend mounting volumes to /workspace/data for data persistence. /workspace/axolotl contains the source code and is ephemeral.\n\n\n\n/workspace/data/axolotl-artifacts: Directory to store Axolotl artifacts.\n/workspace/data/huggingface-cache: Directory to store HuggingFace cache.",
"crumbs": [
"Deployments",
"Docker"
]
},
{
"objectID": "docs/docker.html#cloud-no-tmux",
"href": "docs/docker.html#cloud-no-tmux",
"title": "Docker",
"section": "Cloud-no-tmux",
"text": "Cloud-no-tmux\nThis is the same as the cloud image but without tmux.\n\nImage\naxolotlai/axolotl-cloud-term\nLink: Docker Hub\n\n\n\n\n\n\nNote\n\n\n\nThe naming may be a bit confusing as it has -term appended to the end.\n\n\n\n\nTags format\nThis uses the same tags as the cloud image.",
"crumbs": [
"Deployments",
"Docker"
]
},
{
"objectID": "docs/attention.html",
"href": "docs/attention.html",
"title": "Attention",
"section": "",
"text": "This is the default built-in attention in PyTorch.\nsdp_attention: true\nFor more details: PyTorch docs",
"crumbs": [
"Core Concepts",
"Attention"
]
},
{
"objectID": "docs/attention.html#sdp-attention",
"href": "docs/attention.html#sdp-attention",
"title": "Attention",
"section": "",
"text": "This is the default built-in attention in PyTorch.\nsdp_attention: true\nFor more details: PyTorch docs",
"crumbs": [
"Core Concepts",
"Attention"
]
},
{
"objectID": "docs/attention.html#flash-attention",
"href": "docs/attention.html#flash-attention",
"title": "Attention",
"section": "Flash Attention",
"text": "Flash Attention\nAxolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically\nbased on your installed packages and GPU.\nflash_attention: true\nFor more details: Flash Attention\n\nFlash Attention 2\nRequirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)\npip install flash-attn --no-build-isolation\n\n\n\n\n\n\nTip\n\n\n\nIf you get undefined symbol while training, ensure you installed PyTorch prior to Axolotl.\nAlternatively, try reinstall or downgrade a version.\n\n\n\n\nFlash Attention 3\nRequirements: Hopper only and CUDA 12.8 (recommended)\ngit clone https://github.com/Dao-AILab/flash-attention.git\ncd flash-attention/hopper\n\npython setup.py install\n\n\nFlash Attention 4\nRequirements: Hopper or Blackwell GPUs\npip install flash-attn-4\nOr from source:\ngit clone https://github.com/Dao-AILab/flash-attention.git\ncd flash-attention/flash_attn/cute\n\npip install -e .\n\n# FA2's flash_attn package includes a cute/ stub that shadows FA4.\n# Remove it so Python can find the real FA4 module:\nrm -r $(python -c \"import flash_attn; print(flash_attn.__path__[0])\")/cute\n\n\n\n\n\n\nNote\n\n\n\nHopper (SM90) users: The backward kernel is not yet included in the pip package. To use FA4\nfor training on Hopper, install from source using the instructions above.\n\n\n\n\n\n\n\n\nWarning\n\n\n\nFA4 only supports head dimensions up to 128 (d ≤ 128). The DeepSeek shape (192, 128) is\nalso supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions\nand falls back to FA2/3.\n\n\nFor more details: flash-attention/flash_attn/cute\n\n\nAMD\nRequirements: ROCm 6.0 and above.\nSee Flash Attention AMD docs.",
"crumbs": [
"Core Concepts",
"Attention"
]
},
{
"objectID": "docs/attention.html#flex-attention",
"href": "docs/attention.html#flex-attention",
"title": "Attention",
"section": "Flex Attention",
"text": "Flex Attention\nA flexible PyTorch API for attention used in combination with torch.compile.\nflex_attention: true\n\n# recommended\ntorch_compile: true\n\n\n\n\n\n\nNote\n\n\n\nWe recommend using latest stable version of PyTorch for best performance.\n\n\nFor more details: PyTorch docs",
"crumbs": [
"Core Concepts",
"Attention"
]
},
{
"objectID": "docs/attention.html#sageattention",
"href": "docs/attention.html#sageattention",
"title": "Attention",
"section": "SageAttention",
"text": "SageAttention\nAttention kernels with QK Int8 and PV FP16 accumulator.\nsage_attention: true\nRequirements: Ampere, Ada, or Hopper GPUs\npip install sageattention==2.2.0 --no-build-isolation\n\n\n\n\n\n\nWarning\n\n\n\nOnly LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See GitHub Issue.\n\n\nFor more details: Sage Attention\n\n\n\n\n\n\nNote\n\n\n\nWe do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.",
"crumbs": [
"Core Concepts",
"Attention"
]
},
{
"objectID": "docs/attention.html#xformers",
"href": "docs/attention.html#xformers",
"title": "Attention",
"section": "xFormers",
"text": "xFormers\nxformers_attention: true\n\n\n\n\n\n\nTip\n\n\n\nWe recommend using with Turing GPUs or below (such as on Colab).\n\n\nFor more details: xFormers",
"crumbs": [
"Core Concepts",
"Attention"
]
},
{
"objectID": "docs/attention.html#shifted-sparse-attention",
"href": "docs/attention.html#shifted-sparse-attention",
"title": "Attention",
"section": "Shifted Sparse Attention",
"text": "Shifted Sparse Attention\n\n\n\n\n\n\nWarning\n\n\n\nWe plan to deprecate this! If you use this feature, we recommend switching to methods above.\n\n\nRequirements: LLaMA model architecture\nflash_attention: true\ns2_attention: true\n\n\n\n\n\n\nTip\n\n\n\nNo sample packing support!",
"crumbs": [
"Core Concepts",
"Attention"
]
},
{
"objectID": "docs/unsloth.html",
"href": "docs/unsloth.html",
"title": "Unsloth",
"section": "",
"text": "Overview\nUnsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over\nstandard industry baselines.\n\n\n\n\n\n\nImportant\n\n\n\nDue to breaking changes in transformers v4.48.0, users will need to downgrade to <=v4.47.1 to use this patch.\nThis will later be deprecated in favor of LoRA Optimizations.\n\n\n\n\nInstallation\nThe following will install the correct unsloth and extras from source.\npython scripts/unsloth_install.py | sh\n\n\nUsage\nAxolotl exposes a few configuration options to try out unsloth and get most of the performance gains.\nOur unsloth integration is currently limited to the following model architectures:\n- llama\nThese options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning\nunsloth_lora_mlp: true\nunsloth_lora_qkv: true\nunsloth_lora_o: true\nThese options are composable and can be used with multi-gpu finetuning\nunsloth_cross_entropy_loss: true\nunsloth_rms_norm: true\nunsloth_rope: true\n\n\nLimitations\n\nSingle GPU only; e.g. no multi-gpu support\nNo deepspeed or FSDP support (requires multi-gpu)\nLoRA + QLoRA support only. No full fine tunes or fp8 support.\nLimited model architecture support. Llama, Phi, Gemma, Mistral only\nNo MoE support.",
"crumbs": [
"Advanced Features",
"Unsloth"
]
},
{
"objectID": "docs/qat.html",
"href": "docs/qat.html",
"title": "Quantization Aware Training (QAT)",
"section": "",
"text": "Quantization Aware Training (QAT) is a technique for improving the accuracy of models which are quantized\nby applying “fake” quantizations to the models weights (and optionally, activations) during training. This fake\nquantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually\nquantized, the accuracy loss is minimized. We use the quantization techniques implemented in torchao to provide\nsupport for QAT and post-training quantization (PTQ) in axolotl.\nWe recommend reviewing the excellent QAT tutorial in the torchtune library,\nand the QAT documentation in the torchao library, for more details.",
"crumbs": [
"How To Guides",
"Quantization Aware Training (QAT)"
]
},
{
"objectID": "docs/qat.html#overview",
"href": "docs/qat.html#overview",
"title": "Quantization Aware Training (QAT)",
"section": "",
"text": "Quantization Aware Training (QAT) is a technique for improving the accuracy of models which are quantized\nby applying “fake” quantizations to the models weights (and optionally, activations) during training. This fake\nquantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually\nquantized, the accuracy loss is minimized. We use the quantization techniques implemented in torchao to provide\nsupport for QAT and post-training quantization (PTQ) in axolotl.\nWe recommend reviewing the excellent QAT tutorial in the torchtune library,\nand the QAT documentation in the torchao library, for more details.",
"crumbs": [
"How To Guides",
"Quantization Aware Training (QAT)"
]
},
{
"objectID": "docs/qat.html#configuring-qat-in-axolotl",
"href": "docs/qat.html#configuring-qat-in-axolotl",
"title": "Quantization Aware Training (QAT)",
"section": "Configuring QAT in Axolotl",
"text": "Configuring QAT in Axolotl\nTo enable QAT in axolotl, add the following to your configuration file:\nqat:\n activation_dtype: # Optional[str] = \"int8\". Fake quantization layout to use for activation quantization. Valid options are \"int4\", \"int8\", \"float8\"\n weight_dtype: # Optional[str] = \"int8\". Fake quantization layout to use for weight quantization. Valid options are \"int4\", \"fp8\", and \"nvfp4\".\n group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization\n fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after\nWe support the following quantization schemas:\n\nInt4WeightOnly (requires the fbgemm-gpu extra when installing Axolotl)\nInt8DynamicActivationInt4Weight\nFloat8DynamicActivationFloat8Weight\nFloat8DynamicActivationInt4Weight\nNVFP4\n\nOnce you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the quantize command to do this.",
"crumbs": [
"How To Guides",
"Quantization Aware Training (QAT)"
]
},
{
"objectID": "docs/multi-node.html",
"href": "docs/multi-node.html",
"title": "Multi Node",
"section": "",
"text": "The below are three ways to train multi-node in Axolotl.",
"crumbs": [
"Deployments",
"Multi Node"
]
},
{
"objectID": "docs/multi-node.html#accelerate",
"href": "docs/multi-node.html#accelerate",
"title": "Multi Node",
"section": "Accelerate",
"text": "Accelerate\nYou will need to create a configuration for accelerate, either by using accelerate config and follow the instructions or you can use one of the preset below:\n~/.cache/huggingface/accelerate/default_config.yaml\ncompute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: FSDP\ndowncast_bf16: 'no'\nmachine_rank: 0 # Set to 0 for the main machine, increment by one for other machines\nmain_process_ip: 10.0.0.4 # Set to main machine's IP\nmain_process_port: 5000\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 2 # Change to the number of machines\nnum_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8)\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\nConfigure your model to use FSDP in the Axolotl yaml. For example:\nfsdp_version: 2\nfsdp_config:\n offload_params: true\n state_dict_type: FULL_STATE_DICT\n auto_wrap_policy: TRANSFORMER_BASED_WRAP\n transformer_layer_cls_to_wrap: LlamaDecoderLayer\n reshard_after_forward: true\nAll you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.",
"crumbs": [
"Deployments",
"Multi Node"
]
},
{
"objectID": "docs/multi-node.html#raytrain",
"href": "docs/multi-node.html#raytrain",
"title": "Multi Node",
"section": "Raytrain",
"text": "Raytrain\nPlease see ray train doc here.",
"crumbs": [
"Deployments",
"Multi Node"
]
},
{
"objectID": "docs/multi-node.html#torchrun",
"href": "docs/multi-node.html#torchrun",
"title": "Multi Node",
"section": "Torchrun",
"text": "Torchrun\nIf you are using Infiniband, we recommend torchrun to utilize the full bandwidth.\nSet the following env (change buffersize/socketname depending on your system):\nexport NCCL_IB_DISABLE=0\nexport NCCL_SOCKET_IFNAME=\"eth0,en,eth,em,bond\"\nexport NCCL_BUFFSIZE=2097152\nRun the following on each node:\n\nOption 1: New Axolotl CLI with launcher args (Recommended)\naxolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint \"$head_node_ip:$head_node_port\"\n\n\nOption 2: Direct torchrun (Legacy)\ntorchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint \"$head_node_ip:$head_node_port\" -m axolotl.cli.train config.yaml\nPlease make sure to substitute the placeholder variables:\n\nnum_nodes: Number of nodes (containing GPUs)\ngpu_per_node: Number of gpus per node\nhead_node_ip: IP of the head node (make sure other machines can connect to this)\nhead_node_port: Port of the head node (make sure other machines can connect to this. Default 29400)\nrdzv_id: A unique job ID that is used by the job across nodes.\n\nThe new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.\nMore info on the available configs can be found on the Pytorch docs here",
"crumbs": [
"Deployments",
"Multi Node"
]
},
{
"objectID": "docs/custom_integrations.html",
"href": "docs/custom_integrations.html",
"title": "Custom Integrations",
"section": "",
"text": "Axolotl adds custom features through integrations. They are located within the src/axolotl/integrations directory.\nTo enable them, please check the respective documentations.",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#cut-cross-entropy",
"href": "docs/custom_integrations.html#cut-cross-entropy",
"title": "Custom Integrations",
"section": "Cut Cross Entropy",
"text": "Cut Cross Entropy\nCut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.\nSee https://github.com/apple/ml-cross-entropy\n\nRequirements\n\nPyTorch 2.4.0 or higher\n\n\n\nInstallation\nRun the following command to install cut_cross_entropy[transformers] if you dont have it already.\n\nIf you are in dev environment\n\npython scripts/cutcrossentropy_install.py | sh\n\nIf you are installing from pip\n\npip3 uninstall -y cut-cross-entropy && pip3 install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\"\n\n\nUsage\nplugins:\n - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\n\nSupported Models\n\nafmoe\napertus\narcee\ncohere\ncohere2\ndeepseek_v3\nexaone4\ngemma\ngemma2\ngemma3\ngemma3_text\ngemma3n\ngemma3n_text\nglm\nglm4\nglm4_moe\nglm4_moe_lite\nglm46v\nglm4v\nglm4v_moe\nglm_image\nglm_moe_dsa\ngpt_oss\ngranite\ngranitemoe\ngranitemoehybrid\ngranitemoeshared\nhunyuan_v1_dense\nhunyuan_v1_moe\ninternvl\nkimi_linear\nlfm2\nlfm2_moe\nlfm2_vl\nllama\nllama4\nllama4_text\nllava\nministral\nministral3\nmistral\nmistral3\nmistral4\nmixtral\nmllama\nnemotron_h\nolmo\nolmo2\nolmo3\nolmoe\nphi\nphi3\nphi4_multimodal\nqwen2\nqwen2_5_vl\nqwen2_moe\nqwen2_vl\nqwen3\nqwen3_5\nqwen3_5_text\nqwen3_5_moe\nqwen3_5_moe_text\nqwen3_moe\nqwen3_next\nqwen3_vl\nqwen3_vl_moe\nseed_oss\nsmollm3\nstep3p5\nvoxtral\n\n\n\nCitation\n@article{wijmans2024cut,\n author = {Erik Wijmans and\n Brody Huval and\n Alexander Hertzberg and\n Vladlen Koltun and\n Philipp Kr\\\"ahenb\\\"uhl},\n title = {Cut Your Losses in Large-Vocabulary Language Models},\n journal = {arXiv},\n year = {2024},\n url = {https://arxiv.org/abs/2411.09009},\n}\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#densemixer",
"href": "docs/custom_integrations.html#densemixer",
"title": "Custom Integrations",
"section": "DenseMixer",
"text": "DenseMixer\nSee DenseMixer\nSimply add the following to your axolotl YAML config:\nplugins:\n - axolotl.integrations.densemixer.DenseMixerPlugin\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#diffusion-lm-training-plugin-for-axolotl",
"href": "docs/custom_integrations.html#diffusion-lm-training-plugin-for-axolotl",
"title": "Custom Integrations",
"section": "Diffusion LM Training Plugin for Axolotl",
"text": "Diffusion LM Training Plugin for Axolotl\nThis plugin enables diffusion language model training using an approach inspired by\nLLaDA (Large Language Diffusion Models) within Axolotl.\n\nOverview\nLLaDA is a diffusion-based approach to language model training that uses:\n- Random token masking during training instead of next-token prediction\n- Bidirectional attention to allow the model to attend to the full context\n- Importance weighting based on masking probabilities for stable training\nThis approach can lead to more robust language models with better understanding of\nbidirectional context.\n\n\nInstallation\nThe plugin is included with Axolotl. See our\ninstallation docs.\n\n\nQuickstart\nTrain with an example config (Llama3.2 1B):\n- Pretrain: axolotl train examples/llama-3/diffusion-3.2-1b-pretrain.yaml\n- SFT: axolotl train examples/llama-3/diffusion-3.2-1b-sft.yaml\n\n\nBasic Configuration\nYou can also modify your existing configs to enable / customize diffusion training.\nAdd the following to your Axolotl config:\nplugins:\n - axolotl.integrations.diffusion.DiffusionPlugin\nAnd, configure the nested diffusion block (defaults shown):\ndiffusion:\n noise_schedule: linear # or \"cosine\"\n min_mask_ratio: 0.1\n max_mask_ratio: 0.9\n num_diffusion_steps: 128\n eps: 1e-3\n importance_weighting: true\n\n # Mask token (training auto-adds if missing, avoid pad/eos)\n mask_token_str: \"<|diffusion_mask|>\"\n # Or use an existing special token id (e.g., 128002 for Llama-3.x)\n # mask_token_id: 128002\n\n # Sample generation during training (optional)\n generate_samples: true\n generation_interval: 100\n num_generation_samples: 3\n generation_steps: 128\n generation_temperature: 0.0\n generation_max_length: 100\n\n\nSupported Models\nAny models that support 4D attention masks should work out of the box. If not, please\ncreate an issue or open a\nPR!\n\n\nHow It Works\n\n\nRandom Masking\nDuring training, tokens are randomly masked:\n- Sample timestep t uniformly from [0, 1]\n- Calculate masking probability: p = (1 - eps) * t + eps\n- Randomly mask tokens with probability p\n\n\nDiffusion Loss\nLoss is computed only on masked tokens with (optional) importance weighting:\nloss = sum(cross_entropy(pred, target) / p_mask) / total_tokens\n\n\nSample Generation\nWhen diffusion.generate_samples: true, the plugin generates samples during training:\nSample 1:\n Original (45 tokens): The quick brown fox jumps over the lazy dog...\n Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...\n Generated: The quick brown fox jumps over the lazy dog...\nSamples are logged to console and wandb (if enabled).\n\n\nInference\nDiffusion inference is integrated into the standard Axolotl CLI. Use the same config\nyou trained with and run:\naxolotl inference path/to/your-config.yaml\nOptionally, pass --gradio to use a simple web interface.\nInteractive controls (prefix the prompt with commands):\n- :complete N → completion mode with N new masked tokens appended (default 64)\n- :mask R → random masking mode with target mask ratio R in [0.0, 1.0]\nExample session:\n================================================================================\nCommands:\n:complete N -> completion mode with N tokens (default 64)\n:mask R -> random masking with ratio R (0.01.0)\n================================================================================\nGive me an instruction (Ctrl + D to submit):\n\n:mask 0.4 The quick brown fox jumps over the lazy dog\n\nMasked (40.0%):\nThe [MASK] brown [MASK] jumps over the [MASK] dog\n\nGenerated:\nThe quick brown fox jumps over the loud dog\n\n\nMetrics and Monitoring\nThe plugin adds (or modifies) several metrics to track diffusion training:\n\ntrain/loss: Weighted diffusion loss\ntrain/accuracy: Accuracy on masked tokens\ntrain/mask_ratio: Average fraction of tokens masked\ntrain/num_masked_tokens: Number of tokens masked\ntrain/avg_p_mask: Average masking probability\ntrain/ce_loss: Unweighted cross-entropy loss\ntrain/importance_weight_avg: Average importance weight\n\n\n\nLimitations\n\nNo flash attention support\nNo RL training support\n\n\n\nReferences\n\nLLaDA Paper\nAxolotl Documentation\nAPI reference for plugin\n\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#grokfast",
"href": "docs/custom_integrations.html#grokfast",
"title": "Custom Integrations",
"section": "Grokfast",
"text": "Grokfast\nSee https://github.com/ironjr/grokfast\n\nUsage\nplugins:\n - axolotl.integrations.grokfast.GrokfastPlugin\n\ngrokfast_alpha: 2.0\ngrokfast_lamb: 0.98\n\n\nCitation\n@article{lee2024grokfast,\n title={{Grokfast}: Accelerated Grokking by Amplifying Slow Gradients},\n author={Lee, Jaerin and Kang, Bong Gyun and Kim, Kihoon and Lee, Kyoung Mu},\n journal={arXiv preprint arXiv:2405.20233},\n year={2024}\n}\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#kernels-integration",
"href": "docs/custom_integrations.html#kernels-integration",
"title": "Custom Integrations",
"section": "Kernels Integration",
"text": "Kernels Integration\nMoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, batched_mm and grouped_mm were integrated as built-in options via the experts_implementation config kwarg:\nclass ExpertsInterface(GeneralInterface):\n _global_mapping = {\n \"batched_mm\": batched_mm_experts_forward,\n \"grouped_mm\": grouped_mm_experts_forward,\n }\nIn our custom integration, we add support for ScatterMoE and SonicMoE, which are more efficient and faster than grouped_mm.\n\nUsage\nAdd the following to your axolotl YAML config:\nplugins:\n - axolotl.integrations.kernels.KernelsPlugin\n\nuse_kernels: true\n\nuse_scattermoe: true\nuse_sonicmoe: true\nImportant: Setting experts_implementation is incompatible with custom kernel options.\n\n\nSonicMoE installation\nPrerequisites:\n- NVIDIA Hopper (H100, H200) or Blackwell (B200, GB200) GPU\n- CUDA 12.9+ (13.0+ for B300)\n- PyTorch 2.7+ (2.9.1 recommended)\n- For B300: Triton 3.6.0\npip install --ignore-requires-python --no-deps \"sonic-moe @ git+https://github.com/Dao-AILab/sonic-moe.git@116e2df0a41874f77fa0ad269ce7df3f0cfcb956\" && pip install nvidia-cutlass-dsl==4.4.0 quack-kernels==0.2.5\nSee the SonicMoE installation guide for the latest prerequisite details.\nNote: Blackwell support is in upstream beta. On Blackwell GPUs, Axolotl automatically sets USE_QUACK_GEMM=1 to enable the Blackwell kernels.\n\n\nHow It Works\nThe KernelsPlugin runs before model loading and:\n\n\nScatterMoE\n\nRegisters the ScatterMoE kernel from the local libs/scattermoe_lora package (includes fused LoRA support via Triton kernels).\nPatches the models SparseMoeBlock forward method with the optimized ScatterMoE implementation via the HF kernels library.\n\n\n\nSonicMoE\n\nResolves the models MoE block class(es) from constants.py.\nPatches the forward method with SonicMoEs optimized CUTLASS kernels and registers a weight converter for the interleaved gate/up projection format.\nSupports pluggable routing strategies (see routing table below).\n\nBoth paths use the shared resolve_moe_block_classes utility in constants.py for model-type-to-class resolution.\n\n\nModel Support Matrix\nAll models use the SwiGLU activation (act_fn(gate) * up). Neither kernel currently supports non-SwiGLU MoE architectures.\n\n\nRouting strategies\n\n\n\n\n\n\n\n\n\nRouting Strategy\nDescription\nScatterMoE\nSonicMoE\n\n\n\n\nsoftmax → topk\nSoftmax over experts, select top-K, optional renormalization\nYes\nYes\n\n\nsoftmax → group selection → topk\nSoftmax, select top groups (sum of top-2 per group), topk from selected groups, renorm + scaling\nNo\nYes\n\n\nsigmoid → topk (with groups)\nSigmoid + bias correction, group-based masking, topk from masked scores, weights from original sigmoid\nYes\nYes\n\n\nsigmoid → topk (no groups)\nSigmoid + bias correction, straight topk (n_group=1)\nYes\nYes\n\n\nsoftmax → bias correction → topk\nSoftmax, bias via gate.moe_statics, topk, gather from original probs, clamp-based renorm\nNo\nYes\n\n\nsoftmax → group_limited_greedy\nSoftmax, group selection (max per group), topk, scale only (no renorm)\nNo\nYes\n\n\nsoftmax → topk via gate.wg\nSoftmax, gate weight at gate.wg.weight (not gate.weight), always renormalize\nNo\nYes\n\n\nfused topk → softmax\nRouting + expert computation fused in a single kernel\nNo\nPlanned\n\n\n\n\n\nPer-model support\n\n\n\n\n\n\n\n\n\n\nModel Type\nArchitecture\nRouting\nScatterMoE\nSonicMoE\n\n\n\n\nqwen2_moe\nQwen2-MoE\nsoftmax → topk\nYes\nYes\n\n\nqwen3_moe\nQwen3-MoE\nsoftmax → topk\nYes\nYes\n\n\nqwen3_5_moe\nQwen3.5-MoE\nsoftmax → topk\nYes\nYes\n\n\nqwen3_5_moe_text\nQwen3.5-MoE (VLM text)\nsoftmax → topk\nYes\nYes\n\n\nqwen3_next\nQwen3-Next\nsoftmax → topk\nYes\nYes\n\n\nqwen3_vl_moe\nQwen3-VL-MoE\nsoftmax → topk\nYes\nYes\n\n\nqwen3_omni_moe\nQwen3-Omni (Thinker + Talker)\nsoftmax → topk\nYes\nYes\n\n\nolmoe\nOLMoE\nsoftmax → topk\nYes\nYes\n\n\nmixtral\nMixtral\nsoftmax → topk\nYes\nYes\n\n\nminimax\nMiniMax\nsoftmax → topk\nYes\nYes\n\n\nmistral4\nMistral 4\nsoftmax → group → topk\nNo\nYes\n\n\nglm_moe_dsa\nGLM-MoE DSA (GLM 5)\nsigmoid → topk (groups)\nYes\nYes\n\n\ndeepseek_v3\nDeepSeek-V3\nsigmoid → topk (groups)\nYes\nYes\n\n\nglm4_moe\nGLM4-MoE\nsigmoid → topk (groups)\nYes\nYes\n\n\nglm4_moe_lite\nGLM4-MoE Lite (GLM 4.7 Flash)\nsigmoid → topk (groups)\nYes*\nYes\n\n\nglm4v_moe\nGLM4v-MoE\nsigmoid → topk (groups)\nYes\nYes\n\n\nminimax_m2\nMiniMax M2\nsigmoid → topk (no groups)\nYes\nYes\n\n\nernie4_5_moe\nERNIE 4.5 MoE\nsoftmax → bias → topk\nNo\nYes\n\n\ndeepseek_v2\nDeepSeek-V2\nsoftmax → group_limited_greedy\nNo\nYes\n\n\nhunyuan_v1_moe\nHunYuan V1 MoE\nsoftmax → topk (gate.wg)\nNo\nYes\n\n\ngpt_oss\nGPT-OSS\nfused topk → softmax\nNo\nPlanned\n\n\n\n* glm4_moe_lite with ScatterMoE may have issues — see Limitations.\n\n\nFeature comparison\n\n\n\n\n\n\n\n\nFeature\nScatterMoE\nSonicMoE\n\n\n\n\nKernel backend\nTriton\nCUTLASS\n\n\nGPU requirement\nAny CUDA\nHopper (H100/H200) or Blackwell (B200+)\n\n\nLoRA approach\nFused in Triton kernel\nRuntime materialization + custom autograd\n\n\nLoRA overhead\nLower (fused computation)\nHigher (per-forward materialization)\n\n\nGate/router LoRA\nYes\nYes\n\n\nExpert LoRA\nYes (fused)\nYes (materialized)\n\n\nShared expert LoRA\nYes (standard PEFT)\nYes (standard PEFT)\n\n\nSelective expert dequantization\nYes (~97% memory savings)\nNo\n\n\nWeight format\nTransposed [E, hidden, 2*inter]\nInterleaved gate/up [2*I, H, E]\n\n\ntorch.compile routing\nNo\nYes (optional)\n\n\n\n\n\nShared Expert Handling\nBoth kernels handle shared experts identically. Shared expert attribute names are detected in order of priority:\n\nshared_expert (Qwen2-MoE)\nshared_experts (GLM-MoE, DeepSeek-V3)\nshared_mlp (HunYuan V1 MoE)\n\nIf shared_expert_gate exists, sigmoid gating is applied to the shared expert contribution before adding it to the routed output. PEFT wraps shared expert linear layers with standard LoRA — no special handling is needed.\n\n\nLimitations\n\nScatterMoE + GLM4-MoE Lite: ScatterMoE does not work reliably for GLM 4.7 Flash (glm4_moe_lite).\nNon-SwiGLU activations: Neither kernel supports MoE architectures with non-SwiGLU expert activations (e.g., GPT-OSS uses a custom GLU variant).\nGPT-OSS: Deferred — requires transposed weight layout [E, H, 2*I], expert biases, and custom GLU activation. A dedicated forward path is needed.\nFSDP + fused gate LoRA (SonicMoE): The fused topk→softmax path materializes a local tensor when LoRA delta is present to avoid DTensor + Tensor mixing under FSDP.\n\n\n\nNote on MegaBlocks\nWe tested MegaBlocks but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#knowledge-distillation-kd",
"href": "docs/custom_integrations.html#knowledge-distillation-kd",
"title": "Custom Integrations",
"section": "Knowledge Distillation (KD)",
"text": "Knowledge Distillation (KD)\n\nUsage\nplugins:\n - \"axolotl.integrations.kd.KDPlugin\"\n\nkd_trainer: True\nkd_ce_alpha: 0.1\nkd_alpha: 0.9\nkd_temperature: 1.0\n\ntorch_compile: True # torch>=2.6.0, recommended to reduce vram\n\ndatasets:\n - path: ...\n type: \"axolotl.integrations.kd.chat_template\"\n field_messages: \"messages_combined\"\n logprobs_field: \"llm_text_generation_vllm_logprobs\" # for kd only, field of logprobs\nAn example dataset can be found at axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#llmcompressor",
"href": "docs/custom_integrations.html#llmcompressor",
"title": "Custom Integrations",
"section": "LLMCompressor",
"text": "LLMCompressor\nFine-tune sparsified models in Axolotl using Neural Magics LLMCompressor.\nThis integration enables fine-tuning of models sparsified using LLMCompressor within the Axolotl training framework. By combining LLMCompressors model compression capabilities with Axolotls distributed training pipelines, users can efficiently fine-tune sparse models at scale.\nIt uses Axolotls plugin system to hook into the fine-tuning flows while maintaining sparsity throughout training.\n\n\nRequirements\n\nAxolotl with llmcompressor extras:\npip install \"axolotl[llmcompressor]\"\nRequires llmcompressor >= 0.5.1\n\nThis will install all necessary dependencies to fine-tune sparsified models using the integration.\n\n\n\nUsage\nTo enable sparse fine-tuning with this integration, include the plugin in your Axolotl config:\nplugins:\n - axolotl.integrations.llm_compressor.LLMCompressorPlugin\n\nllmcompressor:\n recipe:\n finetuning_stage:\n finetuning_modifiers:\n ConstantPruningModifier:\n targets: [\n 're:.*q_proj.weight',\n 're:.*k_proj.weight',\n 're:.*v_proj.weight',\n 're:.*o_proj.weight',\n 're:.*gate_proj.weight',\n 're:.*up_proj.weight',\n 're:.*down_proj.weight',\n ]\n start: 0\n save_compressed: true\nThis plugin does not apply pruning or sparsification itself — it is intended for fine-tuning models that have already been sparsified.\nPre-sparsified checkpoints can be:\n- Generated using LLMCompressor\n- Downloaded from Neural Magics Hugging Face page\n- Any custom LLM with compatible sparsity patterns that youve created yourself\nTo learn more about writing and customizing LLMCompressor recipes, refer to the official documentation:\nhttps://github.com/vllm-project/llm-compressor/blob/main/README.md\n\n\nStorage Optimization with save_compressed\nSetting save_compressed: true in your configuration enables saving models in a compressed format, which:\n- Reduces disk space usage by approximately 40%\n- Maintains compatibility with vLLM for accelerated inference\n- Maintains compatibility with llmcompressor for further optimization (example: quantization)\nThis option is highly recommended when working with sparse models to maximize the benefits of model compression.\n\n\nExample Config\nSee examples/llama-3/sparse-finetuning.yaml for a complete example.\n\n\n\nInference with vLLM\nAfter fine-tuning your sparse model, you can leverage vLLM for efficient inference.\nYou can also use LLMCompressor to apply additional quantization to your fine-tuned\nsparse model before inference for even greater performance benefits.:\nfrom vllm import LLM, SamplingParams\n\nprompts = [\n \"Hello, my name is\",\n \"The president of the United States is\",\n \"The capital of France is\",\n \"The future of AI is\",\n]\nsampling_params = SamplingParams(temperature=0.8, top_p=0.95)\nllm = LLM(\"path/to/your/sparse/model\")\noutputs = llm.generate(prompts, sampling_params)\n\nfor output in outputs:\n prompt = output.prompt\n generated_text = output.outputs[0].text\n print(f\"Prompt: {prompt!r}, Generated text: {generated_text!r}\")\nFor more details on vLLMs capabilities and advanced configuration options, see the official vLLM documentation.\n\n\nLearn More\nFor details on available sparsity and quantization schemes, fine-tuning recipes, and usage examples, visit the official LLMCompressor repository:\nhttps://github.com/vllm-project/llm-compressor\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#language-model-evaluation-harness-lm-eval",
"href": "docs/custom_integrations.html#language-model-evaluation-harness-lm-eval",
"title": "Custom Integrations",
"section": "Language Model Evaluation Harness (LM Eval)",
"text": "Language Model Evaluation Harness (LM Eval)\nRun evaluation on model using the popular lm-evaluation-harness library.\nSee https://github.com/EleutherAI/lm-evaluation-harness\n\nUsage\nThere are two ways to use the LM Eval integration:\n\n\n1. Post-Training Evaluation\nWhen training with the plugin enabled, evaluation runs automatically after training completes:\nplugins:\n - axolotl.integrations.lm_eval.LMEvalPlugin\n\nlm_eval_tasks:\n - gsm8k\n - hellaswag\n - arc_easy\n\nlm_eval_batch_size: # Batch size for evaluation\n\noutput_dir:\nRun training as usual:\naxolotl train config.yml\n\n\n2. Standalone CLI Evaluation\nEvaluate any model directly without training:\nlm_eval_model: meta-llama/Llama-2-7b-hf\n\nplugins:\n - axolotl.integrations.lm_eval.LMEvalPlugin\n\nlm_eval_tasks:\n - gsm8k\n - hellaswag\n - arc_easy\n\nlm_eval_batch_size: 8\noutput_dir: ./outputs\nRun evaluation:\naxolotl lm-eval config.yml\n\n\nModel Selection Priority\nThe model to evaluate is selected in the following priority order:\n\nlm_eval_model - Explicit model path or HuggingFace repo (highest priority)\nhub_model_id - Trained model pushed to HuggingFace Hub\noutput_dir - Local checkpoint directory containing trained model weights\n\n\n\nCitation\n@misc{eval-harness,\n author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},\n title = {A framework for few-shot language model evaluation},\n month = 07,\n year = 2024,\n publisher = {Zenodo},\n version = {v0.4.3},\n doi = {10.5281/zenodo.12608602},\n url = {https://zenodo.org/records/12608602}\n}\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#liger-kernels",
"href": "docs/custom_integrations.html#liger-kernels",
"title": "Custom Integrations",
"section": "Liger Kernels",
"text": "Liger Kernels\nLiger Kernel provides efficient Triton kernels for LLM training, offering:\n\n20% increase in multi-GPU training throughput\n60% reduction in memory usage\nCompatibility with both FSDP and DeepSpeed\n\nSee https://github.com/linkedin/Liger-Kernel\n\nUsage\nplugins:\n - axolotl.integrations.liger.LigerPlugin\nliger_rope: true\nliger_rms_norm: true\nliger_glu_activation: true\nliger_layer_norm: true\nliger_fused_linear_cross_entropy: true\n\nliger_use_token_scaling: true\n\n\nSupported Models\n\ndeepseek_v2\ngemma\ngemma2\ngemma3\ngranite\njamba\nllama\nmistral\nmixtral\nmllama\nmllama_text_model\nolmo2\npaligemma\nphi3\nqwen2\nqwen2_5_vl\nqwen2_vl\n\n\n\nCitation\n@article{hsu2024ligerkernelefficienttriton,\n title={Liger Kernel: Efficient Triton Kernels for LLM Training},\n author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},\n year={2024},\n eprint={2410.10989},\n archivePrefix={arXiv},\n primaryClass={cs.LG},\n url={https://arxiv.org/abs/2410.10989},\n journal={arXiv preprint arXiv:2410.10989},\n}\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#nemo-gym-integration-for-axolotl",
"href": "docs/custom_integrations.html#nemo-gym-integration-for-axolotl",
"title": "Custom Integrations",
"section": "NeMo Gym Integration for Axolotl",
"text": "NeMo Gym Integration for Axolotl\nTrain LLMs with reinforcement learning using NVIDIA NeMo Gym environments as reward sources. NeMo Gym provides 50+ verified RL environments spanning math, coding, tool-use, reasoning, and safety — each with deterministic reward signals.\n\nValidated Training Paths\n\n\n\n\n\n\n\n\n\nPath\nSpeed\nMulti-turn\nArchitecture\n\n\n\n\nAsync GRPO + Data Producer\nFastest (3x)\nYes\nNemoGymDataProducer replaces vLLM generation\n\n\nStandard GRPO + Data Producer\nBaseline\nYes\nSame producer, no async prefetch\n\n\nStandard GRPO + /verify\nSimplest\nNo\nReward function calls /verify directly\n\n\nFSDP2 + /verify (2 GPU)\nDistributed\nNo\nfsdp_version: 2\n\n\n\nMulti-turn uses nemo_gym_multi_turn: true which auto-enables the async trainers\ndata producer protocol. The plugins NemoGymDataProducer calls NeMo Gym agent /run\nendpoints and returns RolloutDataset with proper IS correction, env_mask, and rewards.\nAll paths tested end-to-end with Qwen3-0.6B + LoRA, logged to wandb project nemo-gym-rl.\n\n\nQuick Start\n\n\nPrerequisites\n\nuv package manager (for NeMo Gyms venv)\nTwo GPUs recommended (one for vLLM server, one for training)\n\n\n\n1. Set Up NeMo Gym\ngit clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym\ncd ~/Gym\nuv venv --python 3.12 && source .venv/bin/activate && uv sync\n\nCFLAGS=\"\" uv pip install pycosat --python .venv/bin/python --no-build-isolation\n\nfor dir in resources_servers/reasoning_gym resources_servers/example_single_tool_call responses_api_models/vllm_model responses_api_agents/simple_agent; do\n uv venv --seed --allow-existing --python 3.12 $dir/.venv\n CFLAGS=\"\" uv pip install --python $dir/.venv/bin/python pycosat --no-build-isolation 2>/dev/null\n uv pip install --python $dir/.venv/bin/python -e . \"ray[default]==2.52.1\"\ndone\n\nuv pip install --python resources_servers/reasoning_gym/.venv/bin/python \\\n reasoning-gym matplotlib pillow cycler contourpy kiwisolver\n\n\n2. Multi-Turn with Async GRPO (Recommended — Fastest Path)\nThis is the fully validated, highest-performance path. NeMo Gyms agent server handles\nmulti-turn tool execution while axolotls async GRPO prefetches data in background threads.\nStep 1: Create the NeMo Gym agent config\nCreate ~/Gym/configs/axolotl_tool_calling.yaml:\nexample_single_tool_call:\n resources_servers:\n example_single_tool_call:\n entrypoint: app.py\n domain: agent\n verified: false\n\npolicy_model:\n responses_api_models:\n vllm_model:\n entrypoint: app.py\n base_url: http://localhost:8000/v1\n api_key: dummy_key\n model: Qwen/Qwen3-0.6B # Must match your training model\n return_token_id_information: true\n uses_reasoning_parser: false\n\nexample_single_tool_call_simple_agent:\n responses_api_agents:\n simple_agent:\n entrypoint: app.py\n resources_server:\n type: resources_servers\n name: example_single_tool_call\n model_server:\n type: responses_api_models\n name: policy_model\n datasets:\n - name: weather\n type: example\n jsonl_fpath: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl\nStep 2: Start three services\nCUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \\\n --model Qwen/Qwen3-0.6B --max-model-len 2048 --gpu-memory-utilization 0.85\n\ncd ~/Gym && .venv/bin/ng_run \\\n \"+config_paths=[configs/axolotl_tool_calling.yaml]\" \"+skip_venv_if_present=true\"\n\ncd experiments && CUDA_VISIBLE_DEVICES=1 CUDA_HOME=$HOME/env-claude-cu130/cuda_shim \\\n axolotl train nemo_gym_async_agent.yaml\nStep 3: Training config (nemo_gym_async_agent.yaml):\nbase_model: Qwen/Qwen3-0.6B\nadapter: lora\nlora_r: 16\nlora_alpha: 32\nlora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]\nsequence_len: 2048\n\nrl: grpo\nchat_template: tokenizer_default\n\ntrl:\n use_vllm: true\n vllm_mode: server\n vllm_server_host: localhost\n vllm_server_port: 8000\n vllm_lora_sync: true\n vllm_sync_interval: 5\n # Async GRPO — 3x faster than standard\n use_data_producer: true\n async_prefetch: true\n num_generations: 4\n max_completion_length: 512\n temperature: 0.8\n reward_funcs:\n - axolotl.integrations.nemo_gym.rewards.reward_env\n\nplugins:\n - axolotl.integrations.nemo_gym.NemoGymPlugin\n\nnemo_gym_enabled: true\nnemo_gym_auto_start: false\nnemo_gym_head_port: 11000\nnemo_gym_multi_turn: true\nnemo_gym_verify_timeout: 120\nnemo_gym_datasets:\n - path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl\n server_name: example_single_tool_call\n\ndatasets:\n - path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl\n type: chat_template\n field_messages: responses_create_params.input\n message_field_content: content\n message_field_role: role\n\nvllm:\n gpu_memory_utilization: 0.85\n max_model_len: 2048\n tensor_parallel_size: 1\n\nlearning_rate: 5e-6\nmicro_batch_size: 1\ngradient_accumulation_steps: 4\nmax_steps: 30\ngradient_checkpointing: true\nbf16: true\noutput_dir: ./outputs/nemo_gym_async\n\nuse_wandb: true\nwandb_project: nemo-gym-rl\n\n\n3. Single-Turn Training (Simplest — No Agent Server Needed)\nFor environments that only need single-turn verify (math, coding challenges), you dont need\nan agent server. The plugins reward function calls /verify directly.\nbase_model: Qwen/Qwen2.5-0.5B-Instruct\nrl: grpo\nchat_template: tokenizer_default\n\ntrl:\n use_vllm: true\n vllm_mode: colocate\n vllm_enable_sleep_mode: false\n num_generations: 8\n max_completion_length: 128\n temperature: 0.9\n reward_funcs:\n - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify\n\nplugins:\n - axolotl.integrations.nemo_gym.NemoGymPlugin\n\nnemo_gym_enabled: true\nnemo_gym_auto_start: false\nnemo_gym_head_port: 11000\nnemo_gym_datasets:\n - path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl\n server_name: reasoning_gym\n\ndatasets:\n - path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl\n type: chat_template\n field_messages: responses_create_params.input\n message_field_content: content\n message_field_role: role\n\nvllm:\n gpu_memory_utilization: 0.3\n max_model_len: 512\n tensor_parallel_size: 1\n\nlearning_rate: 1e-5\nmicro_batch_size: 4\ngradient_accumulation_steps: 2\nmax_steps: 50\noutput_dir: ./outputs/nemo_gym_arithmetic\nOnly needs ng_run with resource servers (no agent config):\ncd ~/Gym && ng_run \"+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]\" \"+skip_venv_if_present=true\"\n\n\nHow It Works\n\n\nSingle-Turn\naxolotl train → GRPO Trainer generates completions\n → NeMo Gym plugin reward_fn calls POST /verify on resource server\n → reward flows back to GRPO for advantage computation\n\n\nMulti-Turn (Agent /run)\n┌─────────────┐ ┌──────────────┐ ┌──────────────────┐\n│ axolotl │ │ NeMo Gym │────▶│ vLLM OpenAI │\n│ train │────▶│ Agent /run │◀────│ Server (GPU 0) │\n│ (GPU 1) │ │ │ │ /v1/completions │\n└─────────────┘ └──────┬───────┘ └──────────────────┘\n │\n ▼\n ┌──────────────┐\n │ Resource │\n │ Server │\n │ (tools + │\n │ verify) │\n └─────────────┘\nThe agent server orchestrates the entire multi-turn loop:\n1. Calls our vLLM server for model generation\n2. Parses tool calls from model output\n3. Executes tools against resource servers\n4. Feeds tool results back to the model\n5. Repeats until done, then calls /verify for reward\n6. Returns token IDs + logprobs + reward to our rollout_func\n\n\nData Producer Architecture (Multi-Turn)\nWhen nemo_gym_multi_turn: true, the plugin automatically forces use_data_producer: true\nwhich selects the AxolotlAsyncGRPOTrainer. The plugin then swaps the trainers data\nproducer with NemoGymDataProducer, which:\n\nGets a prompt batch from the dataset iterator\nExpands by num_generations (one agent call per rollout)\nCalls NeMo Gym agents via async HTTP (aiohttp.gather)\nParses responses into padded tensors (RolloutDataset)\nReturns with _pending_policy_logps=True for deferred scoring\n\nThe main thread then runs _compute_deferred_scores() which:\n- Computes policy logprobs on the training model (GPU forward pass)\n- Computes IS correction using agents sampling logprobs vs training model logprobs\n- Computes advantages with group-level normalization\n- All downstream features work: replay buffer, re-roll, streaming, zero-adv skip\nWith async_prefetch: true, the data producer runs in a background thread — giving ~3x\nspeedup as generation and training overlap. With async_prefetch: false, it runs\nsynchronously on the main thread (still uses the data producer protocol).\n\n\nWeight Sync (LoRA Mode)\nWith vllm_lora_sync: true, the plugin (or async trainer) replaces NCCL-based weight\nsync with filesystem + HTTP:\n\naccelerator.get_state_dict() gathers LoRA weights from all ranks\nRank 0 saves adapter to /tmp/lora_sync_*/vN/\nRank 0 POSTs to /set_lora_adapter/ on vLLM server\nvLLM loads adapter natively via Punica kernels\nOnly ~40MB transferred (vs multiple GBs for full model weights)\n\n\n\nMulti-Environment Support\nDatasets support per-row environment routing via agent_ref:\n{\"agent_ref\": {\"name\": \"reasoning_gym\"}, \"responses_create_params\": {...}}\n{\"agent_ref\": {\"name\": \"instruction_following\"}, \"responses_create_params\": {...}}\nOr use the simpler per-dataset routing:\nnemo_gym_datasets:\n - path: reasoning_data.jsonl\n server_name: reasoning_gym\n - path: tool_data.jsonl\n server_name: example_single_tool_call\n\n\nConfiguration Reference\n\n\n\n\n\n\n\n\n\nParameter\nType\nDefault\nDescription\n\n\n\n\nnemo_gym_enabled\nbool\nnull\nEnable the NeMo Gym integration\n\n\nnemo_gym_dir\nstr\n~/Gym\nPath to NeMo Gym repo\n\n\nnemo_gym_auto_clone\nbool\ntrue\nAuto-clone NeMo Gym repo if missing\n\n\nnemo_gym_auto_start\nbool\ntrue\nAuto-start resource servers\n\n\nnemo_gym_config_paths\nlist[str]\n—\nServer config YAMLs (relative to gym_dir)\n\n\nnemo_gym_datasets\nlist[dict]\nrequired\nDataset configs with path and optional server_name\n\n\nnemo_gym_head_port\nint\n11000\nHead server port\n\n\nnemo_gym_server_timeout\nint\n360\nServer startup timeout (seconds)\n\n\nnemo_gym_verify_timeout\nint\n30\nPer-request timeout (seconds)\n\n\nnemo_gym_multi_turn\nbool\nfalse\nEnable multi-turn via agent /run\n\n\n\n\n\nDataset JSONL Format\nEach line must have responses_create_params with input messages:\n{\n \"responses_create_params\": {\n \"input\": [{\"role\": \"user\", \"content\": \"What's the weather in SF?\"}],\n \"tools\": [{\"name\": \"get_weather\", \"type\": \"function\", \"strict\": true, \"parameters\": {...}}]\n }\n}\nFor multi-turn agent routing, include agent_ref:\n{\"agent_ref\": {\"name\": \"my_agent\"}, \"responses_create_params\": {...}}\nNote: Tool definitions MUST include \"strict\": true and \"additionalProperties\": false for NeMo Gym agent compatibility.\n\n\nReward Functions\nThe plugin provides two built-in reward functions — no user code needed:\ntrl:\n reward_funcs:\n # Multi-turn (nemo_gym_multi_turn: true):\n # Passthrough — agent /run already computed the reward\n - axolotl.integrations.nemo_gym.rewards.reward_env\n\n # Single-turn (nemo_gym_multi_turn: false):\n # Calls /verify endpoints on NeMo Gym resource servers\n - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify\nBoth are also importable from Python:\nfrom axolotl.integrations.nemo_gym import reward_env, reward_nemo_gym_verify\n\n\nKnown Issues / Troubleshooting\n\n\nNeMo Gym Server Setup\n\npycosat build failure: CFLAGS=\"\" uv pip install pycosat --no-build-isolation\nRay version mismatch: Pin ray[default]==2.52.1 in all server venvs\nPre-build venvs: ng_run creates per-server venvs via Ray. Pre-build them and use +skip_venv_if_present=true\nTool strict field required: Agent server validates tool definitions require strict: true\n\n\n\nvLLM / Weight Sync\n\nStart vLLM with LoRA + tool calling + runtime loading:\nVLLM_ALLOW_RUNTIME_LORA_UPDATING=1 \\\nCUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \\\n --model Qwen/Qwen3-4B-Instruct-2507 \\\n --max-model-len 4096 \\\n --gpu-memory-utilization 0.7 \\\n --enable-lora --max-lora-rank 64 \\\n --enable-auto-tool-choice --tool-call-parser hermes\nVLLM_ALLOW_RUNTIME_LORA_UPDATING=1: Required for vllm_lora_sync: true. Without it, vLLM wont expose the /v1/load_lora_adapter endpoint and weight sync will fail silently. The plugin warns if this endpoint is missing.\n--enable-lora: Enables LoRA adapter support in vLLM\n--enable-auto-tool-choice --tool-call-parser hermes: Required for Qwen3 tool calling\nmax_model_len must be > max_completion_length: Leave room for prompt tokens (~200). If equal, the NeMo Gym model proxy gets a 400 error and returns empty completions.\nCUDA_HOME required: DeepSpeed import needs it for the nvcc shim\nNCCL weight sync broken with vLLM 0.17: Use vllm_lora_sync: true (filesystem + HTTP via /v1/load_lora_adapter)\n\n\n\nMulti-Turn\n\nAgent server required: Multi-turn delegates to NeMo Gyms agent server /run endpoint. Without an agent, the plugin falls back to single-turn /verify\nModel server proxy: NeMo Gym needs a responses_api_models server that proxies to your vLLM. See the agent config example above\n\n\n\nFSDP2\n\nValidated on 2 GPUs with single-turn + LoRA\nAsync field filtering: The builder automatically filters async-only config fields when using the standard GRPO trainer\n\n\n\nComparison with Other Integrations\n\n\n\n\n\n\n\n\n\nFeature\nAxolotl + NeMo Gym\nUnsloth + NeMo Gym\nNeMo RL (native)\n\n\n\n\nServer management\nAutomatic\nManual (notebook)\nBuilt-in\n\n\nMulti-environment\nPer-row routing\nManual code\nYAML config\n\n\nMulti-turn / tool use\nAgent /run delegation\nNo\nAgent /run (Ray)\n\n\nAsync GRPO (3x speedup)\nYes\nNo\nYes\n\n\nLoRA sync\nFilesystem + HTTP\nN/A\nNCCL\n\n\nMulti-GPU (FSDP2)\nYes\nNo\nYes (Ray)\n\n\nConfig-driven\nYes\nNo (code)\nYes\n\n\n\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#spectrum",
"href": "docs/custom_integrations.html#spectrum",
"title": "Custom Integrations",
"section": "Spectrum",
"text": "Spectrum\nby Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar\nThis plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR).\nSee https://github.com/cognitivecomputations/spectrum\n\nOverview\nSpectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models.\nBy identifying the top n% of layers with the highest SNR, you can optimize training efficiency.\n\n\nUsage\nplugins:\n - axolotl.integrations.spectrum.SpectrumPlugin\n\nspectrum_top_fraction: 0.5\nspectrum_model_name: meta-llama/Meta-Llama-3.1-8B\n\n\nCitation\n@misc{hartford2024spectrumtargetedtrainingsignal,\n title={Spectrum: Targeted Training on Signal to Noise Ratio},\n author={Eric Hartford and Lucas Atkins and Fernando Fernandes Neto and David Golchinfar},\n year={2024},\n eprint={2406.06623},\n archivePrefix={arXiv},\n primaryClass={cs.LG},\n url={https://arxiv.org/abs/2406.06623},\n}\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#swanlab-integration-for-axolotl",
"href": "docs/custom_integrations.html#swanlab-integration-for-axolotl",
"title": "Custom Integrations",
"section": "SwanLab Integration for Axolotl",
"text": "SwanLab Integration for Axolotl\nSwanLab is an open-source, lightweight AI experiment tracking and visualization tool that provides a platform for tracking, recording, comparing, and collaborating on experiments.\nThis integration enables seamless experiment tracking and visualization of Axolotl training runs using SwanLab.\n\nFeatures\n\n📊 Automatic Metrics Logging: Training loss, learning rate, and other metrics are automatically logged\n🎯 Hyperparameter Tracking: Model configuration and training parameters are tracked\n📈 Real-time Visualization: Monitor training progress in real-time through SwanLab dashboard\n☁ Cloud & Local Support: Works in both cloud-synced and offline modes\n🔄 Experiment Comparison: Compare multiple training runs easily\n🤝 Team Collaboration: Share experiments with team members\n🎭 RLHF Completion Logging: Automatically log model outputs during DPO/KTO/ORPO/GRPO training for qualitative analysis\n⚡ Performance Profiling: Built-in profiling decorators to measure and optimize training performance\n🔔 Lark Notifications: Send real-time training updates to team chat (Feishu/Lark integration)\n\n\n\nInstallation\npip install swanlab\n\n\nQuick Start\n\n\n1. Register for SwanLab (Optional for cloud mode)\nIf you want to use cloud sync features, register at https://swanlab.cn to get your API key.\n\n\n2. Configure Axolotl Config File\nAdd SwanLab configuration to your Axolotl YAML config:\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: my-llm-project\nswanlab_experiment_name: qwen-finetune-v1\nswanlab_mode: cloud # Options: cloud, local, offline, disabled\nswanlab_workspace: my-team # Optional: organization name\nswanlab_api_key: YOUR_API_KEY # Optional: can also use env var SWANLAB_API_KEY\n\n\n3. Run Training\nexport SWANLAB_API_KEY=your-api-key-here\n\nswanlab login\n\naccelerate launch -m axolotl.cli.train your-config.yaml\n\n\nConfiguration Options\n\n\nBasic Configuration\n\n\n\n\n\n\n\n\n\nParameter\nType\nDefault\nDescription\n\n\n\n\nuse_swanlab\nbool\nfalse\nEnable SwanLab tracking\n\n\nswanlab_project\nstr\nNone\nProject name (required)\n\n\nswanlab_experiment_name\nstr\nNone\nExperiment name\n\n\nswanlab_description\nstr\nNone\nExperiment description\n\n\nswanlab_mode\nstr\ncloud\nSync mode: cloud, local, offline, disabled\n\n\n\n\n\nAdvanced Configuration\n\n\n\n\n\n\n\n\n\nParameter\nType\nDefault\nDescription\n\n\n\n\nswanlab_workspace\nstr\nNone\nWorkspace/organization name\n\n\nswanlab_api_key\nstr\nNone\nAPI key (prefer env var)\n\n\nswanlab_web_host\nstr\nNone\nPrivate deployment web host\n\n\nswanlab_api_host\nstr\nNone\nPrivate deployment API host\n\n\nswanlab_log_model\nbool\nfalse\nLog model checkpoints (coming soon)\n\n\nswanlab_lark_webhook_url\nstr\nNone\nLark (Feishu) webhook URL for team notifications\n\n\nswanlab_lark_secret\nstr\nNone\nLark webhook HMAC secret for authentication\n\n\nswanlab_log_completions\nbool\ntrue\nEnable RLHF completion table logging (DPO/KTO/ORPO/GRPO)\n\n\nswanlab_completion_log_interval\nint\n100\nSteps between completion logging\n\n\nswanlab_completion_max_buffer\nint\n128\nMax completions to buffer (memory bound)\n\n\n\n\n\nConfiguration Examples\n\n\nExample 1: Basic Cloud Sync\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: llama-finetune\nswanlab_experiment_name: llama-3-8b-instruct-v1\nswanlab_mode: cloud\n\n\nExample 2: Offline/Local Mode\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: local-experiments\nswanlab_experiment_name: test-run-1\nswanlab_mode: local # or 'offline'\n\n\nExample 3: Team Workspace\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: research-project\nswanlab_experiment_name: experiment-42\nswanlab_workspace: my-research-team\nswanlab_mode: cloud\n\n\nExample 4: Private Deployment\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: internal-project\nswanlab_experiment_name: secure-training\nswanlab_mode: cloud\nswanlab_web_host: https://swanlab.yourcompany.com\nswanlab_api_host: https://api.swanlab.yourcompany.com\n\n\nTeam Notifications with Lark (Feishu)\nSwanLab supports sending real-time training notifications to your team chat via Lark (Feishu), ByteDances enterprise collaboration platform. This is especially useful for:\n- Production training monitoring: Get alerts when training starts, completes, or encounters errors\n- Team collaboration: Keep your ML team informed about long-running experiments\n- Multi-timezone teams: Team members can check training progress without being online\n\n\nPrerequisites\n\nLark Bot Setup: Create a custom bot in your Lark group chat\nWebhook URL: Get the webhook URL from your Lark bot settings\nHMAC Secret (recommended): Enable signature verification in your Lark bot for security\n\nFor detailed Lark bot setup instructions, see Lark Custom Bot Documentation.\n\n\nExample 5: Basic Lark Notifications\nSend training notifications to a Lark group chat:\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: production-training\nswanlab_experiment_name: llama-3-finetune-v2\nswanlab_mode: cloud\n\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx\nNote: This configuration will work, but youll see a security warning recommending HMAC secret configuration.\n\n\nExample 6: Lark Notifications with HMAC Security (Recommended)\nFor production use, enable HMAC signature verification:\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: production-training\nswanlab_experiment_name: llama-3-finetune-v2\nswanlab_mode: cloud\n\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx\nswanlab_lark_secret: your-webhook-secret-key\nWhy HMAC secret matters:\n- Prevents unauthorized parties from sending fake notifications to your Lark group\n- Ensures notifications genuinely come from your training jobs\n- Required for production deployments with sensitive training data\n\n\nExample 7: Team Workspace + Lark Notifications\nCombine team workspace collaboration with Lark notifications:\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: research-project\nswanlab_experiment_name: multimodal-experiment-42\nswanlab_workspace: ml-research-team\nswanlab_mode: cloud\n\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxxxxxxxxx\nswanlab_lark_secret: your-webhook-secret-key\n\n\nWhat Notifications Are Sent?\nSwanLabs Lark integration sends notifications for key training events:\n- Training Start: When your experiment begins\n- Training Complete: When training finishes successfully\n- Training Errors: If training crashes or encounters critical errors\n- Metric Milestones: Configurable alerts for metric thresholds (if configured in SwanLab)\nEach notification includes:\n- Experiment name and project\n- Training status\n- Key metrics (loss, learning rate)\n- Direct link to SwanLab dashboard\n\n\nLark Configuration Validation\nThe plugin validates your Lark configuration at startup:\n\n✅ Valid Configurations\nuse_swanlab: true\nswanlab_project: my-project\n\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxx\nswanlab_lark_secret: your-secret\n\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxx\n\n\n\nSecurity Best Practices\n\nAlways use HMAC secret in production:\nswanlab_lark_webhook_url: https://open.feishu.cn/...\nswanlab_lark_secret: your-secret-key # ✅ Add this!\nStore secrets in environment variables (even better):\n# In your training script/environment\nexport SWANLAB_LARK_WEBHOOK_URL=\"https://open.feishu.cn/...\"\nexport SWANLAB_LARK_SECRET=\"your-secret-key\"\nThen in config:\n# SwanLab plugin will auto-detect environment variables\nuse_swanlab: true\nswanlab_project: my-project\n# Lark URL and secret read from env vars\nRotate webhook secrets periodically: Update your Lark bots secret every 90 days\nUse separate webhooks for dev/prod: Dont mix development and production notifications\n\n\n\nDistributed Training\nLark notifications are automatically deduplicated in distributed training:\n- Only rank 0 sends notifications\n- Other GPU ranks skip Lark registration\n- Prevents duplicate messages in multi-GPU training\ntorchrun --nproc_per_node=4 -m axolotl.cli.train config.yml\n\n\nRLHF Completion Table Logging\nFor RLHF (Reinforcement Learning from Human Feedback) training methods like DPO, KTO, ORPO, and GRPO, SwanLab can log model completions (prompts, chosen/rejected responses, rewards) to a visual table for qualitative analysis. This helps you:\n\nInspect model behavior: See actual model outputs during training\nDebug preference learning: Compare chosen vs rejected responses\nTrack reward patterns: Monitor how rewards evolve over training\nShare examples with team: Visual tables in SwanLab dashboard\n\n\n\nFeatures\n\n✅ Automatic detection: Works with DPO, KTO, ORPO, GRPO trainers\n✅ Memory-safe buffering: Bounded buffer prevents memory leaks in long training runs\n✅ Periodic logging: Configurable logging interval to reduce overhead\n✅ Rich visualization: SwanLab tables show prompts, responses, and metrics side-by-side\n\n\n\nConfiguration\n\n\n\n\n\n\n\n\n\nParameter\nType\nDefault\nDescription\n\n\n\n\nswanlab_log_completions\nbool\ntrue\nEnable completion logging for RLHF trainers\n\n\nswanlab_completion_log_interval\nint\n100\nLog completions to SwanLab every N training steps\n\n\nswanlab_completion_max_buffer\nint\n128\nMaximum completions to buffer (memory bound)\n\n\n\n\n\nExample: DPO Training with Completion Logging\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: dpo-training\nswanlab_experiment_name: llama-3-dpo-v1\nswanlab_mode: cloud\n\nswanlab_log_completions: true\nswanlab_completion_log_interval: 100 # Log every 100 steps\nswanlab_completion_max_buffer: 128 # Keep last 128 completions\n\nrl: dpo\ndatasets:\n - path: /path/to/preference_dataset\n type: chatml.intel\n\n\nExample: Disable Completion Logging\nIf youre doing a quick test run or dont need completion tables:\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: dpo-training\n\nswanlab_log_completions: false\n\n\nSupported RLHF Trainers\nThe completion logging callback automatically activates for these trainer types:\n\nDPO (Direct Preference Optimization): Logs prompts, chosen, rejected, reward_diff\nKTO (Kahneman-Tversky Optimization): Logs prompts, completions, labels, rewards\nORPO (Odds Ratio Preference Optimization): Logs prompts, chosen, rejected, log_odds_ratio\nGRPO (Group Relative Policy Optimization): Logs prompts, completions, rewards, advantages\nCPO (Constrained Policy Optimization): Logs prompts, chosen, rejected\n\nFor non-RLHF trainers (standard supervised fine-tuning), the completion callback is automatically skipped.\n\n\nHow It Works\n\nAuto-detection: Plugin detects trainer type at initialization\nBuffering: Completions are buffered in memory (up to swanlab_completion_max_buffer)\nPeriodic logging: Every swanlab_completion_log_interval steps, buffer is logged to SwanLab\nMemory safety: Old completions are automatically dropped when buffer is full (uses collections.deque)\nFinal flush: Remaining completions are logged when training completes\n\n\n\nViewing Completion Tables\nAfter training starts, you can view completion tables in your SwanLab dashboard:\n\nNavigate to your experiment in SwanLab\nLook for the “rlhf_completions” table in the metrics panel\nThe table shows:\n\nstep: Training step when completion was generated\nprompt: Input prompt\nchosen: Preferred response (DPO/ORPO)\nrejected: Non-preferred response (DPO/ORPO)\ncompletion: Model output (KTO/GRPO)\nreward_diff/reward: Reward metrics\nTrainer-specific metrics (e.g., log_odds_ratio for ORPO)\n\n\n\n\nMemory Management\nThe completion buffer is memory-bounded to prevent memory leaks:\nfrom collections import deque\n\nbuffer = deque(maxlen=128) # Old completions automatically dropped\nMemory usage estimate:\n- Average completion: ~500 characters (prompt + responses)\n- Buffer size 128: ~64 KB (negligible)\n- Buffer size 1024: ~512 KB (still small)\nRecommendation: Default buffer size (128) works well for most cases. Increase to 512-1024 only if you need to review more historical completions.\n\n\nPerformance Impact\nCompletion logging has minimal overhead:\n\nBuffering: O(1) append operation, negligible CPU/memory\nLogging: Only happens every N steps (default: 100)\nNetwork: SwanLab batches table uploads efficiently\n\nExpected overhead: < 0.5% per training step\n\n\nTroubleshooting\n\nCompletions not appearing in SwanLab\nCause: Trainer may not be logging completion data in the expected format.\nDiagnostic steps:\n1. Check trainer type detection in logs:\ntext INFO: SwanLab RLHF completion logging enabled for DPOTrainer (type: dpo)\n2. Verify your trainer is an RLHF trainer (DPO/KTO/ORPO/GRPO)\n3. Check if trainer logs completion data (this depends on TRL version)\nNote: The current implementation expects trainers to log completion data in the logs dict during on_log() callback. Some TRL trainers may not expose this data by default. You may need to patch the trainer to expose completions.\n\n\nBuffer fills up too quickly\nCause: High logging frequency with small buffer size.\nSolution: Increase buffer size or logging interval:\nswanlab_completion_log_interval: 200 # Log less frequently\nswanlab_completion_max_buffer: 512 # Larger buffer\n\n\nMemory usage growing over time\nCause: Buffer should be bounded, so this indicates a bug.\nSolution:\n1. Verify swanlab_completion_max_buffer is set\n2. Check SwanLab version is up to date\n3. Report issue with memory profiling data\n\n\n\nPerformance Profiling\nSwanLab integration includes profiling utilities to measure and log execution time of trainer methods. This helps you:\n\nIdentify bottlenecks: Find slow operations in your training loop\nOptimize performance: Track improvements after optimization changes\nMonitor distributed training: See per-rank timing differences\nDebug hangs: Detect methods that take unexpectedly long\n\n\n\nFeatures\n\n✅ Zero-config profiling: Automatic timing of key trainer methods\n✅ Decorator-based: Easy to add profiling to custom methods with @swanlab_profile\n✅ Context manager: Fine-grained profiling with swanlab_profiling_context()\n✅ Advanced filtering: ProfilingConfig for throttling and minimum duration thresholds\n✅ Exception-safe: Logs duration even if function raises an exception\n\n\n\nBasic Usage: Decorator\nAdd profiling to any trainer method with the @swanlab_profile decorator:\nfrom axolotl.integrations.swanlab.profiling import swanlab_profile\n\nclass MyCustomTrainer(AxolotlTrainer):\n @swanlab_profile\n def training_step(self, model, inputs):\n # Your training step logic\n return super().training_step(model, inputs)\n\n @swanlab_profile\n def prediction_step(self, model, inputs, prediction_loss_only):\n # Your prediction logic\n return super().prediction_step(model, inputs, prediction_loss_only)\nThe decorator automatically:\n1. Measures execution time with high-precision timer\n2. Logs to SwanLab as profiling/Time taken: ClassName.method_name\n3. Only logs if SwanLab is enabled (use_swanlab: true)\n4. Gracefully handles exceptions (logs duration, then re-raises)\n\n\nAdvanced Usage: Context Manager\nFor fine-grained profiling within a method:\nfrom axolotl.integrations.swanlab.profiling import swanlab_profiling_context\n\nclass MyTrainer(AxolotlTrainer):\n def complex_training_step(self, model, inputs):\n # Profile just the forward pass\n with swanlab_profiling_context(self, \"forward_pass\"):\n outputs = model(**inputs)\n\n # Profile just the backward pass\n with swanlab_profiling_context(self, \"backward_pass\"):\n loss = outputs.loss\n loss.backward()\n\n return outputs\n\n\nAdvanced Usage: ProfilingConfig\nFilter and throttle profiling logs with ProfilingConfig:\nfrom axolotl.integrations.swanlab.profiling import (\n swanlab_profiling_context_advanced,\n ProfilingConfig,\n)\n\nprofiling_config = ProfilingConfig(\n enabled=True,\n min_duration_ms=1.0, # Only log if duration > 1ms\n log_interval=10, # Log every 10th call\n)\n\nclass MyTrainer(AxolotlTrainer):\n def frequently_called_method(self, data):\n with swanlab_profiling_context_advanced(\n self,\n \"frequent_op\",\n config=profiling_config\n ):\n # This only logs every 10th call, and only if it takes > 1ms\n result = expensive_computation(data)\n return result\nProfilingConfig Parameters:\n- enabled: Enable/disable profiling globally (default: True)\n- min_duration_ms: Minimum duration to log in milliseconds (default: 0.1)\n- log_interval: Log every Nth function call (default: 1 = log all)\nUse cases:\n- High-frequency methods: Use log_interval=100 to reduce logging overhead\n- Filter noise: Use min_duration_ms=1.0 to skip very fast operations\n- Debugging: Use log_interval=1, min_duration_ms=0.0 to log everything\n\n\nViewing Profiling Metrics\nIn your SwanLab dashboard, profiling metrics appear under the “profiling” namespace:\nprofiling/Time taken: AxolotlTrainer.training_step\nprofiling/Time taken: AxolotlTrainer.prediction_step\nprofiling/Time taken: MyTrainer.forward_pass\nprofiling/Time taken: MyTrainer.backward_pass\nYou can:\n- Track over time: See if methods get faster/slower during training\n- Compare runs: Compare profiling metrics across experiments\n- Identify regressions: Detect if a code change slowed down training\n\n\nConfiguration in Axolotl Config\nProfiling is automatically enabled when SwanLab is enabled. No additional config needed:\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: my-project\nTo disable profiling while keeping SwanLab enabled:\nfrom axolotl.integrations.swanlab.profiling import DEFAULT_PROFILING_CONFIG\n\nDEFAULT_PROFILING_CONFIG.enabled = False\n\n\nPerformance Impact\n\nDecorator overhead: ~2-5 microseconds per call (negligible)\nContext manager overhead: ~1-3 microseconds (negligible)\nLogging overhead: Only when SwanLab is enabled and method duration exceeds threshold\nNetwork overhead: SwanLab batches metrics efficiently\n\nExpected overhead: < 0.1% per training step (effectively zero)\n\n\nBest Practices\n\nProfile bottlenecks first: Start by profiling suspected slow operations\nUse min_duration_ms: Filter out fast operations (< 1ms) to reduce noise\nThrottle high-frequency calls: Use log_interval for methods called > 100 times/step\nProfile across runs: Compare profiling metrics before/after optimization\nMonitor distributed training: Check for rank-specific slowdowns\n\n\n\nExample: Complete Profiling Setup\nfrom axolotl.integrations.swanlab.profiling import (\n swanlab_profile,\n swanlab_profiling_context,\n ProfilingConfig,\n)\n\nclass OptimizedTrainer(AxolotlTrainer):\n def __init__(self, *args, **kwargs):\n super().__init__(*args, **kwargs)\n\n # Custom profiling config for high-frequency operations\n self.fast_op_config = ProfilingConfig(\n enabled=True,\n min_duration_ms=0.5,\n log_interval=50,\n )\n\n @swanlab_profile\n def training_step(self, model, inputs):\n \"\"\"Main training step - always profile.\"\"\"\n return super().training_step(model, inputs)\n\n @swanlab_profile\n def compute_loss(self, model, inputs, return_outputs=False):\n \"\"\"Loss computation - always profile.\"\"\"\n return super().compute_loss(model, inputs, return_outputs)\n\n def _prepare_inputs(self, inputs):\n \"\"\"High-frequency operation - throttled profiling.\"\"\"\n with swanlab_profiling_context_advanced(\n self,\n \"prepare_inputs\",\n config=self.fast_op_config,\n ):\n return super()._prepare_inputs(inputs)\n\n\nTroubleshooting\n\nProfiling metrics not appearing in SwanLab\nCause: SwanLab is not enabled or not initialized.\nSolution:\nuse_swanlab: true\nswanlab_project: my-project\nCheck logs for:\nINFO: SwanLab initialized for project: my-project\n\n\nToo many profiling metrics cluttering dashboard\nCause: Profiling every function call for high-frequency operations.\nSolution: Use ProfilingConfig with throttling:\nconfig = ProfilingConfig(\n min_duration_ms=1.0, # Skip fast ops\n log_interval=100, # Log every 100th call\n)\n\n\nProfiling overhead impacting training speed\nCause: Profiling itself should have negligible overhead (< 0.1%). If you see > 1% slowdown, this indicates a bug.\nSolution:\n1. Disable profiling temporarily to confirm:\npython DEFAULT_PROFILING_CONFIG.enabled = False\n2. Report issue with profiling data and trainer details\n\n\nProfiling shows inconsistent timing\nCause: Normal variation due to GPU warmup, data loading, or system load.\nSolution:\n- Ignore first few steps (warmup period)\n- Look at average/median timing over many steps\n- Use log_interval to reduce noise from individual outliers\n\n\n\nComplete Config Example\nHeres a complete example integrating SwanLab with your RVQ-Alpha training:\nbase_model: /path/to/your/model\nmodel_type: Qwen2ForCausalLM\n\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin\n\nuse_swanlab: true\nswanlab_project: RVQ-Alpha-Training\nswanlab_experiment_name: Qwen2.5-7B-MetaQA-Perturb-P020\nswanlab_description: \"Training on MetaQA and Perturbation datasets with NEW-RVQ encoding\"\nswanlab_mode: cloud\nswanlab_workspace: single-cell-genomics\n\nsequence_len: 32768\nmicro_batch_size: 1\ngradient_accumulation_steps: 1\nnum_epochs: 2\nlearning_rate: 2e-5\noptimizer: adamw_torch_fused\n\ndatasets:\n - path: /path/to/dataset\n type: chat_template\n\noutput_dir: ./outputs\n\n\nModes Explained\n\n\ncloud Mode (Default)\n\nSyncs experiments to SwanLab cloud in real-time\nRequires API key and internet connection\nBest for: Team collaboration, remote monitoring\n\n\n\nlocal Mode\n\nSaves experiments locally only\nNo cloud sync\nBest for: Local development, air-gapped environments\n\n\n\noffline Mode\n\nSaves metadata locally\nCan sync to cloud later using swanlab sync\nBest for: Unstable internet, sync later\n\n\n\ndisabled Mode\n\nTurns off SwanLab completely\nNo logging or tracking\nBest for: Debugging, testing\n\n\n\nConfiguration Validation & Conflict Detection\nSwanLab integration includes comprehensive validation and conflict detection to help you catch configuration errors early and avoid performance issues.\n\n\nRequired Fields Validation\nThe plugin validates your configuration at startup and provides clear error messages with solutions:\n\nMissing Project Name\nuse_swanlab: true\nSolution:\nuse_swanlab: true\nswanlab_project: my-project\n\n\nInvalid Mode\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_mode: invalid-mode\nSolution:\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_mode: cloud # or: local, offline, disabled\n\n\nEmpty Project Name\nuse_swanlab: true\nswanlab_project: \"\"\nSolution:\nuse_swanlab: true\nswanlab_project: my-project\n\n\n\nCloud Mode API Key Warning\nWhen using cloud mode without an API key, youll receive a warning with multiple solutions:\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_mode: cloud\nSolutions:\n1. Set environment variable: export SWANLAB_API_KEY=your-api-key\n2. Add to config (less secure): swanlab_api_key: your-api-key\n3. Run swanlab login before training\n4. Use swanlab_mode: local for offline tracking\n\n\nMulti-Logger Performance Warnings\nUsing multiple logging tools simultaneously (SwanLab + WandB + MLflow + Comet) can impact training performance:\n\nTwo Loggers - Warning\nuse_swanlab: true\nswanlab_project: my-project\n\nuse_wandb: true\nwandb_project: my-project\nImpact:\n- Performance overhead: ~1-2% per logger (cumulative)\n- Increased memory usage\n- Longer training time per step\n- Potential config/callback conflicts\nRecommendations:\n- Choose ONE primary logging tool for production training\n- Use multiple loggers only for:\n- Migration period (transitioning between tools)\n- Short comparison runs\n- Debugging specific tool issues\n- Monitor system resources (CPU, memory) during training\n\n\nThree+ Loggers - Error-Level Warning\nuse_swanlab: true\nswanlab_project: my-project\n\nuse_wandb: true\nwandb_project: my-project\n\nuse_mlflow: true\nmlflow_tracking_uri: http://localhost:5000\nWhy This Matters:\n- With 3 loggers: ~4-5% overhead per step → significant slowdown over long training\n- Example: 10,000 steps at 2s/step → ~400-500 seconds extra (6-8 minutes)\n- Memory overhead scales with number of loggers\n- Rare edge cases with callback ordering conflicts\n\n\n\nAuto-Enable Logic\nFor convenience, SwanLab will auto-enable if you specify a project without setting use_swanlab:\nswanlab_project: my-project\n\nuse_swanlab: true\nswanlab_project: my-project\n\n\nDistributed Training Detection\nIn distributed training scenarios (multi-GPU), the plugin automatically detects and reports:\nuse_swanlab: true\nswanlab_project: my-project\nswanlab_mode: cloud\nWhy Only Rank 0:\n- Avoids duplicate experiment runs\n- Reduces network/cloud API overhead on worker ranks\n- Prevents race conditions in metric logging\n\n\nAuthentication\n\n\nMethod 1: Environment Variable (Recommended)\nexport SWANLAB_API_KEY=your-api-key-here\n\n\nMethod 2: Login Command\nswanlab login\n\n\nMethod 3: Config File\nswanlab_api_key: your-api-key-here\n\n\nWhat Gets Logged?\n\n\nAutomatically Logged Metrics\n\nTraining loss\nLearning rate\nGradient norm\nTraining steps\nEpoch progress\n\n\n\nAutomatically Logged Config\n\nModel configuration (base_model, model_type)\nTraining hyperparameters (learning_rate, batch_size, etc.)\nOptimizer settings\nParallelization settings (FSDP, DeepSpeed, Context Parallel)\nAxolotl configuration file\nDeepSpeed configuration (if used)\n\n\n\nViewing Your Experiments\n\n\nCloud Mode\nVisit https://swanlab.cn and navigate to your project to view:\n- Real-time training metrics\n- Hyperparameter comparison\n- System resource usage\n- Configuration files\n\n\nLocal Mode\nswanlab watch ./swanlog\n\n\nIntegration with Existing Tools\nSwanLab can work alongside other tracking tools:\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin\n\nuse_swanlab: true\nswanlab_project: my-project\n\nuse_wandb: true\nwandb_project: my-project\n\n\nTroubleshooting\n\n\nConfiguration Errors\n\nError: “SwanLab enabled but swanlab_project is not set”\nCause: You enabled SwanLab (use_swanlab: true) but forgot to specify a project name.\nSolution:\nuse_swanlab: true\nswanlab_project: my-project # Add this line\n\n\nError: “Invalid swanlab_mode: xxx”\nCause: You provided an invalid mode value.\nSolution: Use one of the valid modes:\nswanlab_mode: cloud # or: local, offline, disabled\n\n\nError: “swanlab_project cannot be an empty string”\nCause: You set swanlab_project: \"\" (empty string).\nSolution: Either provide a valid name or remove the field:\nswanlab_project: my-project\n\n\n\nImport Errors\n\nError: “SwanLab is not installed”\nCause: SwanLab package is not installed in your environment.\nSolution:\npip install swanlab\npip install swanlab>=0.3.0\n\n\n\nPerformance Issues\n\nWarning: “Multiple logging tools enabled”\nCause: You have multiple experiment tracking tools enabled (e.g., SwanLab + WandB + MLflow).\nImpact: ~1-2% performance overhead per logger, cumulative.\nSolution: For production training, disable all but one logger:\nuse_swanlab: true\nswanlab_project: my-project\nuse_wandb: false # Disable others\nuse_mlflow: false\n\nuse_swanlab: false\nuse_wandb: true\nwandb_project: my-project\nException: Multiple loggers are acceptable for:\n- Short comparison runs (< 100 steps)\n- Migration testing between logging tools\n- Debugging logger-specific issues\n\n\n\nDistributed Training Issues\n\nSwanLab creates duplicate runs in multi-GPU training\nCause: All ranks are initializing SwanLab instead of just rank 0.\nExpected Behavior: The plugin automatically ensures only rank 0 initializes SwanLab. You should see:\nInfo: Distributed training detected (world_size=4)\nInfo: Only rank 0 will initialize SwanLab\nInfo: Other ranks will skip SwanLab to avoid conflicts\nIf you see duplicates:\n1. Check your plugin is loaded correctly\n2. Verify youre using the latest SwanLab integration code\n3. Check logs for initialization messages on all ranks\n\n\n\nSwanLab not logging metrics\nSolution: Ensure SwanLab is initialized before training starts. The plugin automatically handles this in pre_model_load.\n\n\nAPI Key errors\nSolution:\necho $SWANLAB_API_KEY\n\nswanlab login\n\n\nCloud sync issues\nSolution: Use offline mode and sync later:\nswanlab_mode: offline\nThen sync when ready:\nswanlab sync ./swanlog\n\n\nPlugin not loaded\nSolution: Verify plugin path in config:\nplugins:\n - axolotl.integrations.swanlab.SwanLabPlugin # Correct path\n\n\nLark Notification Issues\n\nError: “Failed to import SwanLab Lark plugin”\nCause: Your SwanLab version doesnt include the Lark plugin (requires SwanLab >= 0.3.0).\nSolution:\npip install --upgrade swanlab\n\npip install 'swanlab>=0.3.0'\n\n\nWarning: “Lark webhook has no secret configured”\nCause: You provided swanlab_lark_webhook_url but no swanlab_lark_secret.\nImpact: Lark notifications will work, but without HMAC authentication (security risk).\nSolution: Add HMAC secret for production use:\nswanlab_lark_webhook_url: https://open.feishu.cn/open-apis/bot/v2/hook/xxx\nswanlab_lark_secret: your-webhook-secret # Add this line\nWhen its OK to skip secret:\n- Local development and testing\n- Internal networks with restricted access\n- Non-sensitive training experiments\nWhen secret is required:\n- Production training jobs\n- Training with proprietary data\n- Multi-team shared Lark groups\n\n\nError: “Failed to register Lark callback”\nCause: Invalid webhook URL or network connectivity issues.\nDiagnostic steps:\ncurl -X POST \"YOUR_WEBHOOK_URL\" \\\n -H 'Content-Type: application/json' \\\n -d '{\"msg_type\":\"text\",\"content\":{\"text\":\"Test from Axolotl\"}}'\n\npip show swanlab\nSolution:\n1. Verify webhook URL is correct (copy from Lark bot settings)\n2. Check network connectivity to Lark API\n3. Ensure webhook is not expired (Lark webhooks can expire)\n4. Regenerate webhook URL in Lark bot settings if needed\n\n\nLark notifications not received\nCause: Multiple possible causes.\nDiagnostic checklist:\n\nCheck training logs for Lark registration confirmation:\n# Expected log message (rank 0 only):\nINFO: Registered Lark notification callback with HMAC authentication\nVerify webhook in Lark: Test webhook manually (see above)\nCheck distributed training: Only rank 0 sends notifications\n# If running multi-GPU, check rank 0 logs specifically\ngrep \"Registered Lark\" logs/rank_0.log\nVerify SwanLab is initialized: Lark callback needs SwanLab to be running\nuse_swanlab: true # Must be enabled\nswanlab_project: my-project # Must be set\nCheck Lark bot permissions: Ensure bot is added to the target group chat\n\n\n\nDuplicate Lark notifications in multi-GPU training\nExpected Behavior: Should NOT happen - only rank 0 sends notifications.\nIf you see duplicates:\n1. Check that all GPUs are using the same config file\n2. Verify plugin is loaded correctly on all ranks\n3. Check logs for unexpected Lark initialization on non-zero ranks\n4. Ensure RANK or LOCAL_RANK environment variables are set correctly\nSolution: This is a bug if it occurs. Report with:\n- Full training command\n- Logs from all ranks\n- Config file\n\n\n\nComparison: SwanLab vs WandB\n\n\n\nFeature\nSwanLab\nWandB\n\n\n\n\nOpen Source\n✅ Yes\n❌ No\n\n\nSelf-Hosting\n✅ Easy\n⚠ Complex\n\n\nFree Tier\n✅ Generous\n⚠ Limited\n\n\nChinese Support\n✅ Native\n⚠ Limited\n\n\nOffline Mode\n✅ Full support\n✅ Supported\n\n\nIntegration\n🆕 New\n✅ Mature\n\n\n\n\n\nAdvanced Usage\n\n\nCustom Logging\nYou can add custom metrics in your callbacks:\nimport swanlab\n\nswanlab.log({\n \"custom_metric\": value,\n \"epoch\": epoch_num\n})\n\n\nExperiment Comparison\nswanlab compare run1 run2 run3\n\n\nSupport\n\nDocumentation: https://docs.swanlab.cn\nGitHub: https://github.com/SwanHubX/SwanLab\nIssues: Report bugs at GitHub Issues\n\n\n\nLicense\nThis integration follows the Axolotl Community License Agreement.\n\n\nAcknowledgements\nThis integration is built on top of:\n- SwanLab - Experiment tracking tool\n- Transformers - SwanLabCallback\n- Axolotl - Training framework\nPlease see reference here",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/custom_integrations.html#adding-a-new-integration",
"href": "docs/custom_integrations.html#adding-a-new-integration",
"title": "Custom Integrations",
"section": "Adding a new integration",
"text": "Adding a new integration\nPlugins can be used to customize the behavior of the training pipeline through hooks. See axolotl.integrations.BasePlugin for the possible hooks.\nTo add a new integration, please follow these steps:\n\nCreate a new folder in the src/axolotl/integrations directory.\nAdd any relevant files (LICENSE, README.md, ACKNOWLEDGEMENTS.md, etc.) to the new folder.\nAdd __init__.py and args.py files to the new folder.\n\n\n__init__.py should import the integration and hook into the appropriate functions.\nargs.py should define the arguments for the integration.\n\n\n(If applicable) Add CPU tests under tests/integrations or GPU tests under tests/e2e/integrations.\n\n\n\n\n\n\n\nTip\n\n\n\nSee src/axolotl/integrations/cut_cross_entropy for a minimal integration example.\n\n\n\n\n\n\n\n\nWarning\n\n\n\nIf you could not load your integration, please ensure you are pip installing in editable mode.\npip install -e .\nand correctly spelled the integration name in the config file.\nplugins:\n - axolotl.integrations.your_integration_name.YourIntegrationPlugin\n\n\n\n\n\n\n\n\nNote\n\n\n\nIt is not necessary to place your integration in the integrations folder. It can be in any location, so long as its installed in a package in your python env.\nSee this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer",
"crumbs": [
"Advanced Features",
"Custom Integrations"
]
},
{
"objectID": "docs/ray-integration.html",
"href": "docs/ray-integration.html",
"title": "Ray Train",
"section": "",
"text": "Axolotl supports using Ray as an alternative to accelerate for orchestrating training. This is especially useful for multi-node training since you only have to setup code and dependencies in a single node and launch training as if you were using a single node.\nWith the --use-ray CLI flag, Axolotl will use Ray Trains TorchTrainer to run training.",
"crumbs": [
"Deployments",
"Ray Train"
]
},
{
"objectID": "docs/ray-integration.html#ray-cluster-setup",
"href": "docs/ray-integration.html#ray-cluster-setup",
"title": "Ray Train",
"section": "Ray cluster setup",
"text": "Ray cluster setup\nA prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here.\nEvery Ray cluster has one head node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this doc.",
"crumbs": [
"Deployments",
"Ray Train"
]
},
{
"objectID": "docs/ray-integration.html#sanity-check",
"href": "docs/ray-integration.html#sanity-check",
"title": "Ray Train",
"section": "Sanity check",
"text": "Sanity check\nTo run a sanity check on whether your ray cluster is setup properly, execute the following on the head node:\nray status\nThe output should have a summary of your Ray cluster - list of all the nodes in your cluster, the number of CPUs and GPUs in your cluster, etc. For example, if you have a cluster with 1 CPU-only head node and 2 4xL40S worker nodes, the output can look like this:\nNode status\n---------------------------------------------------------------\nActive:\n 1 head\nIdle:\n 2 4xL40S:48CPU-384GB\nPending:\n (no pending nodes)\nRecent failures:\n (no failures)\n\nResources\n---------------------------------------------------------------\nUsage:\n 0.0/96.0 CPU\n 0.0/8.0 GPU\n 0B/800.00GiB memory\n 0B/229.57GiB object_store_memory\n\nDemands:\n (no resource demands)\nYou should also be able to see the same on the Ray dashboard.",
"crumbs": [
"Deployments",
"Ray Train"
]
},
{
"objectID": "docs/ray-integration.html#configuring-training-with-ray-train",
"href": "docs/ray-integration.html#configuring-training-with-ray-train",
"title": "Ray Train",
"section": "Configuring training with Ray Train",
"text": "Configuring training with Ray Train\nYou can find an example configuration at configs/llama-3/lora-1b-ray.yaml.\nThe key parameters to note here are:\nuse_ray: true\nray_num_workers: 4\n# optional\nresources_per_worker:\n GPU: 1\n\nuse_ray: This is the flag that enables the Ray Train integration. You can either use the corresponding --use-ray flag in the CLI or set use_ray in the config file.\nray_num_workers: This is the number of workers/GPUs to use for training.\nresources_per_worker: This is the Ray resource request for each worker. This can be used to request a specific GPU type or a custom resource for each worker. For example, if your ray cluster has GPUs of different types, and you only want to use NVIDIA L40S GPUs, you can do\n\nresources_per_worker:\n accelerator_type:L40S: 0.001",
"crumbs": [
"Deployments",
"Ray Train"
]
},
{
"objectID": "docs/ray-integration.html#launching-training",
"href": "docs/ray-integration.html#launching-training",
"title": "Ray Train",
"section": "Launching training",
"text": "Launching training\nYou can simply run the following command on the head node:\naxolotl train examples/llama-3/lora-1b-ray.yml --use-ray\nThis will launch training on the head node and workers will be scheduled automatically by Ray Train to run on the appropriate head or worker nodes.\nYou can also monitor training progress on the Ray dashboard.\nComing back to the example on a Ray cluster with 1 head node and 2 4xL40S worker nodes, lets say you want to make use of all 8 GPUs. You would be able to just set ray_num_workers: 8 and run the previous command. The Cluster tab will show the following:\n\n\n\nRay dashboard",
"crumbs": [
"Deployments",
"Ray Train"
]
},
{
"objectID": "docs/config-reference.html",
"href": "docs/config-reference.html",
"title": "Config Reference",
"section": "",
"text": "# Allow overwrite yml config using from cli\nstrict: bool | None = False\n# Resume from a specific checkpoint dir\nresume_from_checkpoint: str | None\n# If resume_from_checkpoint isn't set and you simply want it to start where it left off.\n# Be careful with this being turned on between different models.\nauto_resume_from_checkpoints: bool | None\n# Resize the model embeddings when new tokens are added to multiples of 32. This is\n# reported to improve training speed on some models\nresize_token_embeddings_to_32x: bool | None\nmean_resizing_embeddings: bool | None = False\n\n# Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.\nshrink_embeddings: bool | None\n# Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs\nembeddings_skip_upcast: bool | None\n# Reinitialize model weights randomly instead of loading pretrained weights\nreinit_weights: bool | None\n\n# module to custom trainer class to use for training\ntrainer_cls: str | None\n\n# Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo', 'ebft'\nrl: RLType | None\n\ntrl: TRLConfig | None\n # For TRLConfig:\n # Beta parameter for the RL training. Same as `rl_beta`. Use\n beta: float | None\n # Maximum length of the completion for RL training.\n max_completion_length: int | None\n\n # Whether to use VLLM for RL training.\n use_vllm: bool = False\n # VLLM mode to use, one of 'server' or 'colocate'\n vllm_mode: Literal['server', 'colocate'] | None\n # Host of the vLLM server to connect to.\n vllm_server_host: str | None = 0.0.0.0\n # Port of the vLLM server to connect to.\n vllm_server_port: int | None = 8000\n # Total timeout (in seconds) to wait for the vLLM server to respond.\n vllm_server_timeout: int | None\n # Regex for vLLM guided decoding.\n vllm_guided_decoding_regex: str | None\n\n # List of reward functions to load. Paths must be importable from current dir.\n reward_funcs: list[str] | None\n # List of reward weights for the reward functions.\n reward_weights: list[float] | None\n # Batch size for generation. Controls how many unique prompts are generated per step.\n # For full DP utilization, set to num_generations * data_parallel_size (or a multiple\n # thereof).\n generation_batch_size: int | None\n # Number of generations to sample.\n num_generations: int | None\n # Whether to log completions.\n log_completions: bool | None = False\n # Number of completions to print when log_completions is True.\n num_completions_to_print: int | None\n # Controls whether importance sampling ratios are computed at the `'token'` or\n # `'sequence'` level. For GSPO, use `sequence`, default is None which corresponds to\n # the original GRPO paper.\n importance_sampling_level: Literal['sequence', 'token'] | None\n\n # Whether to sync the reference model.\n sync_ref_model: bool | None = False\n # Mixup alpha for the reference model.\n ref_model_mixup_alpha: float | None = 0.9\n # Sync steps for the reference model.\n ref_model_sync_steps: int | None = 64\n # Whether to scale rewards by their standard deviation.\n scale_rewards: bool = True\n\n # Sampling temperature for the GRPO policy.\n temperature: float | None\n # Top-p sampling probability for the generation policy.\n top_p: float | None\n # Top-k sampling for the generation policy.\n top_k: int | None\n # Minimum probability for the generation policy.\n min_p: float | None\n # Penalty for tokens that appear in prompt and generated text.\n repetition_penalty: float | None\n # Additional generation parameters passed to vLLM SamplingParams. Useful for\n # stop_token_ids, seed, frequency_penalty, etc.\n generation_kwargs: dict[str, Any] | None\n # Additional kwargs for the chat template. E.g., {enable_thinking: false} for Qwen3.5\n # models.\n chat_template_kwargs: dict[str, Any] | None\n # Number of iterations per batch (μ) for GRPO.\n num_iterations: int | None\n # Epsilon value for clipping in the GRPO algorithm.\n epsilon: float | None\n # Upper-bound epsilon value for clipping in the GRPO algorithm.\n epsilon_high: float | None\n # Whether to use Liger loss for GRPO.\n use_liger_loss: bool | None\n # Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.\n loss_type: str | None\n # Whether to exclude truncated completions from loss calculation.\n mask_truncated_completions: bool = False\n # Enable sleep mode for vLLM to offload VRAM when idle\n vllm_enable_sleep_mode: bool | None\n # Path to custom rollout function. Must be importable from current dir.\n rollout_func: str | None\n # Multi-objective reward aggregation strategy. 'sum_then_normalize' (GRPO default):\n # weights and sums rewards first, then normalizes. 'normalize_then_sum' (GDPO):\n # normalizes each reward independently, then sums.\n multi_objective_aggregation: Literal['sum_then_normalize', 'normalize_then_sum'] | None\n\n # Use the GRPODataProducer protocol for online data generation.\n use_data_producer: bool = False\n # Generate rollouts in a background thread while training on the previous rollout.\n async_prefetch: bool = False\n # Number of rollouts to prefetch ahead of training.\n prefetch_depth: int | None\n # Sync model weights to vLLM every N optimizer steps (async mode only).\n vllm_sync_interval: int | None\n # Score prompt groups incrementally instead of the full batch at once.\n streaming_partial_batch: bool | None\n # Minimum prompt groups to score per streaming chunk.\n streaming_min_groups: int | None\n # Apply IS correction for distribution mismatch between vLLM and training model.\n vllm_importance_sampling_correction: bool | None\n # IS mode: token_truncate, token_mask, sequence_truncate, or sequence_mask.\n vllm_importance_sampling_mode: Literal['token_truncate', 'token_mask', 'sequence_truncate', 'sequence_mask'] | None\n # Cap C for IS ratio clipping/masking.\n vllm_importance_sampling_cap: float | None\n # KL threshold for off-policy sequence masking (OPSM). None = disabled.\n off_policy_mask_threshold: float | None\n # Apply IS correction to KL divergence term.\n use_bias_correction_kl: bool | None\n\n # Number of persistent subprocess workers for parallel reward computation. Each worker\n # has its own main thread so signal.alarm() (used by math_verify) works correctly.\n # Work is sharded across workers by prompt groups. Only used with\n # use_data_producer=True and non-nn.Module reward functions.\n reward_num_workers: int = 1\n # [Experimental, disabled by default] Size of the replay buffer for storing high-\n # signal rollout groups. When > 0, groups with reward variance are cached and used to\n # replace zero-signal groups (where all rewards are identical). Set to 0 to disable.\n # Only used with use_data_producer=True.\n replay_buffer_size: int = 0\n # When True (default), recompute old_per_token_logps for replayed groups using the\n # current training model. This fixes the importance sampling mismatch that occurs when\n # replaying stale data. Only relevant when replay_buffer_size > 0.\n replay_recompute_logps: bool = True\n # Fraction of total training steps after which deferred re-rolling begins. Zero-signal\n # prompts (where all rewards in a group are identical) are buffered and re-injected\n # into later batches when the model is more likely to solve them. Set to 1.0 to\n # disable. Only used with use_data_producer=True.\n reroll_start_fraction: float = 1.0\n # Maximum number of prompt groups to replace with re-roll candidates per batch. Higher\n # values increase data utilization but reduce prompt diversity. Only used with\n # use_data_producer=True.\n reroll_max_groups: int = 1\n # When True, skip gradient computation for micro-batches where all advantages are zero\n # (no learning signal). This avoids the forward/backward pass entirely when no\n # learning signal is present. The step is logged with skipped_zero_adv_batches=1 for\n # monitoring.\n skip_zero_advantage_batches: bool = True\n # Sync LoRA adapter to vLLM via filesystem instead of merging + NCCL broadcast. Auto-\n # selects vllm_serve_lora serve module. Syncs only LoRA adapter weights vs full merged\n # model.\n vllm_lora_sync: bool = False\n\nvllm: VllmConfig | None\n # For VllmConfig:\n # Device to use for VLLM\n device: str | None = auto\n # Tensor parallel size for VLLM\n tensor_parallel_size: int | None\n # Data parallel size for VLLM\n data_parallel_size: int | None\n # GPU memory utilization for VLLM\n gpu_memory_utilization: float | None = 0.9\n # Data type for VLLM\n dtype: str | None = auto\n # Maximum length of the model context for VLLM\n max_model_len: int | None\n # Enable prefix caching for VLLM\n enable_prefix_caching: bool | None\n # Host for the vLLM server to start on\n host: str | None = 0.0.0.0\n # Port of the vLLM server to start on\n port: int | None = 8000\n\n # Enable reasoning for VLLM\n enable_reasoning: bool | None\n # Reasoning parser for VLLM\n reasoning_parser: str | None\n # Disable CUDA graph capture in vLLM. Required for models with causal_conv1d (e.g.,\n # Qwen3.5 hybrid linear attention).\n enforce_eager: bool | None\n # Python module for vLLM serve script. Set to 'axolotl.scripts.vllm_serve_lora' for\n # native LoRA support, or leave None for default TRL serve.\n serve_module: str | None\n # vLLM worker extension class for weight synchronization. Defaults to\n # 'trl.scripts.vllm_serve.WeightSyncWorkerExtension'.\n worker_extension_cls: str | None\n\n# Configuration for Energy-Based Fine-Tuning (EBFT)\nebft: EBFTConfig | None\n # For EBFTConfig:\n # Fractional layer depths for feature extraction (e.g., [0.25, 0.5, 0.75])\n feature_layers: list[float] = [0.25, 0.5, 0.75]\n # Embedding method: 'last_token', 'mean_pooling', 'completion_mean', or 'concat'\n embed_method: Literal['last_token', 'mean_pooling', 'completion_mean', 'concat'] = last_token\n # Apply SVD whitening to feature embeddings\n use_whitening: bool = False\n # Coefficient for alignment reward (cosine similarity with ground truth)\n alignment_coef: float = 1.0\n # Coefficient for diversity penalty (pairwise similarity between samples)\n diversity_coef: float = 1.0\n # Cross-entropy loss coefficient on ground-truth tokens\n ce_coef: float = 0.0\n # Set per-batch max_tokens based on ground-truth length\n adaptive_max_tokens: bool = True\n # Multiplier for ground-truth token count when computing adaptive max_tokens\n gt_length_multiplier: float = 1.5\n\n # EBFT mode: 'structured' (QA with vLLM) or 'strided' (unstructured text)\n mode: Literal['structured', 'strided'] = structured\n # Stride between anchor points (tokens)\n stride: int = 8\n # Context window size per block\n context_length: int = 8\n # Tokens to generate per block\n generate_max_len: int = 8\n # Independent rollouts per document\n n_samples_per_prompt: int = 4\n # Sampling temperature for strided generation\n temperature: float = 0.6\n # Top-p nucleus sampling threshold\n top_p: float = 1.0\n # RL policy gradient loss coefficient\n rl_coef: float = 1.0\n # Advantage estimator: 'rloo', 'group_norm', 'reinforce'\n advantage_estimator: Literal['rloo', 'group_norm', 'reinforce'] = rloo\n # Minimum tokens into completion before placing anchors. Skips anchors too close to\n # the prompt boundary where features are dominated by prompt context.\n min_completion_prefix: int = 0\n\nqat: QATConfig | None\n # For QATConfig:\n # Fake quantization layout to use for activation quantization.\n activation_dtype: TorchAOQuantDType | None\n # Fake quantization layout to use for weight quantization.\n weight_dtype: TorchAOQuantDType = TorchAOQuantDType.int8\n # Quantize embedding\n quantize_embedding: bool | None = False\n # The number of elements in each group for per-group fake quantization\n group_size: int | None = 32\n # The number of steps to apply fake quantization after\n fake_quant_after_n_steps: int | None\n\nquantization: PTQConfig | None\n # For PTQConfig:\n # Fake quantization layout to use for weight quantization.\n weight_dtype: TorchAOQuantDType = TorchAOQuantDType.int8\n # Fake quantization layout to use for activation quantization.\n activation_dtype: TorchAOQuantDType | None\n # Whether to quantize the embedding layer.\n quantize_embedding: bool | None\n # The number of elements in each group for per-group fake quantization\n group_size: int | None = 32\n\n# Reward modelling: `True` or `False`\nreward_model: bool | None\n\n# Configuration for dynamic checkpointing (trigger by file or signal). Set 'enabled:\n# true' to activate this feature.\ndynamic_checkpoint: DynamicCheckpointConfig | None\n # For DynamicCheckpointConfig:\n # Enable dynamic checkpoint triggering during training. Create a file\n # 'axolotl_checkpoint.save' in the configured `output_dir` to trigger.\n enabled: bool = False\n # Check for trigger file every N steps (reduces I/O overhead). Default: 100\n check_interval: int = 10\n # Custom trigger filename (optional). If not specified, defaults to\n # 'axolotl_checkpoint.save'. Specify a filename (not a full path) to override the\n # default.\n trigger_file_path: str = \n\n# Process reward modelling: `True` or `False`\nprocess_reward_model: bool | None\n# Coefficient to incentivize the reward model to output mean-zero rewards (proposed by\n# https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.\ncenter_rewards_coefficient: float | None\nnum_labels: int | None\n\n# Whether to perform weighting in DPO trainer\ndpo_use_weighting: bool | None\ndpo_label_smoothing: float | None\n# Precompute reference model log probabilities for DPO\nprecompute_ref_log_probs: bool | None\n\n# Whether to use Liger kernel for DPO loss.\ndpo_use_liger_kernel: bool | None\n\ndpo_padding_free: bool | None\n\n# A list of one or more datasets to finetune the model with\ndatasets: Annotated[list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset], MinLen(1)] | None\n # For SFTDataset:\n # HuggingFace dataset repo | s3:// | gs:// | path to local file or directory\n path: str | None\n # name of dataset split to load from\n split: str | None\n # The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]\n type: str | UserDefinedPrompterType | None\n # For UserDefinedPrompterType:\n # Custom user instruction prompt\n system_prompt: str | None\n # Use {system} as key to be replaced\n system_format: str | None\n field_system: str | None\n field_instruction: str | None\n field_input: str | None\n field_output: str | None\n\n # Customizable to be single line or multi-line. Use {instruction}/{input} as key to\n # be replaced. 'format' can include {input}\n format: str | None\n # 'no_input_format' cannot include {input}\n no_input_format: str | None\n input_transform: str | None\n # split dataset into N pieces (use with shards_idx)\n shards: int | None\n # the index of sharded dataset to use\n shards_idx: int | None\n # process dataset in N sequential chunks for memory efficiency (exclusive with\n # `shards`)\n preprocess_shards: int | None\n conversation: str | None\n\n # The name of the chat template to use for training, following values are supported:\n # tokenizer_default: Uses the chat template that is available in the\n # tokenizer_config.json. If the chat template is not available in the tokenizer, it\n # will raise an error. This is the default.\n # alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates\n # are available in the axolotl codebase at src/axolotl/utils/chat_templates.py.\n # tokenizer_default_fallback_*: where * is the name of the chat template to fallback\n # to if the tokenizer does not have a chat template else default to tokenizer. E.g.\n # tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat\n # template. The custom jinja template should be provided in the chat_template_jinja\n # field.\n chat_template: ChatTemplate | str | None\n # Custom jinja chat template or path to jinja file. Used only if `chat_template:\n # jinja` or empty.\n chat_template_jinja: str | None\n # path to source data files\n data_files: str | list[str] | None\n input_format: str | None\n # name of dataset configuration to load\n name: str | None\n # defines the datatype when path is a file\n ds_type: str | None\n # For `completion` datasets only, uses the provided field instead of `text` column\n field: str | None\n field_human: str | None\n field_model: str | None\n # Key containing the messages (default: \"messages\")\n field_messages: str | None\n # Key containing the tools (default: \"tools\"). Must be a list[dict] and follow [JSON\n # schema](https://json-schema.org/learn/getting-started-step-by-step).\n field_tools: str | None\n # Key containing the reasoning trace (default: \"reasoning_content\").\n field_thinking: str | None\n # The key the chat template expects that indicates the reasoning trace.\n template_thinking_key: str | None\n\n message_field_role: str | None\n\n message_field_content: str | None\n # Mapping of properties from the input dataset to the chat template. (default:\n # message_property_mappings={'role':'role', 'content':'content'}) If a property exists\n # in the template but not in this mapping, the system will attempt to load it directly\n # from the message using the property name as the key. Example: In the mapping below,\n # 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and\n # used as 'content' in the chat template.\n message_property_mappings: dict[str, str] | None\n # The key in the message turn that indicates via boolean whether tokens of a turn\n # should be considered for training. Useful to selectively train on certain turns\n # besides the `roles_to_train`.\n message_field_training: str | None\n # The key in the message turn that contains the training details. Useful to\n # selectively train on certain tokens in a turn. The value of the key is a List[Dict]\n # containing `begin_offset` (start character index in content), `end_offset` (end\n # character index in content), and `train` (boolean whether to train).\n message_field_training_detail: str | None\n # (for Qwen3 template only) Whether to split the assistant content based on a\n # reasoning trace inside delimited tags\n split_thinking: bool | None\n logprobs_field: str | None\n temperature: float | None\n # Roles to train on. The tokens from these roles will be considered for the loss.\n roles_to_train: list[str] | None\n # Which EOS tokens to train on in the conversation. Possible values are: all: train on\n # all EOS tokens, turn (default): train on the EOS token at the end of each trainable\n # turn, last: train on the last EOS token in the conversation\n train_on_eos: Literal['all', 'turn', 'last'] | None\n # Roles mapping in the messages. The format is {target_role: [source_roles]}. All\n # source roles will be mapped to the target role. The default is: user: [\"human\",\n # \"user\"], assistant: [\"gpt\", \"assistant\"], system: [\"system\"], tool: [\"tool\"]\n roles: dict[str, list[str]] | None\n # Whether to drop the system turn from the dataset. Only works with chat_template.\n # This does not drop the default system message from chat_template if it exists. If\n # you wish to, we recommend using a custom jinja template with the default system\n # message removed or adding a system turn with empty content.\n drop_system_message: bool | None\n # Trust remote code for untrusted source\n trust_remote_code: bool | None = False\n # The specific revision of the dataset to use when loading from the Hugging Face Hub.\n # This can be a commit hash, tag, or branch name. If not specified, the latest version\n # will be used. This parameter is ignored for local datasets.\n revision: str | None\n\n # For DPODataset:\n path: str | None\n split: str | None\n type: UserDefinedDPOType | str | None\n # For UserDefinedDPOType:\n field_system: str | None\n field_prompt: str | None\n field_chosen: str | None\n field_rejected: str | None\n prompt_format: str | None\n chosen_format: str | None\n rejected_format: str | None\n data_files: list[str] | None\n revision: str | None\n field_messages: str | None\n\n # For KTODataset:\n path: str | None\n split: str | None\n type: UserDefinedKTOType | str | None\n # For UserDefinedKTOType:\n field_system: str | None\n field_prompt: str | None\n field_completion: str | None\n field_label: bool | None\n prompt_format: str | None\n completion_format: str | None\n data_files: list[str] | None\n trust_remote_code: bool | None = False\n revision: str | None\n\n # For StepwiseSupervisedDataset:\n path: str | None\n split: str | None\n data_files: list[str] | None\n revision: str | None\n step_separator: str | None\n max_completion_length: int | None\n train_on_last_step_only: bool | None\n\n # For SyntheticDataset:\n path: Literal['synthetic'] = synthetic\n type: Literal['_synthetic'] = _synthetic\n # Number of rows to generate\n length: int = 1000\n # Sequence length per row (defaults to sequence_len from config)\n sequence_length: int | None\n # Minimum token ID for generation\n min_input_id: int = 100\n # Maximum token ID for generation (defaults to tokenizer vocab_size)\n max_input_id: int | None\n # Random seed for reproducibility\n seed: int | None\n\n# A list of one or more datasets to eval the model with. You can use either\n# test_datasets, or val_set_size, but not both.\ntest_datasets: Annotated[list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset | SyntheticDataset], MinLen(1)] | None\n # For SFTDataset:\n # HuggingFace dataset repo | s3:// | gs:// | path to local file or directory\n path: str | None\n # name of dataset split to load from\n split: str | None\n # The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]\n type: str | UserDefinedPrompterType | None\n # For UserDefinedPrompterType:\n # Custom user instruction prompt\n system_prompt: str | None\n # Use {system} as key to be replaced\n system_format: str | None\n field_system: str | None\n field_instruction: str | None\n field_input: str | None\n field_output: str | None\n\n # Customizable to be single line or multi-line. Use {instruction}/{input} as key to\n # be replaced. 'format' can include {input}\n format: str | None\n # 'no_input_format' cannot include {input}\n no_input_format: str | None\n input_transform: str | None\n # split dataset into N pieces (use with shards_idx)\n shards: int | None\n # the index of sharded dataset to use\n shards_idx: int | None\n # process dataset in N sequential chunks for memory efficiency (exclusive with\n # `shards`)\n preprocess_shards: int | None\n conversation: str | None\n\n # The name of the chat template to use for training, following values are supported:\n # tokenizer_default: Uses the chat template that is available in the\n # tokenizer_config.json. If the chat template is not available in the tokenizer, it\n # will raise an error. This is the default.\n # alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates\n # are available in the axolotl codebase at src/axolotl/utils/chat_templates.py.\n # tokenizer_default_fallback_*: where * is the name of the chat template to fallback\n # to if the tokenizer does not have a chat template else default to tokenizer. E.g.\n # tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat\n # template. The custom jinja template should be provided in the chat_template_jinja\n # field.\n chat_template: ChatTemplate | str | None\n # Custom jinja chat template or path to jinja file. Used only if `chat_template:\n # jinja` or empty.\n chat_template_jinja: str | None\n # path to source data files\n data_files: str | list[str] | None\n input_format: str | None\n # name of dataset configuration to load\n name: str | None\n # defines the datatype when path is a file\n ds_type: str | None\n # For `completion` datasets only, uses the provided field instead of `text` column\n field: str | None\n field_human: str | None\n field_model: str | None\n # Key containing the messages (default: \"messages\")\n field_messages: str | None\n # Key containing the tools (default: \"tools\"). Must be a list[dict] and follow [JSON\n # schema](https://json-schema.org/learn/getting-started-step-by-step).\n field_tools: str | None\n # Key containing the reasoning trace (default: \"reasoning_content\").\n field_thinking: str | None\n # The key the chat template expects that indicates the reasoning trace.\n template_thinking_key: str | None\n\n message_field_role: str | None\n\n message_field_content: str | None\n # Mapping of properties from the input dataset to the chat template. (default:\n # message_property_mappings={'role':'role', 'content':'content'}) If a property exists\n # in the template but not in this mapping, the system will attempt to load it directly\n # from the message using the property name as the key. Example: In the mapping below,\n # 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and\n # used as 'content' in the chat template.\n message_property_mappings: dict[str, str] | None\n # The key in the message turn that indicates via boolean whether tokens of a turn\n # should be considered for training. Useful to selectively train on certain turns\n # besides the `roles_to_train`.\n message_field_training: str | None\n # The key in the message turn that contains the training details. Useful to\n # selectively train on certain tokens in a turn. The value of the key is a List[Dict]\n # containing `begin_offset` (start character index in content), `end_offset` (end\n # character index in content), and `train` (boolean whether to train).\n message_field_training_detail: str | None\n # (for Qwen3 template only) Whether to split the assistant content based on a\n # reasoning trace inside delimited tags\n split_thinking: bool | None\n logprobs_field: str | None\n temperature: float | None\n # Roles to train on. The tokens from these roles will be considered for the loss.\n roles_to_train: list[str] | None\n # Which EOS tokens to train on in the conversation. Possible values are: all: train on\n # all EOS tokens, turn (default): train on the EOS token at the end of each trainable\n # turn, last: train on the last EOS token in the conversation\n train_on_eos: Literal['all', 'turn', 'last'] | None\n # Roles mapping in the messages. The format is {target_role: [source_roles]}. All\n # source roles will be mapped to the target role. The default is: user: [\"human\",\n # \"user\"], assistant: [\"gpt\", \"assistant\"], system: [\"system\"], tool: [\"tool\"]\n roles: dict[str, list[str]] | None\n # Whether to drop the system turn from the dataset. Only works with chat_template.\n # This does not drop the default system message from chat_template if it exists. If\n # you wish to, we recommend using a custom jinja template with the default system\n # message removed or adding a system turn with empty content.\n drop_system_message: bool | None\n # Trust remote code for untrusted source\n trust_remote_code: bool | None = False\n # The specific revision of the dataset to use when loading from the Hugging Face Hub.\n # This can be a commit hash, tag, or branch name. If not specified, the latest version\n # will be used. This parameter is ignored for local datasets.\n revision: str | None\n\n # For DPODataset:\n path: str | None\n split: str | None\n type: UserDefinedDPOType | str | None\n # For UserDefinedDPOType:\n field_system: str | None\n field_prompt: str | None\n field_chosen: str | None\n field_rejected: str | None\n prompt_format: str | None\n chosen_format: str | None\n rejected_format: str | None\n data_files: list[str] | None\n revision: str | None\n field_messages: str | None\n\n # For KTODataset:\n path: str | None\n split: str | None\n type: UserDefinedKTOType | str | None\n # For UserDefinedKTOType:\n field_system: str | None\n field_prompt: str | None\n field_completion: str | None\n field_label: bool | None\n prompt_format: str | None\n completion_format: str | None\n data_files: list[str] | None\n trust_remote_code: bool | None = False\n revision: str | None\n\n # For StepwiseSupervisedDataset:\n path: str | None\n split: str | None\n data_files: list[str] | None\n revision: str | None\n step_separator: str | None\n max_completion_length: int | None\n train_on_last_step_only: bool | None\n\n # For SyntheticDataset:\n path: Literal['synthetic'] = synthetic\n type: Literal['_synthetic'] = _synthetic\n # Number of rows to generate\n length: int = 1000\n # Sequence length per row (defaults to sequence_len from config)\n sequence_length: int | None\n # Minimum token ID for generation\n min_input_id: int = 100\n # Maximum token ID for generation (defaults to tokenizer vocab_size)\n max_input_id: int | None\n # Random seed for reproducibility\n seed: int | None\n\n# If false, the datasets will not be shuffled and will keep their original order in\n# `datasets`. The same applies to the `test_datasets` option and the\n# `pretraining_dataset` option. Default is true.\nshuffle_merged_datasets: bool | None = True\n# If true, each dataset in `datasets` will be shuffled before merging. This allows\n# curriculum learning strategies to be applied at the dataset level. Default is false.\nshuffle_before_merging_datasets: bool | None = False\n# Axolotl attempts to save the dataset as an arrow after packing the data together so\n# subsequent training attempts load faster, relative path\ndataset_prepared_path: str | None\n# Num shards for whole dataset\ndataset_shard_num: int | None\n# Index of shard to use for whole dataset\ndataset_shard_idx: int | None\nskip_prepare_dataset: bool | None = False\n# Number of shards to save the prepared dataset\nnum_dataset_shards_to_save: int | None\n\n# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize\npretraining_dataset: Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None\n # For PretrainingDataset:\n name: str | None\n path: str | None\n split: str | None = train\n text_column: str | None = text\n type: str | None = pretrain\n trust_remote_code: bool | None = False\n data_files: str | None\n skip: int | None\n\n # For SFTDataset:\n # HuggingFace dataset repo | s3:// | gs:// | path to local file or directory\n path: str | None\n # name of dataset split to load from\n split: str | None\n # The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]\n type: str | UserDefinedPrompterType | None\n # For UserDefinedPrompterType:\n # Custom user instruction prompt\n system_prompt: str | None\n # Use {system} as key to be replaced\n system_format: str | None\n field_system: str | None\n field_instruction: str | None\n field_input: str | None\n field_output: str | None\n\n # Customizable to be single line or multi-line. Use {instruction}/{input} as key to\n # be replaced. 'format' can include {input}\n format: str | None\n # 'no_input_format' cannot include {input}\n no_input_format: str | None\n input_transform: str | None\n # split dataset into N pieces (use with shards_idx)\n shards: int | None\n # the index of sharded dataset to use\n shards_idx: int | None\n # process dataset in N sequential chunks for memory efficiency (exclusive with\n # `shards`)\n preprocess_shards: int | None\n conversation: str | None\n\n # The name of the chat template to use for training, following values are supported:\n # tokenizer_default: Uses the chat template that is available in the\n # tokenizer_config.json. If the chat template is not available in the tokenizer, it\n # will raise an error. This is the default.\n # alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates\n # are available in the axolotl codebase at src/axolotl/utils/chat_templates.py.\n # tokenizer_default_fallback_*: where * is the name of the chat template to fallback\n # to if the tokenizer does not have a chat template else default to tokenizer. E.g.\n # tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat\n # template. The custom jinja template should be provided in the chat_template_jinja\n # field.\n chat_template: ChatTemplate | str | None\n # Custom jinja chat template or path to jinja file. Used only if `chat_template:\n # jinja` or empty.\n chat_template_jinja: str | None\n # path to source data files\n data_files: str | list[str] | None\n input_format: str | None\n # name of dataset configuration to load\n name: str | None\n # defines the datatype when path is a file\n ds_type: str | None\n # For `completion` datasets only, uses the provided field instead of `text` column\n field: str | None\n field_human: str | None\n field_model: str | None\n # Key containing the messages (default: \"messages\")\n field_messages: str | None\n # Key containing the tools (default: \"tools\"). Must be a list[dict] and follow [JSON\n # schema](https://json-schema.org/learn/getting-started-step-by-step).\n field_tools: str | None\n # Key containing the reasoning trace (default: \"reasoning_content\").\n field_thinking: str | None\n # The key the chat template expects that indicates the reasoning trace.\n template_thinking_key: str | None\n\n message_field_role: str | None\n\n message_field_content: str | None\n # Mapping of properties from the input dataset to the chat template. (default:\n # message_property_mappings={'role':'role', 'content':'content'}) If a property exists\n # in the template but not in this mapping, the system will attempt to load it directly\n # from the message using the property name as the key. Example: In the mapping below,\n # 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and\n # used as 'content' in the chat template.\n message_property_mappings: dict[str, str] | None\n # The key in the message turn that indicates via boolean whether tokens of a turn\n # should be considered for training. Useful to selectively train on certain turns\n # besides the `roles_to_train`.\n message_field_training: str | None\n # The key in the message turn that contains the training details. Useful to\n # selectively train on certain tokens in a turn. The value of the key is a List[Dict]\n # containing `begin_offset` (start character index in content), `end_offset` (end\n # character index in content), and `train` (boolean whether to train).\n message_field_training_detail: str | None\n # (for Qwen3 template only) Whether to split the assistant content based on a\n # reasoning trace inside delimited tags\n split_thinking: bool | None\n logprobs_field: str | None\n temperature: float | None\n # Roles to train on. The tokens from these roles will be considered for the loss.\n roles_to_train: list[str] | None\n # Which EOS tokens to train on in the conversation. Possible values are: all: train on\n # all EOS tokens, turn (default): train on the EOS token at the end of each trainable\n # turn, last: train on the last EOS token in the conversation\n train_on_eos: Literal['all', 'turn', 'last'] | None\n # Roles mapping in the messages. The format is {target_role: [source_roles]}. All\n # source roles will be mapped to the target role. The default is: user: [\"human\",\n # \"user\"], assistant: [\"gpt\", \"assistant\"], system: [\"system\"], tool: [\"tool\"]\n roles: dict[str, list[str]] | None\n # Whether to drop the system turn from the dataset. Only works with chat_template.\n # This does not drop the default system message from chat_template if it exists. If\n # you wish to, we recommend using a custom jinja template with the default system\n # message removed or adding a system turn with empty content.\n drop_system_message: bool | None\n # Trust remote code for untrusted source\n trust_remote_code: bool | None = False\n # The specific revision of the dataset to use when loading from the Hugging Face Hub.\n # This can be a commit hash, tag, or branch name. If not specified, the latest version\n # will be used. This parameter is ignored for local datasets.\n revision: str | None\n\n# The maximum number of processes to use while preprocessing your input dataset. This\n# defaults to `os.cpu_count()` if not set. For Runpod VMs, it will default to number of\n# vCPUs via RUNPOD_CPU_COUNT.\ndataset_processes: int | None\n# The maximum number of processes to use while preprocessing your input dataset. This\n# defaults to `os.cpu_count()` if not set. For Runpod VMs, it will default to number of\n# vCPUs via RUNPOD_CPU_COUNT.\ndataset_num_proc: int | None\n\n# Deduplicates datasets and test_datasets with identical entries\ndataset_exact_deduplication: bool | None\n# Keep dataset in memory while preprocessing. Only needed if cached dataset is taking\n# too much storage\ndataset_keep_in_memory: bool | None\ndataloader_pin_memory: bool | None\ndataloader_num_workers: int | None\ndataloader_prefetch_factor: int | None\ndataloader_drop_last: bool | None\n\naccelerator_config: dict[str, Any] | None\n\nremove_unused_columns: bool | None\n\n# Push prepared dataset to hub - repo_org/repo_name\npush_dataset_to_hub: str | None\n# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private\n# datasets. Required to be true when used in combination with `push_dataset_to_hub`\nhf_use_auth_token: bool | None\n\ndevice: Any | None\n# Passed through to transformers when loading the model when launched without\n# accelerate. Use `sequential` when training w/ model parallelism to limit memory\ndevice_map: Any | None\nworld_size: int | None\n# Don't mess with this, it's here for accelerate and torchrun\nlocal_rank: int | None\nddp: bool | None\n\n# Seed for reproducibility\nseed: int | None\n# Advanced DDP Arguments - timeout\nddp_timeout: int | None\n# Advanced DDP Arguments - bucket cap in MB\nddp_bucket_cap_mb: int | None\n# Advanced DDP Arguments - broadcast buffers\nddp_broadcast_buffers: bool | None\nddp_find_unused_parameters: bool | None\n\n# Whether to run causal language model evaluation for metrics in\n# `eval_causal_lm_metrics`\ndo_causal_lm_eval: bool | None\n# HF evaluate metrics used during evaluation. Default is ['sacrebleu', 'comet', 'ter',\n# 'chrf', 'perplexity']\neval_causal_lm_metrics: list[str] | None\ndo_bench_eval: bool | None\nbench_dataset: str | None\nbench_split: str | None\nmetric_for_best_model: str | None\ngreater_is_better: bool | None\n\n# High loss value, indicating the learning has broken down (a good estimate is ~2 times\n# the loss at the start of training)\nloss_watchdog_threshold: float | None\n# Number of high-loss steps in a row before the trainer aborts (default: 3)\nloss_watchdog_patience: int | None\n\n# Run garbage collection every `gc_steps` steps. -1 will run on epoch end and before\n# evaluations. Default is 0 (disabled).\ngc_steps: int | None\n\n# Use CUDA bf16. bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection.\n# require >=ampere\nbf16: Literal['auto'] | bool | None = auto\n# Use CUDA fp16\nfp16: bool | None\n# Enable FP8 mixed precision training using TorchAO. Best used in combination with\n# torch.compile.\nfp8: bool | None\n# Enable FSDP float8 all-gather optimization for FP8 training. Can improve training\n# speed by 10-15% when FSDP is enabled.\nfp8_enable_fsdp_float8_all_gather: bool | None\n# No AMP (automatic mixed precision) - require >=ampere\nbfloat16: bool | None\n# No AMP (automatic mixed precision)\nfloat16: bool | None\n# bool to use CUDA tf32 or 'auto' for automatic detection - require >=ampere\ntf32: Literal['auto'] | bool | None = auto\nfloat32: bool | None\n\n# Whether to use gradient checkpointing. Available options are: true, false, 'offload',\n# 'offload_disk'.\n# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing\ngradient_checkpointing: Literal['offload', 'offload_disk'] | bool | None = False\n# Additional kwargs to pass to the trainer for gradient checkpointing\ngradient_checkpointing_kwargs: dict[str, Any] | None\n# Whether to offload activations. Available options are: true, false, 'legacy', 'disk'.\nactivation_offloading: Literal['legacy', 'disk'] | bool | None = False\n# Offload model layer parameters to CPU during forward, prefetch back during backward.\nlayer_offloading: bool | None = False\n\n# List of regex patterns for parameter names to keep unfrozen. All other parameters will\n# be frozen via requires_grad=False. Note: range-based patterns (e.g.\n# embed_tokens.weight$[:32000]) use gradient zeroing rather than a true freeze, so\n# weight decay will still apply to the frozen portion and optimizer states are allocated\n# for the full parameter.\nunfrozen_parameters: list[str] | None\n\n# The maximum length of an input to train with, this should typically be less than 2048\n# as most models have a token/context limit of 2048\nsequence_len: int = 512\n# What to do when a tokenized row exceeds sequence_len. 'drop' removes the row;\n# 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to\n# 'drop' for backward compatibility.\nexcess_length_strategy: Literal['drop', 'truncate', 'raise'] | None\n# The maximum length of an input for evaluation. If not specified, defaults to\n# sequence_len\neval_sequence_len: int | None\nmin_sample_len: int | None\n# maximum prompt length for RL training\nmax_prompt_len: int | None\n# Use efficient multi-packing with block diagonal attention and per sequence\n# position_ids. Recommend set to 'true'\nsample_packing: bool | None\n# The number of samples packed at a time. Increasing the following values helps with\n# packing, but usually only slightly (<%1.)\nsample_packing_group_size: int | None = 100000\n# The number of samples which can be packed into one sequence. Increase if using a large\n# sequence_len with many short samples.\nsample_packing_bin_size: int | None = 200\n# Whether to pack samples sequentially\nsample_packing_sequentially: bool | None\n# The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or\n# 'forkserver'\nsample_packing_mp_start_method: str | None\n# Set to 'false' if getting errors during eval with sample_packing on\neval_sample_packing: bool | None\n# Pad inputs so each step uses constant sized buffers. This will reduce memory\n# fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to\n# True if `sample_packing` enabled\npad_to_sequence_len: bool | None\n# Whether to use sequential sampling for curriculum learning\ncurriculum_sampling: bool | None\nmultipack_real_batches: bool | None\n\n# Use batch flattening for speedups when not using sample_packing\nbatch_flattening: Literal['auto'] | bool | None\n\nuse_pose: bool | None\npose_split_on_token_ids: list[int] | None\npose_max_context_len: int | None\npose_num_chunks: int | None\n\npretrain_multipack_buffer_size: int | None\n# whether to prevent cross attention for packed sequences during pretraining\npretrain_multipack_attn: bool | None = True\n# whether to concatenate samples during pretraining\npretraining_sample_concatenation: bool | None\n\n# Use streaming mode for loading datasets\nstreaming: bool | None\n# Buffer size for multipack streaming datasets\nstreaming_multipack_buffer_size: int | None = 10000\n\n# Whether to use xformers attention patch https://github.com/facebookresearch/xformers\nxformers_attention: bool | None\n# Whether to use scaled-dot-product attention https://pytorch.org/docs/stable/generated/\n# torch.nn.functional.scaled_dot_product_attention.html\nsdp_attention: bool | None\n# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf\ns2_attention: bool | None\nflex_attention: bool | None\nflex_attn_compile_kwargs: dict[str, Any] | None\n# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention\nflash_attention: bool | None\n# Whether to use flash-attention cross entropy implementation - advanced use only\nflash_attn_cross_entropy: bool | None\n# Whether to use flash-attention rms norm implementation - advanced use only\nflash_attn_rms_norm: bool | None\n# Whether to fuse part of the MLP into a single operation\nflash_attn_fuse_mlp: bool | None\n# Whether to use bettertransformers\nflash_optimum: bool | None\n# Whether to use SageAttention https://github.com/thu-ml/SageAttention\nsage_attention: bool | None\n\neager_attention: bool | None\n\n# Specify a custom attention implementation, used mostly for kernels.\nattn_implementation: str | None\n\n# Which experts implementation to use for MoE models,\nexperts_implementation: str | None\n\n# Quantize MoE expert weights on load to reduce VRAM. Requires adapter (lora/qlora) with\n# load_in_4bit or load_in_8bit. Requires CUDA (not compatible with ROCm or other\n# backends). Note: total parameter count may be reported incorrectly when enabled\n# (trainable param count is correct).\nquantize_moe_experts: bool = False\n\n# Whether to use Scaled Softmax (SSMax) attention. Ref: https://arxiv.org/abs/2501.19399\nscaling_softmax: bool | None\n# Scaling factor for SSMax attention. Default is 0.43\nscaling_softmax_factor: float | None\n# Bias for SSMax attention. Default is 0.0. Note: The paper recommends bias=0 for better\n# length generalization.\nscaling_softmax_bias: float | None\n\nunsloth_cross_entropy_loss: bool | None\nunsloth_lora_mlp: bool | None\nunsloth_lora_qkv: bool | None\nunsloth_lora_o: bool | None\nunsloth_rms_norm: bool | None\nunsloth_rope: bool | None\n\n# Apply custom LoRA autograd functions and activation function Triton kernels for speed\n# and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html\nlora_mlp_kernel: bool | None\n# Apply custom LoRA autograd functions and activation function Triton kernels for speed\n# and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html\nlora_qkv_kernel: bool | None\n# Apply custom LoRA autograd functions and activation function Triton kernels for speed\n# and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html\nlora_o_kernel: bool | None\n# Apply custom LoRA autograd function for embedding layers. See:\n# https://docs.axolotl.ai/docs/lora_optims.html\nlora_embedding_kernel: bool | None\n\n# Whether to use chunked cross entropy loss for memory efficiency\nchunked_cross_entropy: bool | None\n# Number of chunks to use for chunked cross entropy loss\nchunked_cross_entropy_num_chunks: int | None\n# Enable Entropy-Aware Focal Training loss (EAFT)\nuse_eaft: bool | None\n# Exponent for entropy weighting in EAFT (default: 1.0)\neaft_alpha: float | None = 1.0\n# Number of top logits for entropy approximation (default: 20)\neaft_k: int | None = 20\n\n# Whether to use ALST tiled mlp for memory efficient long context\ntiled_mlp: bool | None\n\n# Number of shards to use for ALST tiled mlp. If unset, it will be set based on\n# seqlen/hidden_size\ntiled_mlp_num_shards: int | None\n\n# Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on\n# llama.\ntiled_mlp_use_original_mlp: bool | None = True\n\nllama4_linearized_experts: bool | None\n\n# Deepspeed config path. e.g., deepspeed_configs/zero3.json\ndeepspeed: str | dict[str, Any] | None\n# Whether to use deepcompile for faster training with deepspeed\ndeepcompile: bool | None\n# FSDP configuration\nfsdp: list[str] | None\n\n# FSDP configuration options\nfsdp_config: FSDPConfig | None\n # For FSDPConfig:\n # FSDP version\n fsdp_version: int | None\n # Enable activation checkpointing to reduce memory usage during forward passes\n activation_checkpointing: bool | None\n # Offload parameters to CPU to reduce GPU memory usage\n offload_params: bool | None\n # Synchronize module states across all processes\n sync_module_states: bool | None\n # Enable CPU RAM efficient loading to reduce memory usage during model loading\n cpu_ram_efficient_loading: bool | None\n # Disabling this enables swap memory usage for resource-constrained setups when\n # offload_params is enabled.\n cpu_offload_pin_memory: bool | None\n # Use original parameters instead of flattened parameters\n use_orig_params: bool | None\n\n # Type of state dict to use for saving/loading checkpoints\n state_dict_type: Literal['FULL_STATE_DICT', 'LOCAL_STATE_DICT', 'SHARDED_STATE_DICT'] | None\n # Final state dict type to use after training completion\n final_state_dict_type: Literal['FULL_STATE_DICT', 'LOCAL_STATE_DICT', 'SHARDED_STATE_DICT'] | None\n\n # Policy for automatically wrapping modules with FSDP\n auto_wrap_policy: Literal['TRANSFORMER_BASED_WRAP', 'SIZE_BASED_WRAP'] | None\n # Class name of transformer layers to wrap (e.g., 'LlamaDecoderLayer')\n transformer_layer_cls_to_wrap: str | None\n\n # Reshard parameters after forward pass to save memory\n reshard_after_forward: bool | None\n # Mixed precision policy for FSDP (e.g., 'fp16', 'bf16')\n mixed_precision_policy: str | None\n\n# FSDP version\nfsdp_version: int | None\nfsdp_final_state_dict_type: Literal['FULL_STATE_DICT', 'LOCAL_STATE_DICT', 'SHARDED_STATE_DICT'] | None\n\n# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for\n# no eval.\nval_set_size: float | None = 0.0\n\n# Number of devices to shard across. If not set, will use all available devices.\ndp_shard_size: int | None\n# Number of devices to replicate across.\ndp_replicate_size: int | None\n# Deprecated: use `context_parallel_size` instead\nsequence_parallel_degree: int | None\n# Set to a divisor of the number of GPUs available to split sequences into chunks of\n# equal size. Use in long context training to prevent OOM when sequences cannot fit into\n# a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each\n# sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized\n# subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more\n# details.\ncontext_parallel_size: int | None\n# Optional; strides across the key dimension. Larger values use more memory but should\n# make training faster. Must evenly divide the number of KV heads in your model.\nheads_k_stride: int | None\n# One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to\n# 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing\n# case.\nring_attn_func: RingAttnFunc | None\n# Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP.\ntensor_parallel_size: int | None\n\n# Add or change special tokens. If you add tokens here, you don't need to add them to\n# the `tokens` list.\nspecial_tokens: SpecialTokensConfig | None\n # For SpecialTokensConfig:\n bos_token: str | None\n eos_token: str | None\n pad_token: str | None\n unk_token: str | None\n additional_special_tokens: list[str] | None\n\n# Add extra tokens to the tokenizer\ntokens: list[str] | None\n# Mapping token_id to new_token_string to override reserved added_tokens in the\n# tokenizer. Only works for tokens that are not part of the base vocab (aka are\n# added_tokens). Can be checked if they exist in tokenizer.json added_tokens.\nadded_tokens_overrides: dict[int, str] | None\n\n# Whether to use torch.compile and which backend to use. setting to `auto` will enable\n# torch compile when torch>=2.6.0\ntorch_compile: Literal['auto'] | bool | None\n# Backend to use for torch.compile\ntorch_compile_backend: str | None\ntorch_compile_mode: Literal['default', 'reduce-overhead', 'max-autotune'] | None\n\n# Maximum number of iterations to train for. It precedes num_epochs which means that if\n# both are set, num_epochs will not be guaranteed. e.g., when 1 epoch is 1000 steps =>\n# `num_epochs: 2` and `max_steps: 100` will train for 100 steps\nmax_steps: int | None\n# Number of warmup steps. Cannot use with warmup_ratio\nwarmup_steps: int | None\n# Warmup ratio. Cannot use with warmup_steps\nwarmup_ratio: float | None\n# Leave empty to eval at each epoch, integer for every N steps. float for fraction of\n# total steps\neval_steps: int | float | None\n# Number of times per epoch to run evals, mutually exclusive with eval_steps\nevals_per_epoch: int | None\n# Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer\n# from `eval_steps`\neval_strategy: str | None\n\n# Leave empty to save at each epoch, integer for every N steps. float for fraction of\n# total steps\nsave_steps: int | float | None\n# Number of times per epoch to save a checkpoint, mutually exclusive with save_steps\nsaves_per_epoch: int | None\n# Set to `no` to skip checkpoint saves, `epoch` at end of each epoch, `best` when better\n# result is achieved, leave empty to infer from `save_steps`\nsave_strategy: str | None\n# Checkpoints saved at a time\nsave_total_limit: int | None\n# Whether to checkpoint a model after the first step of training. Defaults to False.\nsave_first_step: bool | None\n\n# Logging frequency\nlogging_steps: int | None\n# Stop training after this many evaluation losses have increased in a row. https://huggi\n# ngface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppin\n# gCallback\nearly_stopping_patience: int | None\nload_best_model_at_end: bool | None = False\n# Save only the model weights, skipping the optimizer. Using this means you can't resume\n# from checkpoints.\nsave_only_model: bool | None = False\n# Use tensorboard for logging\nuse_tensorboard: bool | None\n# Enable the pytorch profiler to capture the first N steps of training to the\n# output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more\n# information. Snapshots can be visualized @ https://pytorch.org/memory_viz\nprofiler_steps: int | None\n# Which step to start the profiler at. Useful for only capturing a few steps mid-run.\nprofiler_steps_start: int | None = 0\n# bool of whether to report tokens per second at the end of training. This is not\n# supported with pre-training datasets.\ninclude_tokens_per_second: bool | None\n# bool of whether to report tokens per second per-gpu during training by measuring\n# throughput of non-padding tokens.\ninclude_tkps: bool | None = True\n# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to\n# add noise to embeddings. Currently only supported on Llama and Mistral\nneftune_noise_alpha: float | None\n\n# Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to\n# `beta` in `ORPOConfig` due to trl mapping.\norpo_alpha: float | None\n# Target reward margin for the SimPO loss\nsimpo_gamma: float | None\n# Weight of the BC regularizer\ncpo_alpha: float | None\n\n# Factor for desirable loss term in KTO loss\nkto_desirable_weight: float | None\n# Factor for undesirable loss term in KTO loss\nkto_undesirable_weight: float | None\n# The beta parameter for the RL training\nrl_beta: float | None\n\n# Defines the max memory usage per gpu on the system. Passed through to transformers\n# when loading the model.\nmax_memory: dict[int | Literal['cpu', 'disk'], int | str] | None\n# Limit the memory for all available GPUs to this amount (if an integer, expressed in\n# gigabytes); default: unset\ngpu_memory_limit: int | str | None\n# Whether to use low_cpu_mem_usage\nlow_cpu_mem_usage: bool | None\n\n# The name of the chat template to use for training, following values are supported:\n# tokenizer_default: Uses the chat template that is available in the\n# tokenizer_config.json. If the chat template is not available in the tokenizer, it will\n# raise an error. This is the default value.\n# alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates\n# are available in the axolotl codebase at src/axolotl/utils/chat_templates.py.\n# tokenizer_default_fallback_*: where * is the name of the chat template to fallback to.\n# E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not\n# available in the tokenizer. jinja: Uses a custom jinja template for the chat template.\n# The custom jinja template should be provided in the chat_template_jinja field. The\n# selected chat template will be saved to the tokenizer_config.json for easier\n# inferencing\nchat_template: ChatTemplate | Annotated[str, StringConstraints(pattern='^tokenizer_default_fallback_')] | None\n# Custom jinja template or path to jinja file for chat template. This will be only used\n# if chat_template is set to `jinja` or `null` (in which case chat_template is\n# automatically set to `jinja`). Default is null.\nchat_template_jinja: str | None\n# Additional kwargs to pass to the chat template. This is useful for customizing the\n# chat template. For example, you can pass `thinking=False` to add a generation prompt\n# to the chat template.\nchat_template_kwargs: dict[str, Any] | None\n# Custom EOT (End-of-Turn) tokens to mask/unmask during training. These tokens mark the\n# boundaries between conversation turns. For example: ['/INST', '</s>',\n# '[/SYSTEM_PROMPT]']. If not specified, defaults to just the model's eos_token. This is\n# useful for templates that use multiple delimiter tokens.\neot_tokens: list[str] | None\n# Changes the default system message. Currently only supports chatml.\ndefault_system_message: str | None\n\n# Token index or indices to adjust embedding weights to the mean of the other tokens.\n# This is useful when the model has untrained embeddings.\nfix_untrained_tokens: int | list[int] | None\n\nis_preprocess: bool | None\npreprocess_iterable: bool | None\n\n# Total number of tokens - internal use\ntotal_num_tokens: int | None\ntotal_supervised_tokens: int | None\n# You can set these packing optimizations AFTER starting a training at least once. The\n# trainer will provide recommended values for these values.\nsample_packing_eff_est: float | None\naxolotl_config_path: str | None\n\n# Internal use only - Used to identify which the model is based on\nis_falcon_derived_model: bool | None\n# Internal use only - Used to identify which the model is based on\nis_llama_derived_model: bool | None\n# Internal use only - Used to identify which the model is based on. Please note that if\n# you set this to true, `padding_side` will be set to 'left' by default\nis_mistral_derived_model: bool | None\n# Internal use only - Used to identify which the model is based on\nis_qwen_derived_model: bool | None\n\n# Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available\n# plugins or doc below for more details.\n# https://docs.axolotl.ai/docs/custom_integrations.html\nplugins: list[str] | None\n# Enable sample generation during training for monitoring\ngenerate_samples: bool | None = False\n# Number of samples to generate at each interval\nnum_generation_samples: int | None = 3\n# Maximum new tokens to generate per sample\ngeneration_max_new_tokens: int | None = 50\n# Temperature for sample generation (0.0 = greedy)\ngeneration_temperature: float | None = 0.7\n# Nucleus sampling parameter for generation\ngeneration_top_p: float | None\n# Top-k sampling parameter for generation\ngeneration_top_k: int | None\n# Ratio of input to use as prompt (0.0-1.0)\ngeneration_prompt_ratio: float | None = 0.5\n# Whether to use sampling (vs greedy decoding)\ngeneration_do_sample: bool | None = True\n\n# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files. This\n# can also be a relative path to a model on disk\nbase_model: str (required)\n# If the base_model repo on hf hub doesn't include configuration .json files, You can\n# set that here, or leave this empty to default to base_model\nbase_model_config: str | None\n# transformers config class (e.g., 'LlamaConfig', 'MistralConfig'). Defaults to\n# AutoConfig.\ncls_model_config: str | None\n# Optional tokenizer configuration path in case you want to use a different tokenizer\n# than the one defined in the base model\ntokenizer_config: str | None\n# use_fast option for tokenizer loading from_pretrained, default to True\ntokenizer_use_fast: bool | None\n# Whether to use the legacy tokenizer setting, defaults to True\ntokenizer_legacy: bool | None\n# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-\n# common tokenizer.\ntokenizer_use_mistral_common: bool | None\n# Corresponding tokenizer for the model AutoTokenizer is a good choice\ntokenizer_type: str | None\n# transformers processor class\nprocessor_type: str | None\n# Whether to save jinja files for tokenizer, transformers default is True\ntokenizer_save_jinja_files: bool | None = True\n# Trust remote code for untrusted source\ntrust_remote_code: bool | None\n\n# Don't move the model to the device before sharding. Set to `false` to revert to legacy\n# behavior.\nexperimental_skip_move_to_device: bool | None = True\n\n# Use custom kernels, e.g. MegaBlocks.\nuse_kernels: bool | None\n\n# Model loading quantization config\nmodel_quantization_config: Literal['Mxfp4Config'] | None\n# kwargs for model quantization config\nmodel_quantization_config_kwargs: dict[str, Any] | None\n\n# Where to save the full-finetuned model to\noutput_dir: str = ./model-out\n# push checkpoints to hub\nhub_model_id: str | None\n# how to push checkpoints to hub\nhub_strategy: str | None\n# branch/revision to push to on hub (default: main)\nhub_revision: str | None\n# Whether to save the model using safetensors format. Defaults to True.\nsave_safetensors: bool | None = True\n\n# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer\nload_in_8bit: bool | None = False\n# Use bitsandbytes 4 bit\nload_in_4bit: bool | None = False\n\n# If you want to use 'lora', 'qlora', or 'llama-adapter', or leave blank to train all\n# parameters in original model\nadapter: Literal['lora', 'qlora', 'llama-adapter'] | None\n# If you already have a lora model trained that you want to load, put that here. This\n# means after training, if you want to test the model, you should set this to the value\n# of `output_dir`. Note that if you merge an adapter to the base model, a new\n# subdirectory `merged` will be created under the `output_dir`.\nlora_model_dir: str | None\nlora_r: int | None\nlora_alpha: int | None\nlora_fan_in_fan_out: bool | None\nlora_target_modules: str | list[str] | None\nlora_target_parameters: str | list[str] | None\n# If true, will target all linear modules\nlora_target_linear: bool | None\n# If you added new tokens to the tokenizer, you may need to save some LoRA modules\n# because they need to know the new tokens. For LLaMA and Mistral, you need to save\n# `embed_tokens` and `lm_head`. It may vary for other models. `embed_tokens` converts\n# tokens to embeddings, and `lm_head` converts embeddings to token probabilities.\nlora_modules_to_save: list[str] | None\nlora_dropout: float | None = 0.0\n# The layer indices to transform, otherwise, apply to all layers\npeft_layers_to_transform: list[int] | None\npeft_layers_pattern: list[str] | None\n\npeft: PeftConfig | None\n # For PeftConfig:\n # Configuration options for loftq initialization for LoRA\n loftq_config: LoftQConfig | None\n # For LoftQConfig:\n # typically 4 bits\n loftq_bits: int = 4\n\n# Whether to use DoRA.\npeft_use_dora: bool | None\n# Whether to use RSLoRA.\npeft_use_rslora: bool | None\n# List of layer indices to replicate.\npeft_layer_replication: list[tuple[int, int]] | None\n# How to initialize LoRA weights. Default to True which is MS original implementation.\npeft_init_lora_weights: bool | str | None\n# A list of token indices to fine-tune on the `embed_tokens` layer. Otherwise, a dict\n# mapping an embedding layer name to its trainable token indices. See\n# https://huggingface.co/docs/peft/v0.17.0/en/developer_guides/lora#efficiently-train-\n# tokens-alongside-lora\npeft_trainable_token_indices: list[int] | dict[str, list[int]] | None\n# Whether to tie adapter weights for tied model weights. See\n# https://github.com/huggingface/peft/issues/2864\npeft_ensure_weight_tying: bool | None\n# Whether to upcast the LoRA adapter to fp32. This is enabled by default in PEFT.\npeft_autocast_adapter_dtype: bool | None\n\n# load qlora model in sharded format for FSDP using answer.ai technique.\nqlora_sharded_model_loading: bool | None = False\n# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it\n# takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge\nlora_on_cpu: bool | None\n# Whether you are training a 4-bit GPTQ quantized model\ngptq: bool | None\n# optional overrides to the bnb 4bit quantization configuration\nbnb_config_kwargs: dict[str, Any] | None\n\n# loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.\nloraplus_lr_ratio: float | None\n# loraplus learning rate for lora embedding layers. Default value is 1e-6.\nloraplus_lr_embedding: float | None = 1e-06\n\nmerge_lora: bool | None\n# Method to use for LoRA merging. 'memory_efficient' (default) processes shards\n# individually to reduce memory usage, 'legacy' loads the full model into memory.\nmerge_method: Literal['legacy', 'memory_efficient'] | None = memory_efficient\n\n# Whether to use ReLoRA. Use with jagged_restart_*steps options.\nrelora: bool | None\n# threshold for optimizer magnitude when pruning\nrelora_prune_ratio: float | None\n# True to perform lora weight merges on cpu during restarts, for modest gpu memory\n# savings\nrelora_cpu_offload: bool | None\n\n# how often to reset for jagged restarts\njagged_restart_steps: int | None\n# how many warmup steps to take after reset for jagged restarts\njagged_restart_warmup_steps: int | None\n# how many anneal steps to take before reset for jagged restarts\njagged_restart_anneal_steps: int | None\n\n# If greater than 1, backpropagation will be skipped and the gradients will be\n# accumulated for the given number of steps.\ngradient_accumulation_steps: int | None = 1\n# The number of samples to include in each batch. This is the number of samples sent to\n# each GPU. Batch size per gpu = micro_batch_size * gradient_accumulation_steps\nmicro_batch_size: int | None = 1\n# Total batch size, we do not recommended setting this manually\nbatch_size: int | None\n# per gpu micro batch size for evals, defaults to value of micro_batch_size\neval_batch_size: int | None\n\n# whether to find batch size that fits in memory. Passed to underlying transformers\n# Trainer\nauto_find_batch_size: bool | None\n\n# Whether to mask out or include the human's prompt from the training labels\ntrain_on_inputs: bool | None = False\n# Group similarly sized data to minimize padding. May be slower to start, as it must\n# download and sort the entire dataset. Note that training loss may have an oscillating\n# pattern with this enabled.\ngroup_by_length: bool | None\n\nlearning_rate: str | float (required)\nembedding_lr: float | None\nembedding_lr_scale: float | None\n# Specify weight decay\nweight_decay: float | None = 0.0\n# Specify optimizer\noptimizer: OptimizerNames | CustomSupportedOptimizers | None = OptimizerNames.ADAMW_TORCH_FUSED\n# Dictionary of arguments to pass to the optimizer\noptim_args: str | dict[str, Any] | None\n# The target modules to optimize, i.e. the module names that you would like to train,\n# right now this is used only for GaLore algorithm\noptim_target_modules: list[str] | Literal['all_linear'] | None\n# Path to torch distx for optim 'adamw_anyprecision'\ntorchdistx_path: str | None\nlr_scheduler: SchedulerType | Literal['one_cycle'] | Literal['rex'] | None = SchedulerType.COSINE\n# Specify a scheduler and kwargs to use with the optimizer\nlr_scheduler_kwargs: dict[str, Any] | None\nlr_quadratic_warmup: bool | None\n# decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of\n# peak lr\ncosine_min_lr_ratio: float | None\n# freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means\n# start cosine_min_lr at 80% of training step\ncosine_constant_lr_ratio: float | None\n# Learning rate div factor\nlr_div_factor: float | None\n\nlr_groups: list[LrGroup] | None\n # For LrGroup:\n name: str (required)\n modules: list[str] (required)\n lr: float (required)\n\n# adamw hyperparams\nadam_epsilon: float | None\n# only used for CAME Optimizer\nadam_epsilon2: float | None\n# adamw hyperparams\nadam_beta1: float | None\n# adamw hyperparams\nadam_beta2: float | None\n# only used for CAME Optimizer\nadam_beta3: float | None\n\n# Dion Optimizer learning rate\ndion_lr: float | None\n# Dion Optimizer momentum\ndion_momentum: float | None\n# Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank\n# dimension.\ndion_rank_fraction: float | None = 1.0\n# Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may\n# be useful to ensure even sharding.\ndion_rank_multiple_of: int | None = 1\n\n# Gradient clipping max norm\nmax_grad_norm: float | None\nnum_epochs: float = 1.0\n\nuse_wandb: bool | None\n# Set the name of your wandb run\nwandb_name: str | None\n# Set the ID of your wandb run\nwandb_run_id: str | None\n# \"offline\" to save run metadata locally and not sync to the server, \"disabled\" to turn\n# off wandb\nwandb_mode: str | None\n# Your wandb project name\nwandb_project: str | None\n# A wandb Team name if using a Team\nwandb_entity: str | None\nwandb_watch: str | None\n# \"checkpoint\" to log model to wandb Artifacts every `save_steps` or \"end\" to log only\n# at the end of training\nwandb_log_model: str | None\n\nuse_mlflow: bool | None\n# URI to mlflow\nmlflow_tracking_uri: str | None\n# Your experiment name\nmlflow_experiment_name: str | None\n# Your run name\nmlflow_run_name: str | None\n# set to true to copy each saved checkpoint on each save to mlflow artifact registry\nhf_mlflow_log_artifacts: bool | None\n\n# Enable or disable Comet integration.\nuse_comet: bool | None\n# API key for Comet. Recommended to set via `comet login`.\ncomet_api_key: str | None\n# Workspace name in Comet. Defaults to the user's default workspace.\ncomet_workspace: str | None\n# Project name in Comet. Defaults to Uncategorized.\ncomet_project_name: str | None\n# Identifier for the experiment. Used to append data to an existing experiment or\n# control the key of new experiments. Default to a random key.\ncomet_experiment_key: str | None\n# Create a new experiment (\"create\") or log to an existing one (\"get\"). Default\n# (\"get_or_create\") auto-selects based on configuration.\ncomet_mode: str | None\n# Set to True to log data to Comet server, or False for offline storage. Default is\n# True.\ncomet_online: bool | None\n# Dictionary for additional configuration settings, see the doc for more details.\ncomet_experiment_config: dict[str, Any] | None\n\nuse_trackio: bool | None\n# Your trackio project name\ntrackio_project_name: str | None\n# Set the name of your trackio run\ntrackio_run_name: str | None\n# Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)\ntrackio_space_id: str | None\n\n# Enable OpenTelemetry metrics collection and Prometheus export\nuse_otel_metrics: bool | None = False\n# Host to bind the OpenTelemetry metrics server to\notel_metrics_host: str | None = localhost\n# Port for the Prometheus metrics HTTP server\notel_metrics_port: int | None = 8000\n\n# the number of activate layers in LISA\nlisa_n_layers: int | None\n# how often to switch layers in LISA\nlisa_step_interval: int | None\n# path under the model to access the layers\nlisa_layers_attribute: str | None = model.layers\n\ngradio_title: str | None\ngradio_share: bool | None\ngradio_server_name: str | None\ngradio_server_port: int | None\ngradio_max_new_tokens: int | None\ngradio_temperature: float | None\n\nuse_ray: bool = False\nray_run_name: str | None\nray_num_workers: int = 1\nresources_per_worker: dict\n\n# The size of the image to resize to. It can be an integer (resized into padded-square\n# image) or a tuple (width, height).If not provided, we will attempt to load from\n# preprocessor.size, otherwise, images won't be resized.\nimage_size: int | tuple[int, int] | None\n# The resampling algorithm to use for image resizing. Default is bilinear. Please refer\n# to PIL.Image.Resampling for more details.\nimage_resize_algorithm: Literal['bilinear', 'bicubic', 'lanczos'] | Resampling | None\n\n# optional overrides to the base model configuration\noverrides_of_model_config: dict[str, Any] | None\n# optional overrides the base model loading from_pretrained\noverrides_of_model_kwargs: dict[str, Any] | None\n# If you want to specify the type of model to load, AutoModelForCausalLM is a good\n# choice too\ntype_of_model: str | None\n# You can specify to choose a specific model revision from huggingface hub\nrevision_of_model: str | None\n\nmax_packed_sequence_len: int | None\nrope_scaling: Any | None\nnoisy_embedding_alpha: float | None\ndpo_beta: float | None\nevaluation_strategy: str | None\neval_table_size: int | None\neval_max_new_tokens: int | None\ndpo_use_logits_to_keep: bool | None\ndpo_generate_during_eval: bool | None\ndpo_norm_loss: bool | None\nrpo_alpha: float | None",
"crumbs": [
"Getting Started",
"Config Reference"
]
},
{
"objectID": "docs/gradient_checkpointing.html",
"href": "docs/gradient_checkpointing.html",
"title": "Gradient Checkpointing, Activation Offloading, and Layer Offloading",
"section": "",
"text": "Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning\nmodels by reducing the memory footprint and improving computational efficiency.\n\nEnabling Gradient Checkpointing\ngradient_checkpointing: true\n\n\nEnabling Activation Offloading\ngradient_checkpointing: true # required for activation offloading\nactivation_offloading: true\nActivation offloading variants:\nThe default activation_offloading: true offloads activations to CPU and uses CUDA streams\nto overlap the communications and computations when offloading.\nThe activation_offloading: legacy naively offloads activations to CPU and without additional optimizations.\nFor resource constrained environments with limited CPU memory, activation_offloading: disk offloads\nactivations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.\n\n\nEnabling Layer Offloading\nlayer_offloading: true\nLayer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU\nand streaming them back to GPU one layer at a time during the forward and backward passes. This is\nparticularly useful for LoRA/QLoRA training where most of the models parameters are frozen — only the\ntrainable adapter weights stay on GPU permanently.\nDuring training, forward and backward hooks on each decoder layer handle the transfer automatically:\n\nForward pass: Before a layer executes, its frozen params are loaded to GPU. The next layer is\nprefetched asynchronously on a separate CUDA stream for overlap.\nBackward pass: Same pattern in reverse — the current layers frozen params are loaded and the\nprevious layer is prefetched.\n\nAfter each layer finishes, its frozen params are offloaded back to CPU pinned memory.\nThis approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory\nis roughly equal to the size of all frozen parameters across all decoder layers, minus one layers worth\nthat is kept on GPU at any given time.\nRequirements:\n\nCUDA GPU (CPU-only training is not supported for this feature)\nWorks with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)\nBest combined with LoRA/QLoRA where most parameters are frozen",
"crumbs": [
"Advanced Features",
"Gradient Checkpointing, Activation Offloading, and Layer Offloading"
]
},
{
"objectID": "docs/grpo.html",
"href": "docs/grpo.html",
"title": "GRPO Training",
"section": "",
"text": "Group Relative Policy Optimization (GRPO) is a reinforcement learning method that improves language models by generating multiple completions per prompt, scoring them with reward functions, and using the relative ranking within each group to compute advantage estimates. Unlike DPO, which requires pre-collected preference pairs, GRPO generates its own training data online and can work with any programmatic reward signal (math correctness, format compliance, code execution results, etc.).\nUse GRPO when you have a task with a verifiable reward signal and want the model to discover solution strategies on its own. Use DPO when you already have human preference data. Use SFT when you have gold-standard completions to imitate directly.\nAxolotls GRPO implementation builds on TRL and adds async generation, streaming scoring, importance sampling correction, replay buffers, and multi-GPU scaling via FSDP and DeepSpeed.",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#overview",
"href": "docs/grpo.html#overview",
"title": "GRPO Training",
"section": "",
"text": "Group Relative Policy Optimization (GRPO) is a reinforcement learning method that improves language models by generating multiple completions per prompt, scoring them with reward functions, and using the relative ranking within each group to compute advantage estimates. Unlike DPO, which requires pre-collected preference pairs, GRPO generates its own training data online and can work with any programmatic reward signal (math correctness, format compliance, code execution results, etc.).\nUse GRPO when you have a task with a verifiable reward signal and want the model to discover solution strategies on its own. Use DPO when you already have human preference data. Use SFT when you have gold-standard completions to imitate directly.\nAxolotls GRPO implementation builds on TRL and adds async generation, streaming scoring, importance sampling correction, replay buffers, and multi-GPU scaling via FSDP and DeepSpeed.",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#architecture",
"href": "docs/grpo.html#architecture",
"title": "GRPO Training",
"section": "Architecture",
"text": "Architecture\nGRPO training uses a two-process architecture: a vLLM server for fast generation and a trainer process for scoring and gradient updates.\nTerminal 1 (GPU 0) Terminal 2 (GPU 1)\n┌──────────────────────┐ ┌──────────────────────────────────┐\n│ vLLM Server │ │ Trainer │\n│ │ HTTP │ │\n│ Serves base model │◄────────────►│ Background thread: │\n│ + LoRA adapter │ /generate │ Send prompts to vLLM │\n│ │ /set_lora │ Pad & collate completions │\n│ Punica kernels for │ │ │\n│ LoRA inference │ │ Main thread: │\n│ │ │ Score completions (rewards) │\n└──────────────────────┘ │ Compute policy log-probs │\n │ Calculate advantages │\n │ PPO-clip gradient update │\n │ Sync LoRA weights to vLLM │\n └──────────────────────────────────┘\nData flow for each training step:\n\nThe background thread sends prompts to vLLM, which generates num_generations completions per prompt.\nThe main thread scores completions using your reward functions.\nAdvantages are computed within each prompt group (group-relative normalization).\nPolicy log-probabilities are computed by running a forward pass on the training model.\nThe PPO-clip loss is computed and gradients are applied.\nPeriodically, LoRA adapter weights are synced back to vLLM so future generations reflect the updated policy.\n\nWith async prefetch enabled, step 1 for the next batch runs concurrently with steps 2-6 for the current batch.",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#quick-start",
"href": "docs/grpo.html#quick-start",
"title": "GRPO Training",
"section": "Quick Start",
"text": "Quick Start\nA GRPO training run requires three components: a YAML config, a reward module (Python file), and a running vLLM server.\n\n1. Write a reward module\nCreate a file called rewards.py in your working directory:\n# rewards.py\nimport re\n\n\ndef accuracy_reward(completions, answer, **kwargs) -> list[float]:\n \"\"\"Check if the completion contains the correct numerical answer.\"\"\"\n rewards = []\n for completion, correct in zip(completions, answer):\n text = completion[0][\"content\"]\n # Extract the last number from the completion\n numbers = re.findall(r\"-?\\d+(?:\\.\\d+)?\", text)\n predicted = numbers[-1] if numbers else \"\"\n rewards.append(1.0 if predicted == str(correct) else 0.0)\n return rewards\n\n\ndef format_reward(completions, **kwargs) -> list[float]:\n \"\"\"Reward completions that use a structured thinking format.\"\"\"\n rewards = []\n for completion in completions:\n text = completion[0][\"content\"]\n has_think = \"<think>\" in text and \"</think>\" in text\n has_answer = \"<answer>\" in text and \"</answer>\" in text\n rewards.append(1.0 if has_think and has_answer else 0.0)\n return rewards\n\n\ndef prompt_transform(cfg, *args, **kwargs):\n \"\"\"Convert GSM8K dataset rows into chat prompts.\"\"\"\n def transform_fn(example, tokenizer=None):\n label = example[\"answer\"].split(\"####\")[-1].strip().replace(\",\", \"\")\n return {\n \"prompt\": [\n {\"role\": \"system\", \"content\": \"Solve the math problem. Show your reasoning in <think> tags and your final numerical answer in <answer> tags.\"},\n {\"role\": \"user\", \"content\": example[\"question\"]},\n ],\n \"answer\": label,\n }\n return transform_fn, {\"remove_columns\": [\"question\"]}\n\n\n2. Write the config\nCreate config.yaml:\nbase_model: Qwen/Qwen2.5-1.5B-Instruct\n\nrl: grpo\nchat_template: tokenizer_default\n\nvllm:\n host: 0.0.0.0\n port: 8000\n gpu_memory_utilization: 0.85\n dtype: auto\n max_model_len: 2048\n\nadapter: lora\nlora_r: 32\nlora_alpha: 64\nlora_target_linear: true\n\ntrl:\n use_vllm: true\n use_data_producer: true\n vllm_server_host: 0.0.0.0\n vllm_server_port: 8000\n vllm_server_timeout: 300\n vllm_lora_sync: true\n num_generations: 8\n max_completion_length: 512\n temperature: 0.7\n reward_funcs:\n - rewards.accuracy_reward\n - rewards.format_reward\n reward_weights:\n - 1.0\n - 0.5\n\ndatasets:\n - path: openai/gsm8k\n name: main\n type: rewards.prompt_transform\n split: train\n\nskip_prepare_dataset: true\nval_set_size: 0.0\nsequence_len: 512\nmicro_batch_size: 2\ngradient_accumulation_steps: 4\nmax_steps: 200\nlearning_rate: 5.0e-6\noptimizer: adamw_torch_fused\nlr_scheduler: cosine\nwarmup_steps: 10\n\nbf16: true\nflash_attention: true\ngradient_checkpointing: true\n\nspecial_tokens:\n pad_token: \"<|endoftext|>\"\n\noutput_dir: ./grpo-output\nlogging_steps: 1\n\n\n3. Start vLLM and train\n# Terminal 1: Start vLLM server on GPU 0\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml\n\n# Wait 30-90 seconds for model loading and CUDA graph capture\n\n# Terminal 2: Train on GPU 1\nCUDA_VISIBLE_DEVICES=1 axolotl train config.yaml\n\n\n\n\n\n\nTip\n\n\n\nUse tmux or separate terminal sessions to manage the two processes. The vLLM server must remain running for the entire training duration.",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#custom-reward-functions",
"href": "docs/grpo.html#custom-reward-functions",
"title": "GRPO Training",
"section": "Custom Reward Functions",
"text": "Custom Reward Functions\n\nFunction signature\nTRL calls reward functions with this signature:\ndef my_reward(completions, **kwargs) -> list[float]:\n\ncompletions is a list of single-element lists, where each element is a dict {\"role\": \"assistant\", \"content\": \"...\"}. So completions[i][0][\"content\"] gives you the text of the i-th completion.\n**kwargs contains all dataset columns that were not removed by the dataset transform. This is how you pass ground truth answers, metadata, or any other information to your reward function.\nReturn a list[float] with the same length as completions. You may return None for individual elements to exclude them from aggregation.\n\n\n\nExample: accuracy reward with answer extraction\ndef accuracy_reward(completions, answer, **kwargs) -> list[float]:\n rewards = []\n for completion, correct_answer in zip(completions, answer):\n text = completion[0][\"content\"]\n # Extract answer from <answer>...</answer> tags\n match = re.search(r\"<answer>(.*?)</answer>\", text, re.DOTALL)\n predicted = match.group(1).strip() if match else \"\"\n rewards.append(1.0 if predicted == str(correct_answer) else 0.0)\n return rewards\n\n\nExample: length penalty\ndef length_penalty(completions, **kwargs) -> list[float]:\n \"\"\"Penalize very short or very long completions.\"\"\"\n rewards = []\n for completion in completions:\n length = len(completion[0][\"content\"])\n if length < 50:\n rewards.append(-0.5)\n elif length > 2000:\n rewards.append(-0.2)\n else:\n rewards.append(0.0)\n return rewards\n\n\nMultiple rewards and weighting\nYou can combine multiple reward functions with different weights:\ntrl:\n reward_funcs:\n - rewards.accuracy_reward\n - rewards.format_reward\n - rewards.length_penalty\n reward_weights:\n - 1.0 # accuracy is most important\n - 0.5 # format compliance\n - 0.1 # mild length preference\nRewards are combined by the multi_objective_aggregation strategy:\n\nsum_then_normalize (default): weights and sums all rewards first, then normalizes across the group.\nnormalize_then_sum (GDPO): normalizes each reward independently, then sums. This prevents one reward from dominating and is recommended when using multiple reward functions with different scales.\n\ntrl:\n multi_objective_aggregation: normalize_then_sum\n\n\nDataset transforms\nThe dataset transform converts raw HuggingFace dataset rows into chat-format prompts:\ndef prompt_transform(cfg, *args, **kwargs):\n def map_fn(example, tokenizer=None):\n return {\n \"prompt\": [\n {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n {\"role\": \"user\", \"content\": example[\"question\"]},\n ],\n # Keep 'answer' column for the reward function\n \"answer\": example[\"answer\"],\n }\n # Remove columns consumed by the transform; keep columns needed by rewards\n return map_fn, {\"remove_columns\": [\"question\"]}\nThe transform returns a tuple of (map_function, kwargs_dict). The remove_columns in the kwargs dict removes columns that are no longer needed. Columns that your reward functions reference via **kwargs (like answer) must not be removed.\n\n\n\n\n\n\nWarning\n\n\n\nThe reward module must be importable from the directory where you run axolotl train. If your reward file is rewards.py, the import path is rewards.accuracy_reward. If it is inside a package my_rewards/scoring.py, use my_rewards.scoring.accuracy_reward.\n\n\n\n\nReward models (neural network rewards)\nInstead of a Python function, you can pass a HuggingFace model path as a reward function. TRL will load it as a reward model and use its scalar output as the reward:\ntrl:\n reward_funcs:\n - OpenAssistant/reward-model-deberta-v3-large-v2\n - rewards.format_reward\n reward_weights:\n - 1.0\n - 0.3\n\n\nUsing math_verify\nThe math_verify library provides robust mathematical answer verification but uses signal.alarm() internally, which only works in the main thread. If you use math_verify in a reward function, set reward_num_workers to use subprocess workers:\ntrl:\n reward_num_workers: 4\nEach worker runs in its own subprocess with its own main thread, so signal.alarm() works correctly.",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#vllm-setup",
"href": "docs/grpo.html#vllm-setup",
"title": "GRPO Training",
"section": "vLLM Setup",
"text": "vLLM Setup\nGRPO requires a running vLLM server for generation. For a complete guide on server modes, LoRA sync, weight synchronization, and restart procedures, see vLLM Serving.\nThe minimal setup:\nvllm:\n host: 0.0.0.0\n port: 8000\n gpu_memory_utilization: 0.85\n\ntrl:\n use_vllm: true\n vllm_lora_sync: true # Recommended with LoRA — faster sync, no NCCL contention\n vllm_sync_interval: 5 # Sync weights every 5 steps\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml # GPU 0: vLLM\nCUDA_VISIBLE_DEVICES=1 axolotl train config.yaml # GPU 1: training\n\n\n\n\n\n\nWarning\n\n\n\nvLLM must be restarted between experiments — stale weight syncs corrupt server state. See Restart Requirements.",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#async-training-features",
"href": "docs/grpo.html#async-training-features",
"title": "GRPO Training",
"section": "Async Training Features",
"text": "Async Training Features\nAsync GRPO overlaps generation and training to reduce wall-clock time. While the model trains on the current batch, the next batch is already being generated by vLLM.\n\nEnabling async prefetch\ntrl:\n use_data_producer: true\n async_prefetch: true\n prefetch_depth: 1\n vllm_sync_interval: 2\n\nuse_data_producer: true enables the data producer protocol (required for all async features).\nasync_prefetch: true runs generation in a background thread.\nprefetch_depth controls how many batches to prefetch ahead (1 is usually sufficient).\nvllm_sync_interval controls how often LoRA weights are synced to vLLM (every N optimizer steps). Lower values mean fresher generations but more sync overhead.\n\n\n\n\n\n\n\nTip\n\n\n\nBecause the background thread generates with slightly stale model weights, async mode benefits from importance sampling correction (see next section). Enable vllm_importance_sampling_correction: true when using async_prefetch: true.\n\n\n\n\nStreaming partial batch\nInstead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This reduces peak memory during scoring and enables finer-grained zero-advantage skipping.\ntrl:\n streaming_partial_batch: true\n streaming_min_groups: 1\nstreaming_min_groups controls the minimum number of prompt groups scored per chunk. Setting it to 1 gives maximum granularity.\n\n\nZero-advantage batch skipping\nWhen all advantages in a micro-batch are zero (every completion in the group got the same reward), there is no learning signal. This feature skips the forward/backward pass entirely for such micro-batches.\ntrl:\n skip_zero_advantage_batches: true # default\nThis is enabled by default and logged as skipped_zero_adv_batches in training metrics. It is a safety net, not a major optimization it only saves significant time when the model cannot solve any prompts in the batch.\n\n\nReplay buffer\nThe replay buffer caches rollout groups that had learning signal (non-zero reward variance) and replaces zero-signal groups in later batches. This improves data utilization when many prompts yield no reward variance.\ntrl:\n replay_buffer_size: 100\n replay_recompute_logps: true\n\n\n\n\n\n\nWarning\n\n\n\nWhen replay_recompute_logps: false, replayed data uses stale log-probabilities which creates an IS mismatch. Keep the default true unless you have a specific reason to disable it.\n\n\n\n\nDeferred re-rolling\nPrompts where the model gets zero reward for all generations are buffered and re-injected into later batches, when the model may have improved enough to produce useful completions.\ntrl:\n reroll_start_fraction: 0.5 # Start re-rolling after 50% of training\n reroll_max_groups: 1 # Max groups to replace per batch\nSet reroll_start_fraction: 1.0 to disable. This is most useful for tasks where the model starts weak but steadily improves.\n\n\nParallel reward workers\nReward functions that use signal.alarm() (like math_verify) only work in the main thread. Parallel reward workers run each function in its own subprocess:\ntrl:\n reward_num_workers: 4\nWork is sharded across workers by prompt group. For simple reward functions, a single worker is usually sufficient the overhead of IPC can exceed the computation time.",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#importance-sampling-and-off-policy-correction",
"href": "docs/grpo.html#importance-sampling-and-off-policy-correction",
"title": "GRPO Training",
"section": "Importance Sampling and Off-Policy Correction",
"text": "Importance Sampling and Off-Policy Correction\nWhen using async prefetch, completions are generated from a slightly older policy. IS correction adjusts the gradient to account for this mismatch.\ntrl:\n vllm_importance_sampling_correction: true\n importance_sampling_level: token # 'token' recommended (especially with Liger kernel)\n off_policy_mask_threshold: 0.5 # KL threshold — masks sequences that are too off-policy\nUse token level IS. Sequence-level has numerical issues with Ligers chunked computation. The off_policy_mask_threshold (OPSM) is a safety net that drops sequences where KL divergence exceeds the threshold — 0.5 is a reasonable starting point.\nFor detailed coverage of IS modes (token_mask, token_truncate, etc.), capping, and bias-corrected KL, see vLLM Serving — IS Correction.",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#scaling",
"href": "docs/grpo.html#scaling",
"title": "GRPO Training",
"section": "Scaling",
"text": "Scaling\n\nFP8 training\nFP8 quantization halves model VRAM usage with minimal impact on training quality. It does not significantly speed up computation for small models but allows larger models to fit in memory.\nfp8: true\ntorch_compile: true\n\n\n\n\n\n\nWarning\n\n\n\nFP8 requires patching for zero-padding edge cases. The act_quant_kernel can produce NaN when input is all zeros (padding positions). If you see NaN in grad norms, check whether your padding token embedding is non-zero.\n\n\n\n\nFSDP (Fully Sharded Data Parallel)\nFSDP distributes model parameters across multiple GPUs for training while vLLM runs on a separate GPU:\nfsdp:\n - full_shard\n - auto_wrap\nfsdp_config:\n fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer\ngradient_checkpointing_kwargs:\n use_reentrant: false\nLaunch with:\n# GPU 0: vLLM\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml\n\n# GPUs 0,1: Training (FSDP will use both visible GPUs)\nCUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml\n\n\n\n\n\n\nWarning\n\n\n\nasync_prefetch: true can deadlock with FSDP because background threads perform unsynchronized FSDP collectives across ranks. With multi-GPU FSDP, only rank 0 generates in the background thread and results are broadcast to all ranks. If you still see hangs, set async_prefetch: false.\n\n\n\n\nDeepSpeed ZeRO-3\ndeepspeed: deepspeed_configs/zero3_bf16.json\ngradient_checkpointing_kwargs:\n use_reentrant: true # Required -- non-reentrant causes CheckpointError with ZeRO-3\n\n\n\n\n\n\nNote\n\n\n\nDeepSpeed ZeRO-3 requires use_reentrant: true for gradient checkpointing. This is the opposite of the FSDP recommendation. Non-reentrant checkpointing causes tensor metadata mismatches during recomputation with ZeRO-3s parameter partitioning.\n\n\n\n\nMulti-GPU considerations\n\n\n\n\n\n\n\nConcern\nRecommendation\n\n\n\n\nvLLM GPU allocation\nDedicate one or more GPUs to vLLM; do not share with trainer GPUs\n\n\nWeight sync contention\nUse vllm_lora_sync: true to avoid NCCL contention between training and vLLM\n\n\nFSDP + async\nUse async_prefetch: false or rely on rank-0-only background generation\n\n\nDeepSpeed + gradient checkpoint\nMust use use_reentrant: true\n\n\nOOM during scoring\nReduce micro_batch_size or num_generations. The logits tensor scales with batch_size * vocab_size",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#monitoring-and-debugging",
"href": "docs/grpo.html#monitoring-and-debugging",
"title": "GRPO Training",
"section": "Monitoring and Debugging",
"text": "Monitoring and Debugging\nFor detailed metric ranges, failure diagnosis, and OOM debugging, see Training Stability & Debugging.\nQuick health checks during GRPO training:\n\nrewards/*/mean should be > 0.15 within 20 steps — if it stays at 0, test your reward function standalone\nreward_std should be > 0 on most steps — all-zero means no learning signal\nentropy in 0.05-0.5 — below 0.01 suggests mode collapse\ngrad_norm in 0.001-1.0 — > 10 is unstable, 0.0 is expected when zero-advantage skip fires\n\n\n\n\n\n\n\nTip\n\n\n\nPipe training output to a log file: axolotl train config.yaml 2>&1 | tee /tmp/training.log",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/grpo.html#configuration-reference",
"href": "docs/grpo.html#configuration-reference",
"title": "GRPO Training",
"section": "Configuration Reference",
"text": "Configuration Reference\nAll GRPO-specific options live under the trl: key in your config. Standard training options (learning_rate, micro_batch_size, etc.) are set at the top level as usual.\n\nCore GRPO\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\nuse_vllm\nbool\nfalse\nEnable vLLM for generation\n\n\nvllm_mode\n\"server\" or \"colocate\"\nnull\nvLLM deployment mode\n\n\nvllm_server_host\nstr\n\"0.0.0.0\"\nvLLM server hostname\n\n\nvllm_server_port\nint\n8000\nvLLM server port\n\n\nvllm_server_timeout\nint\nnull\nTimeout (seconds) for vLLM responses\n\n\nnum_generations\nint\nnull\nCompletions generated per prompt\n\n\ngeneration_batch_size\nint\nnull\nNumber of unique prompts per generation step\n\n\nmax_completion_length\nint\nnull\nMaximum tokens per completion\n\n\nbeta\nfloat\nnull\nKL penalty coefficient\n\n\nnum_iterations\nint\nnull\nIterations per batch (mu in the GRPO paper)\n\n\nepsilon\nfloat\nnull\nPPO clipping lower bound\n\n\nepsilon_high\nfloat\nnull\nPPO clipping upper bound\n\n\nloss_type\nstr\nnull\nLoss formulation: grpo, bnpo, or dr_grpo\n\n\nscale_rewards\nbool\ntrue\nNormalize rewards by standard deviation\n\n\nmask_truncated_completions\nbool\nfalse\nExclude truncated completions from loss\n\n\n\n\n\nReward functions\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\nreward_funcs\nlist[str]\nnull\nImport paths to reward functions or HF model IDs\n\n\nreward_weights\nlist[float]\nnull\nRelative weights for each reward function\n\n\nmulti_objective_aggregation\nstr\nnull\n\"sum_then_normalize\" (GRPO) or \"normalize_then_sum\" (GDPO)\n\n\nrollout_func\nstr\nnull\nImport path to custom rollout function for OpenEnv-style tasks\n\n\n\n\n\nGeneration parameters\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\ntemperature\nfloat\nnull\nSampling temperature\n\n\ntop_p\nfloat\nnull\nNucleus sampling probability\n\n\ntop_k\nint\nnull\nTop-k sampling\n\n\nmin_p\nfloat\nnull\nMinimum probability threshold\n\n\nrepetition_penalty\nfloat\nnull\nPenalty for repeated tokens\n\n\ngeneration_kwargs\ndict\nnull\nAdditional vLLM SamplingParams (e.g., stop_token_ids)\n\n\nchat_template_kwargs\ndict\nnull\nChat template kwargs (e.g., {enable_thinking: false})\n\n\nvllm_guided_decoding_regex\nstr\nnull\nRegex constraint for guided decoding\n\n\n\n\n\nAsync pipeline\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\nuse_data_producer\nbool\nfalse\nEnable data producer protocol (required for async features)\n\n\nasync_prefetch\nbool\nfalse\nGenerate next batch in background thread\n\n\nprefetch_depth\nint\nnull\nNumber of batches to prefetch ahead\n\n\nvllm_sync_interval\nint\nnull\nSync LoRA weights to vLLM every N steps\n\n\nvllm_lora_sync\nbool\nfalse\nUse filesystem LoRA sync instead of NCCL merge\n\n\nstreaming_partial_batch\nbool\nnull\nScore prompt groups incrementally\n\n\nstreaming_min_groups\nint\nnull\nMinimum groups per streaming chunk\n\n\nskip_zero_advantage_batches\nbool\ntrue\nSkip micro-batches with zero learning signal\n\n\nreward_num_workers\nint\n1\nSubprocess workers for reward computation\n\n\nvllm_enable_sleep_mode\nbool\nnull\nOffload vLLM weights when idle (colocate mode)\n\n\n\n\n\nImportance sampling\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\nvllm_importance_sampling_correction\nbool\nnull\nEnable IS correction for async distribution shift\n\n\nimportance_sampling_level\n\"token\" or \"sequence\"\nnull\nGranularity of IS ratios. Use token with Liger\n\n\nvllm_importance_sampling_mode\nstr\nnull\ntoken_mask, token_truncate, sequence_mask, or sequence_truncate\n\n\nvllm_importance_sampling_cap\nfloat\nnull\nCap C for IS ratio clipping/masking\n\n\noff_policy_mask_threshold\nfloat\nnull\nKL threshold for off-policy sequence masking (OPSM)\n\n\nuse_bias_correction_kl\nbool\nnull\nApply IS correction to KL divergence term\n\n\n\n\n\nReplay and re-roll\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\nreplay_buffer_size\nint\n0\nMax cached high-signal groups. 0 = disabled\n\n\nreplay_recompute_logps\nbool\ntrue\nRecompute log-probs for replayed data with current model\n\n\nreroll_start_fraction\nfloat\n1.0\nStart re-rolling failed prompts after this fraction of training. 1.0 = disabled\n\n\nreroll_max_groups\nint\n1\nMax prompt groups to replace with re-rolls per batch\n\n\n\n\n\nReference model\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\nsync_ref_model\nbool\nfalse\nPeriodically sync reference model with training model\n\n\nref_model_mixup_alpha\nfloat\n0.9\nEMA coefficient for reference model sync\n\n\nref_model_sync_steps\nint\n64\nSync reference model every N steps\n\n\n\n\n\nLogging\n\n\n\n\n\n\n\n\n\nOption\nType\nDefault\nDescription\n\n\n\n\nlog_completions\nbool\nfalse\nLog sample completions to W&B\n\n\nnum_completions_to_print\nint\nnull\nNumber of completions to print per step\n\n\nuse_liger_loss\nbool\nnull\nUse Liger fused kernel for GRPO loss (reduces VRAM)",
"crumbs": [
"How To Guides",
"GRPO Training"
]
},
{
"objectID": "docs/choosing_method.html",
"href": "docs/choosing_method.html",
"title": "Which Fine-Tuning Method Should I Use?",
"section": "",
"text": "Axolotl supports four broad categories of fine-tuning, each suited to different data types, objectives, and resource constraints.\n\n\n\n\n\n\n\n\nMethod\nWhat It Does\nData You Need\n\n\n\n\nSupervised Fine-Tuning (SFT)\nTeaches the model to produce specific outputs given inputs\nInput-output pairs (instructions, conversations, completions)\n\n\nPreference Learning (DPO/KTO/ORPO)\nSteers the model toward preferred outputs and away from dispreferred ones\nChosen/rejected response pairs (DPO, ORPO) or binary labels (KTO)\n\n\nReinforcement Learning (GRPO)\nOptimizes the model against a reward signal through online generation\nA reward function (code or model-based) and a prompt dataset\n\n\nReward Modeling\nTrains a model to score responses, for use as a reward signal in RL\nPreference pairs ranked by quality\n\n\n\nEach method is configured through a YAML file with rl: <method> (or omitted for SFT). All methods support LoRA, QLoRA, and full fine-tuning unless otherwise noted.",
"crumbs": [
"Getting Started",
"Which Fine-Tuning Method Should I Use?"
]
},
{
"objectID": "docs/choosing_method.html#sec-overview",
"href": "docs/choosing_method.html#sec-overview",
"title": "Which Fine-Tuning Method Should I Use?",
"section": "",
"text": "Axolotl supports four broad categories of fine-tuning, each suited to different data types, objectives, and resource constraints.\n\n\n\n\n\n\n\n\nMethod\nWhat It Does\nData You Need\n\n\n\n\nSupervised Fine-Tuning (SFT)\nTeaches the model to produce specific outputs given inputs\nInput-output pairs (instructions, conversations, completions)\n\n\nPreference Learning (DPO/KTO/ORPO)\nSteers the model toward preferred outputs and away from dispreferred ones\nChosen/rejected response pairs (DPO, ORPO) or binary labels (KTO)\n\n\nReinforcement Learning (GRPO)\nOptimizes the model against a reward signal through online generation\nA reward function (code or model-based) and a prompt dataset\n\n\nReward Modeling\nTrains a model to score responses, for use as a reward signal in RL\nPreference pairs ranked by quality\n\n\n\nEach method is configured through a YAML file with rl: <method> (or omitted for SFT). All methods support LoRA, QLoRA, and full fine-tuning unless otherwise noted.",
"crumbs": [
"Getting Started",
"Which Fine-Tuning Method Should I Use?"
]
},
{
"objectID": "docs/choosing_method.html#sec-decision-tree",
"href": "docs/choosing_method.html#sec-decision-tree",
"title": "Which Fine-Tuning Method Should I Use?",
"section": "2 Decision Tree",
"text": "2 Decision Tree\nUse the following flowchart to choose your method. Start at the top and follow the path that matches your situation.\nDo you have a reward function (code-based or model-based)?\n├── YES\n│ └── Use GRPO (rl: grpo)\n│ The model generates its own completions and learns from reward scores.\n│ Best for: math, code, reasoning, tasks with verifiable answers.\n│ See: rlhf.qmd#grpo\n│\n└── NO\n │\n Do you have preference pairs (chosen vs. rejected responses)?\n ├── YES\n │ │\n │ Are they paired (same prompt, one chosen, one rejected)?\n │ ├── YES → Use DPO (rl: dpo)\n │ │ Direct optimization without a separate reward model.\n │ │ See: rlhf.qmd#dpo\n │ │\n │ └── NO (only binary good/bad labels)\n │ └── Use KTO (rl: kto)\n │ Works with unpaired preference data.\n │ See: rlhf.qmd#kto\n │\n └── NO\n │\n Do you have input-output examples?\n ├── YES → Use SFT\n │ The simplest and most common method.\n │ See: getting-started.qmd\n │\n └── NO\n └── You need to create training data first.\n Consider generating preference pairs with an LLM judge,\n or writing a reward function for GRPO.\n\n\n\n\n\n\nTip\n\n\n\nWhen in doubt, start with SFT. It is the most straightforward method and works well for most tasks. You can always move to preference learning or RL later to further refine behavior.\n\n\n\n2.1 Method Comparison at a Glance\n\n\n\n\n\n\n\n\n\n\nCriterion\nSFT\nDPO\nKTO\nGRPO\n\n\n\n\nData complexity\nLow (input-output pairs)\nMedium (preference pairs)\nMedium (binary labels)\nLow (prompts + reward code)\n\n\nCompute cost\nLow\nMedium\nMedium\nHigh (requires vLLM server)\n\n\nLearning signal\nSupervised\nContrastive\nContrastive\nOnline reward\n\n\nOnline generation\nNo\nNo\nNo\nYes\n\n\nReward model needed\nNo\nNo\nNo\nNo (uses reward functions)\n\n\nBest for\nTask adaptation, instruction following\nSafety, style alignment\nUnpaired preference data\nReasoning, math, code\n\n\n\n\n\n\n\n\n\nNote\n\n\n\nORPO is an alternative to DPO that combines SFT and preference optimization in a single training stage, removing the need for a separate SFT step. Configure with rl: orpo. See rlhf.qmd for details.",
"crumbs": [
"Getting Started",
"Which Fine-Tuning Method Should I Use?"
]
},
{
"objectID": "docs/choosing_method.html#sec-adapter-selection",
"href": "docs/choosing_method.html#sec-adapter-selection",
"title": "Which Fine-Tuning Method Should I Use?",
"section": "3 Adapter Selection",
"text": "3 Adapter Selection\nOnce you have chosen a method, decide how to apply the parameter updates. The three main options trade off VRAM usage against model quality.\n\n3.1 QLoRA\n\nHow it works: The base model is loaded in 4-bit (NF4) quantization. Small low-rank adapter matrices are trained in higher precision on top.\nVRAM savings: Roughly 4x reduction in model memory compared to full fine-tuning.\nQuality: Slight degradation due to quantization noise, but often negligible for task-specific fine-tuning.\nWhen to use: When your GPU cannot fit the model in full precision, or when you want fast experimentation.\n\nadapter: qlora\nload_in_4bit: true\nlora_r: 32\nlora_alpha: 64\nlora_target_linear: true\n\n\n3.2 LoRA\n\nHow it works: The base model is loaded at full precision (or 8-bit). Low-rank adapter matrices are trained alongside.\nVRAM savings: Roughly 2-3x reduction compared to full fine-tuning (model weights are frozen, only adapters + optimizer states for adapters are stored).\nQuality: Very close to full fine-tuning for most tasks, especially with higher rank values.\nWhen to use: When you have enough VRAM for the base model but not for full optimizer states.\n\nadapter: lora\nlora_r: 32\nlora_alpha: 64\nlora_target_linear: true\n\n\n\n\n\n\nTip\n\n\n\nFor GRPO training, LoRA is strongly recommended. The vLLM server needs to sync weights from the trainer, and LoRA sync (trl.vllm_lora_sync: true) is far more efficient than syncing full merged weights. See vLLM Serving for details.\n\n\n\n\n3.3 Full Fine-Tuning\n\nHow it works: All model parameters are updated during training. No adapters.\nVRAM savings: None. Requires memory for model weights, gradients, and optimizer states (roughly 4x model size in bf16 with AdamW).\nQuality: Highest potential quality, especially for large distribution shifts.\nWhen to use: When you have ample GPU memory or multi-GPU setups, and need maximum performance. Also required for pre-training.\n\n# No adapter or load_in_* lines needed\nmicro_batch_size: 1\ngradient_accumulation_steps: 16\n\n\n3.4 Quick Comparison\n\n\n\n\n\n\n\n\n\n\nQLoRA\nLoRA\nFull\n\n\n\n\nTrainable params\n~0.1-1%\n~0.1-1%\n100%\n\n\nModel memory\n~25% of full\n~50-100% of full\n100%\n\n\nOptimizer memory\nTiny (adapters only)\nTiny (adapters only)\n2x model size (AdamW)\n\n\nTraining speed\nSlower (dequantization overhead)\nBaseline\nFaster per-step (no adapter overhead)\n\n\nInference\nMerge or serve with adapter\nMerge or serve with adapter\nDirect\n\n\nMulti-GPU required?\nRarely\nFor 13B+ models\nFor 7B+ models",
"crumbs": [
"Getting Started",
"Which Fine-Tuning Method Should I Use?"
]
},
{
"objectID": "docs/choosing_method.html#sec-hardware-mapping",
"href": "docs/choosing_method.html#sec-hardware-mapping",
"title": "Which Fine-Tuning Method Should I Use?",
"section": "4 Hardware Mapping",
"text": "4 Hardware Mapping\nThe tables below provide approximate GPU memory requirements. Actual usage depends on context length, batch size, and optimizer choice.\n\n4.1 SFT / Preference Learning\n\n\n\nModel Size\nQLoRA (4-bit)\nLoRA (bf16)\nFull (bf16 + AdamW)\n\n\n\n\n1-3B\n6-8 GB\n8-12 GB\n24-32 GB\n\n\n7-8B\n10-14 GB\n16-24 GB\n60-80 GB\n\n\n13-14B\n16-20 GB\n28-40 GB\n120+ GB\n\n\n30-34B\n24-32 GB\n64-80 GB\n2-4x 80 GB\n\n\n70-72B\n40-48 GB\n2x 80 GB\n4-8x 80 GB\n\n\n\n\n\n\n\n\n\nImportant\n\n\n\nThese estimates assume a short context length (512-2048 tokens) and micro_batch_size of 1-2. Longer sequences and larger batches increase memory significantly due to activations. Use gradient checkpointing to reduce activation memory at the cost of ~30% slower training.\n\n\n\n\n4.2 GRPO (RL Training)\nGRPO requires additional GPU(s) for the vLLM generation server. Plan for at least two GPUs: one for training, one for vLLM.\n\n\n\n\n\n\n\n\n\nModel Size\nTraining GPU (LoRA, bf16)\nvLLM GPU\nTotal GPUs\n\n\n\n\n0.5-3B\n1x 24 GB\n1x 24 GB\n2x 24 GB\n\n\n7-8B\n1x 80 GB\n1x 80 GB\n2x 80 GB\n\n\n13-14B\n1-2x 80 GB\n1-2x 80 GB\n2-4x 80 GB\n\n\n30-72B\n2-4x 80 GB (FSDP/DeepSpeed)\n2-4x 80 GB (tensor parallel)\n4-8x 80 GB\n\n\n\n\n\n\n\n\n\nTip\n\n\n\nFor single-GPU GRPO, use vllm_mode: colocate with vllm_enable_sleep_mode: true. The vLLM engine shares the GPU and offloads VRAM when not generating. This works for smaller models (up to ~3B on a 24 GB GPU) but is slower than the two-GPU server mode.\n\n\n\n\n4.3 Multi-GPU Threshold\nYou need multi-GPU training when:\n\nFull fine-tuning of models 7B+ (use FSDP or DeepSpeed ZeRO)\nLoRA of models 30B+ (or 13B+ with long contexts)\nGRPO almost always (separate vLLM server), unless using colocate mode\n\nSee Multi-GPU Training for FSDP and DeepSpeed configuration.",
"crumbs": [
"Getting Started",
"Which Fine-Tuning Method Should I Use?"
]
},
{
"objectID": "docs/choosing_method.html#sec-quick-links",
"href": "docs/choosing_method.html#sec-quick-links",
"title": "Which Fine-Tuning Method Should I Use?",
"section": "5 Quick Links",
"text": "5 Quick Links\n\n\n\nMethod\nConfig Key\nDocumentation\nExample Config\n\n\n\n\nSFT\n(default, no rl: key)\nGetting Started\nexamples/llama-3/lora-1b.yml\n\n\nDPO\nrl: dpo\nRLHF - DPO\nSee rlhf.qmd\n\n\nKTO\nrl: kto\nRLHF - KTO\nSee rlhf.qmd\n\n\nORPO\nrl: orpo\nRLHF - ORPO\nSee rlhf.qmd\n\n\nGRPO\nrl: grpo\nRLHF - GRPO, vLLM Serving\nSee rlhf.qmd\n\n\nReward Modeling\nrl: reward_trainer\nReward Modelling\nSee reward_modelling.qmd\n\n\n\n\n5.1 Related Guides\n\nConfiguration Reference Full list of all config options\nDataset Formats How to structure your training data\nOptimizations Flash attention, gradient checkpointing, mixed precision\nMulti-GPU Training FSDP and DeepSpeed setup\nvLLM Serving Setting up vLLM for GRPO training",
"crumbs": [
"Getting Started",
"Which Fine-Tuning Method Should I Use?"
]
},
{
"objectID": "docs/models/LiquidAI.html",
"href": "docs/models/LiquidAI.html",
"title": "Liquid Foundation Models 2",
"section": "",
"text": "Liquid Foundation Models 2 (LFM2) are a family of small, open-weight models from Liquid AI focused on quality, speed, and memory efficiency. Liquid AI released text-only LFM2 and text+vision LFM2-VL models.\nLFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.\nThis guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.\nThanks to the team at LiquidAI for giving us early access to prepare for these releases.",
"crumbs": [
"Getting Started",
"Model Guides",
"Liquid Foundation Models 2"
]
},
{
"objectID": "docs/models/LiquidAI.html#getting-started",
"href": "docs/models/LiquidAI.html#getting-started",
"title": "Liquid Foundation Models 2",
"section": "Getting Started",
"text": "Getting Started\n\nInstall Axolotl following the installation guide.\nHere is an example of how to install from pip:\n# Ensure you have a compatible version of Pytorch installed\npip3 install packaging setuptools wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\nRun one of the finetuning examples below.\nLFM2\n# FFT SFT (1x48GB @ 25GiB)\naxolotl train examples/LiquidAI/lfm2-350m-fft.yaml\nLFM2-VL\n# LoRA SFT (1x48GB @ 2.7GiB)\naxolotl train examples/LiquidAI/lfm2-vl-lora.yaml\nLFM2-MoE\npip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6\n\n# LoRA SFT (1x48GB @ 16.2GiB)\naxolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml\n\n\nTIPS\n\nInstallation Error: If you encounter ImportError: ... undefined symbol ... or ModuleNotFoundError: No module named 'causal_conv1d_cuda', the causal-conv1d package may have been installed incorrectly. Try uninstalling it:\npip uninstall -y causal-conv1d\nDataset Loading: Read more on how to load your own dataset in our documentation.\nDataset Formats:\n\nFor LFM2 models, the dataset format follows the OpenAI Messages format as seen here.\nFor LFM2-VL models, Axolotl follows the multi-content Messages format. See our Multimodal docs for details.",
"crumbs": [
"Getting Started",
"Model Guides",
"Liquid Foundation Models 2"
]
},
{
"objectID": "docs/models/LiquidAI.html#optimization-guides",
"href": "docs/models/LiquidAI.html#optimization-guides",
"title": "Liquid Foundation Models 2",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nOptimizations Guide",
"crumbs": [
"Getting Started",
"Model Guides",
"Liquid Foundation Models 2"
]
},
{
"objectID": "docs/models/LiquidAI.html#related-resources",
"href": "docs/models/LiquidAI.html#related-resources",
"title": "Liquid Foundation Models 2",
"section": "Related Resources",
"text": "Related Resources\n\nLFM2 Blog\nLFM2-VL Blog\nLFM2-MoE Blog\nAxolotl Docs\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Liquid Foundation Models 2"
]
},
{
"objectID": "docs/models/magistral.html",
"href": "docs/models/magistral.html",
"title": "Magistral",
"section": "",
"text": "Magistral Small is a 24B parameter opensource model from MistralAI found on HuggingFace at 2506, 2507 (see Thinking), and 2509 (see Vision). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\nMistralAI has also released a proprietary medium-sized version called Magistral Medium.\nThanks to the team at MistralAI for giving us early access to prepare for these releases.",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral"
]
},
{
"objectID": "docs/models/magistral.html#getting-started",
"href": "docs/models/magistral.html#getting-started",
"title": "Magistral",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nHere is an example of how to install from pip:\n\n# Ensure you have Pytorch installed (Pytorch 2.7.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n\nInstall Cut Cross Entropy to reduce training VRAM usage\n\npython scripts/cutcrossentropy_install.py | sh\n\nRun the finetuning example:\n\naxolotl train examples/magistral/magistral-small-qlora.yaml\nThis config uses about 24GB VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nThinking\nMistralAI has released their 2507 model with thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.\n📚 See the Thinking fine-tuning guide →\n\n\nVision\nMistralAI has released their 2509 model with vision capabilities.\n📚 See the Vision fine-tuning guide →\n\n\nTips\n\nWe recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repos files titled SYSTEM_PROMPT.txt.\nFor inference, the official MistralAI team recommends top_p: 0.95 and temperature: 0.7 with max_tokens: 40960.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe text dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral"
]
},
{
"objectID": "docs/models/magistral.html#optimization-guides",
"href": "docs/models/magistral.html#optimization-guides",
"title": "Magistral",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training\nLoRA Optimizations",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral"
]
},
{
"objectID": "docs/models/magistral.html#limitations",
"href": "docs/models/magistral.html#limitations",
"title": "Magistral",
"section": "Limitations",
"text": "Limitations\nWe only support the mistral-common tokenizer for Supervised Fine-tuning at the moment and for type: chat_template only.\nIn addition, we do not support overriding tokens yet.",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral"
]
},
{
"objectID": "docs/models/magistral.html#related-resources",
"href": "docs/models/magistral.html#related-resources",
"title": "Magistral",
"section": "Related Resources",
"text": "Related Resources\n\nMistralAI Magistral Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral"
]
},
{
"objectID": "docs/models/magistral.html#future-work",
"href": "docs/models/magistral.html#future-work",
"title": "Magistral",
"section": "Future Work",
"text": "Future Work\n\nAdd parity to Preference Tuning, RL, etc.\nAdd parity to other tokenizer configs like overriding tokens.",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral"
]
},
{
"objectID": "docs/models/devstral.html",
"href": "docs/models/devstral.html",
"title": "Devstral",
"section": "",
"text": "Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace Devstral-Small-2505 and Devstral-Small-2507. Devstral-Small-2507 is the latest version of the model and has function calling support.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.\nThe model was fine-tuned ontop of Mistral-Small-3.1 without the vision layer and has a context of up to 128k tokens.\nThanks to the team at MistralAI for giving us early access to prepare for this release.",
"crumbs": [
"Getting Started",
"Model Guides",
"Devstral"
]
},
{
"objectID": "docs/models/devstral.html#getting-started",
"href": "docs/models/devstral.html#getting-started",
"title": "Devstral",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nHere is an example of how to install from pip:\n\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n\nInstall Cut Cross Entropy to reduce training VRAM usage\n\npython scripts/cutcrossentropy_install.py | sh\n\nRun the finetuning example:\n\naxolotl train examples/devstral/devstral-small-qlora.yml\nThis config uses about 21GB VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nTIPS\n\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.\nLearn how to use function calling with Axolotl at docs.",
"crumbs": [
"Getting Started",
"Model Guides",
"Devstral"
]
},
{
"objectID": "docs/models/devstral.html#optimization-guides",
"href": "docs/models/devstral.html#optimization-guides",
"title": "Devstral",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training\nLoRA Optimizations\nCut Cross Entropy\nLiger Kernel",
"crumbs": [
"Getting Started",
"Model Guides",
"Devstral"
]
},
{
"objectID": "docs/models/devstral.html#limitations",
"href": "docs/models/devstral.html#limitations",
"title": "Devstral",
"section": "Limitations",
"text": "Limitations\nWe only support the mistral-common tokenizer for Supervised Fine-tuning at the moment and for type: chat_template only.\nIn addition, we do not support overriding tokens yet.",
"crumbs": [
"Getting Started",
"Model Guides",
"Devstral"
]
},
{
"objectID": "docs/models/devstral.html#related-resources",
"href": "docs/models/devstral.html#related-resources",
"title": "Devstral",
"section": "Related Resources",
"text": "Related Resources\n\nMistralAI Devstral Blog\nMistralAI Devstral 1.1 Blog\nAxolotl Docs\nAxolotl GitHub\nAxolotl Website\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Devstral"
]
},
{
"objectID": "docs/models/devstral.html#future-work",
"href": "docs/models/devstral.html#future-work",
"title": "Devstral",
"section": "Future Work",
"text": "Future Work\n\nAdd parity to Preference Tuning, RL, Multi-modal, etc.\nAdd parity to other tokenizer configs like overriding tokens.",
"crumbs": [
"Getting Started",
"Model Guides",
"Devstral"
]
},
{
"objectID": "docs/models/qwen3-next.html",
"href": "docs/models/qwen3-next.html",
"title": "Qwen 3 Next",
"section": "",
"text": "Qwen3-Next represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"Qwen 3 Next"
]
},
{
"objectID": "docs/models/qwen3-next.html#getting-started",
"href": "docs/models/qwen3-next.html#getting-started",
"title": "Qwen 3 Next",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nInstall Cut Cross Entropy to reduce training VRAM usage.\nInstall FLA for improved performance\n\npip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1\n\nRun the finetuning example:\n\naxolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml\nThis config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nTIPS\n\nFor inference, you can experiment with temperature: 0.7, top_p: 0.8, top_k: 20, and min_p: 0.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config. See Multi-GPU section below.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Qwen 3 Next"
]
},
{
"objectID": "docs/models/qwen3-next.html#optimization-guides",
"href": "docs/models/qwen3-next.html#optimization-guides",
"title": "Qwen 3 Next",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training\nLoRA Optimizations",
"crumbs": [
"Getting Started",
"Model Guides",
"Qwen 3 Next"
]
},
{
"objectID": "docs/models/qwen3-next.html#related-resources",
"href": "docs/models/qwen3-next.html#related-resources",
"title": "Qwen 3 Next",
"section": "Related Resources",
"text": "Related Resources\n\nQwen3-Next Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Qwen 3 Next"
]
},
{
"objectID": "docs/models/mistral.html",
"href": "docs/models/mistral.html",
"title": "Mistral 7B",
"section": "",
"text": "Mistral 7B is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.\nFine Tune:\naccelerate launch -m axolotl.cli.train examples/mistral/config.yml\n\nIf you run into CUDA OOM, use deepspeed with config zero2.json:\naccelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json",
"crumbs": [
"Getting Started",
"Model Guides",
"Mistral 7B"
]
},
{
"objectID": "docs/models/plano.html",
"href": "docs/models/plano.html",
"title": "Plano Orchestrator",
"section": "",
"text": "Plano-Orchestrator is a family of 4B and 30B-A3B routing and orchestration models designed for multi-agent systems. It analyzes user intent and conversation context to make precise routing decisions, excelling at multi-turn context understanding, multi-intent detection, and context-dependent routing.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"Plano Orchestrator"
]
},
{
"objectID": "docs/models/plano.html#getting-started",
"href": "docs/models/plano.html#getting-started",
"title": "Plano Orchestrator",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nInstall Cut Cross Entropy to reduce training VRAM usage.\nRun the finetuning example:\naxolotl train examples/plano/plano-4b-qlora.yaml\n\nThis config uses about 5.1 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀\n\nOrchestration Prompt\nPlano-Orchestrator uses a specific orchestration prompt format for routing/agent decisions. Please check the official model card for proper prompt formatting and the ORCHESTRATION_PROMPT template.\n\n\nTips\n\nTo use the larger Plano-Orchestrator-30B-A3B MoE model, simply change base_model: katanemo/Plano-Orchestrator-30B-A3B in the config and enable multi-GPU training if needed.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Plano Orchestrator"
]
},
{
"objectID": "docs/models/plano.html#optimization-guides",
"href": "docs/models/plano.html#optimization-guides",
"title": "Plano Orchestrator",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"Plano Orchestrator"
]
},
{
"objectID": "docs/models/plano.html#related-resources",
"href": "docs/models/plano.html#related-resources",
"title": "Plano Orchestrator",
"section": "Related Resources",
"text": "Related Resources\n\nPlano GitHub\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Plano Orchestrator"
]
},
{
"objectID": "docs/models/olmo3.html",
"href": "docs/models/olmo3.html",
"title": "OLMo 3",
"section": "",
"text": "Olmo 3 are a family of 7B and 32B models open source models trained by The Allen Institute for Artificial Intelligence.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"OLMo 3"
]
},
{
"objectID": "docs/models/olmo3.html#getting-started",
"href": "docs/models/olmo3.html#getting-started",
"title": "OLMo 3",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nInstall Cut Cross Entropy to reduce training VRAM usage.\nRun the finetuning example:\naxolotl train examples/olmo3/olmo3-7b-qlora.yaml\n\nThis uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀\n\nTIPS\n\nThe example config can be re-used for Olmo and Olmo 2.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"OLMo 3"
]
},
{
"objectID": "docs/models/olmo3.html#optimization-guides",
"href": "docs/models/olmo3.html#optimization-guides",
"title": "OLMo 3",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"OLMo 3"
]
},
{
"objectID": "docs/models/olmo3.html#related-resources",
"href": "docs/models/olmo3.html#related-resources",
"title": "OLMo 3",
"section": "Related Resources",
"text": "Related Resources\n\nOlmo 3 Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"OLMo 3"
]
},
{
"objectID": "docs/models/magistral/vision.html",
"href": "docs/models/magistral/vision.html",
"title": "Magistral Vision",
"section": "",
"text": "This guide covers fine-tuning Magistral Small 2509 with vision capabilities using Axolotl.",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral Vision"
]
},
{
"objectID": "docs/models/magistral/vision.html#prerequisites",
"href": "docs/models/magistral/vision.html#prerequisites",
"title": "Magistral Vision",
"section": "Prerequisites",
"text": "Prerequisites\nBefore starting, ensure you have:\n\nInstalled Axolotl from source (see main README)",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral Vision"
]
},
{
"objectID": "docs/models/magistral/vision.html#getting-started",
"href": "docs/models/magistral/vision.html#getting-started",
"title": "Magistral Vision",
"section": "Getting started",
"text": "Getting started\n\nInstall the required vision lib:\nbash pip install 'mistral-common[opencv]==1.8.5'\nDownload the example dataset image:\nwget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg\nRun the fine-tuning:\naxolotl train examples/magistral/vision/magistral-small-vision-24B-qlora.yml\n\nThis config uses about 17GiB VRAM.\nWARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.\n\nTips\nKey differences from text-only model:\n- max_tokens: 131072 for inference\n- Multi-modal dataset format required\n- Sample packing not supported",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral Vision"
]
},
{
"objectID": "docs/models/magistral/vision.html#dataset-format",
"href": "docs/models/magistral/vision.html#dataset-format",
"title": "Magistral Vision",
"section": "Dataset Format",
"text": "Dataset Format\nThe vision model requires multi-modal dataset format as documented here.\nOne exception is that, passing \"image\": PIL.Image is not supported. MistralTokenizer only supports path, url, and base64 for now.\nExample:\n{\n \"messages\": [\n {\"role\": \"system\", \"content\": [{ \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}]},\n {\"role\": \"user\", \"content\": [\n { \"type\": \"text\", \"text\": \"What's in this image?\"},\n {\"type\": \"image\", \"path\": \"path/to/image.jpg\" }\n ]},\n {\"role\": \"assistant\", \"content\": [{ \"type\": \"text\", \"text\": \"...\" }]},\n ],\n}",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral Vision"
]
},
{
"objectID": "docs/models/magistral/vision.html#limitations",
"href": "docs/models/magistral/vision.html#limitations",
"title": "Magistral Vision",
"section": "Limitations",
"text": "Limitations\n\nSample Packing is not supported for multi-modality training currently.",
"crumbs": [
"Getting Started",
"Model Guides",
"Magistral",
"Magistral Vision"
]
},
{
"objectID": "docs/models/mimo.html",
"href": "docs/models/mimo.html",
"title": "MiMo",
"section": "",
"text": "MiMo is a family of models trained from scratch for reasoning tasks, incorporating Multiple-Token Prediction (MTP) as an additional training objective for enhanced performance and faster inference. Pre-trained on ~25T tokens with a three-stage data mixture strategy and optimized reasoning pattern density.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"MiMo"
]
},
{
"objectID": "docs/models/mimo.html#getting-started",
"href": "docs/models/mimo.html#getting-started",
"title": "MiMo",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nRun the finetuning example:\naxolotl train examples/mimo/mimo-7b-qlora.yaml\n\nThis config uses about 17.2 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀\n\nTips\n\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"MiMo"
]
},
{
"objectID": "docs/models/mimo.html#optimization-guides",
"href": "docs/models/mimo.html#optimization-guides",
"title": "MiMo",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"MiMo"
]
},
{
"objectID": "docs/models/mimo.html#limitations",
"href": "docs/models/mimo.html#limitations",
"title": "MiMo",
"section": "Limitations",
"text": "Limitations\nCut Cross Entropy (CCE): Currently not supported. We plan to include CCE support for MiMo in the near future.",
"crumbs": [
"Getting Started",
"Model Guides",
"MiMo"
]
},
{
"objectID": "docs/models/mimo.html#related-resources",
"href": "docs/models/mimo.html#related-resources",
"title": "MiMo",
"section": "Related Resources",
"text": "Related Resources\n\nMiMo Paper\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"MiMo"
]
},
{
"objectID": "docs/models/index.html",
"href": "docs/models/index.html",
"title": "Model Guides",
"section": "",
"text": "Model Guides\nBelow are the curated examples for training various model architectures:\n\nKimi Linear\nPlano Orchestrator\nMiMo\nInternVL 3.5\nOLMo 3\nTrinity\nArcee AFM\nMinistral3\nMinistral 3 Thinking\nMinistral 3 Vision\nMagistral\nMagistral Thinking\nMagistral Vision\nMinistral\nMistral Small 3.1/3.2\nVoxtral\nDevstral\nMistral 7B\nLlama 4\nLlama 2\nQwen 3 Next\nQwen 3\nGemma 3n\nApertus\nGPT-OSS\nSeed-OSS\nPhi\nSmolVLM 2\nGranite 4\nLiquid Foundation Models 2\nHunyuan\nJamba\nOrpheus"
},
{
"objectID": "docs/models/trinity.html",
"href": "docs/models/trinity.html",
"title": "Trinity",
"section": "",
"text": "Trinity is a family of open weight MoE models trained by Arcee.ai.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"Trinity"
]
},
{
"objectID": "docs/models/trinity.html#getting-started",
"href": "docs/models/trinity.html#getting-started",
"title": "Trinity",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the main from the installation guide.\nInstall Cut Cross Entropy to reduce training VRAM usage.\nRun the finetuning example:\naxolotl train examples/trinity/trinity-nano-preview-qlora.yaml\n\nThis config uses about 24.9 GiB VRAM (w/o CCE).\nLet us know how it goes. Happy finetuning! 🚀\n\nTIPS\n\nFor inference, the official Arcee.ai team recommends top_p: 0.75, temperature: 0.15, top_k: 50, and min_p: 0.06.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Trinity"
]
},
{
"objectID": "docs/models/trinity.html#optimization-guides",
"href": "docs/models/trinity.html#optimization-guides",
"title": "Trinity",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"Trinity"
]
},
{
"objectID": "docs/models/trinity.html#related-resources",
"href": "docs/models/trinity.html#related-resources",
"title": "Trinity",
"section": "Related Resources",
"text": "Related Resources\n\nTrinity Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Trinity"
]
},
{
"objectID": "docs/models/kimi-linear.html",
"href": "docs/models/kimi-linear.html",
"title": "Kimi Linear",
"section": "",
"text": "Kimi Linear is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.\nNote: Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.",
"crumbs": [
"Getting Started",
"Model Guides",
"Kimi Linear"
]
},
{
"objectID": "docs/models/kimi-linear.html#getting-started",
"href": "docs/models/kimi-linear.html#getting-started",
"title": "Kimi Linear",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nInstall CCE via docs\nRun the finetuning example:\naxolotl train examples/kimi-linear/kimi-48b-lora.yaml\n\nThis config uses about 98.7GiB VRAM.\nLet us know how it goes. Happy finetuning!\n\nTIPS\n\nKimi Linear requires trust_remote_code: true.\nYou can run a full finetuning by removing the adapter: lora and load_in_8bit: true.\nRead more on how to load your own dataset at docs\nThe dataset format follows the OpenAI Messages format as seen here",
"crumbs": [
"Getting Started",
"Model Guides",
"Kimi Linear"
]
},
{
"objectID": "docs/models/kimi-linear.html#optimization-guides",
"href": "docs/models/kimi-linear.html#optimization-guides",
"title": "Kimi Linear",
"section": "Optimization Guides",
"text": "Optimization Guides\nSee 👉 docs.",
"crumbs": [
"Getting Started",
"Model Guides",
"Kimi Linear"
]
},
{
"objectID": "docs/models/kimi-linear.html#limitations",
"href": "docs/models/kimi-linear.html#limitations",
"title": "Kimi Linear",
"section": "Limitations",
"text": "Limitations\nThis is not yet compatible with MoE kernels from transformers v5.",
"crumbs": [
"Getting Started",
"Model Guides",
"Kimi Linear"
]
},
{
"objectID": "docs/models/kimi-linear.html#related-resources",
"href": "docs/models/kimi-linear.html#related-resources",
"title": "Kimi Linear",
"section": "Related Resources",
"text": "Related Resources\n\nKimi Linear Paper\nKimi Linear GitHub\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Kimi Linear"
]
},
{
"objectID": "docs/models/orpheus.html",
"href": "docs/models/orpheus.html",
"title": "Orpheus",
"section": "",
"text": "In this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.\nThe finetune.yml withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.",
"crumbs": [
"Getting Started",
"Model Guides",
"Orpheus"
]
},
{
"objectID": "docs/models/orpheus.html#dataset-pre-processing-for-pre-training",
"href": "docs/models/orpheus.html#dataset-pre-processing-for-pre-training",
"title": "Orpheus",
"section": "Dataset pre-processing for pre-training",
"text": "Dataset pre-processing for pre-training\nIf you are adding another voice in English, please jump ahead to finetuning pre-processing.\nFor this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.\nUsing this code, it will download the SNAC model and add the correct tokens and upload the final dataset.\nimport torch\nfrom snac import SNAC\nfrom datasets import load_dataset\nfrom huggingface_hub import snapshot_download\nfrom datasets import load_dataset\nimport random\nimport torchaudio.transforms as T\nfrom transformers import AutoTokenizer\nimport os\n\nmy_original_dataset_name = \"<huggingface-id-of-dataset-that-we-want-to-preprocess>\"\nname_to_push_dataset_to = \"<huggingface-id-of-where-to-save-dataset>\"\n\ndsn = my_original_dataset_name\n\nsnapshot_download(\n repo_id=dsn,\n repo_type=\"dataset\",\n revision=\"main\",\n max_workers=64,\n)\n\n\nds = load_dataset(dsn, split=\"train\")\nds_sample_rate = ds[0][\"audio\"][\"sampling_rate\"]\n\nmodel = SNAC.from_pretrained(\"hubertsiuzdak/snac_24khz\")\nmodel = model.to(\"mps\")\n\ndef tokenise_audio(waveform):\n waveform = torch.from_numpy(waveform).unsqueeze(0)\n waveform = waveform.to(dtype=torch.float32)\n resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)\n waveform = resample_transform(waveform)\n\n waveform = waveform.unsqueeze(0).to(\"cuda\")\n\n #generate the codes from snac\n with torch.inference_mode():\n codes = model.encode(waveform)\n\n all_codes = []\n for i in range(codes[0].shape[1]):\n all_codes.append(codes[0][0][i].item()+128266)\n all_codes.append(codes[1][0][2*i].item()+128266+4096)\n all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))\n all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))\n all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))\n all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))\n all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))\n\n\n return all_codes\n\ndef add_codes(example):\n # Always initialize codes_list to None\n codes_list = None\n\n try:\n answer_audio = example.get(\"audio\")\n # If there's a valid audio array, tokenise it\n if answer_audio and \"array\" in answer_audio:\n audio_array = answer_audio[\"array\"]\n codes_list = tokenise_audio(audio_array)\n except Exception as e:\n print(f\"Skipping row due to error: {e}\")\n # Keep codes_list as None if we fail\n example[\"codes_list\"] = codes_list\n\n return example\n\nds = ds.map(add_codes, remove_columns=[\"audio\"])\n\n#@title Load Tokenizer\ntokeniser_length = 128256\nstart_of_text = 128000\nend_of_text = 128009\n\nstart_of_speech = tokeniser_length + 1\nend_of_speech = tokeniser_length + 2\n\nstart_of_human = tokeniser_length + 3\nend_of_human = tokeniser_length + 4\n\nstart_of_ai = tokeniser_length + 5\nend_of_ai = tokeniser_length + 6\npad_token = tokeniser_length + 7\n\naudio_tokens_start = tokeniser_length + 10\n\ntokenizer_name = \"canopylabs/orpheus-3b-0.1-pretrained\"\n\n\ntokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\nnum_proc = os.cpu_count() - 2\n\nds = ds.filter(lambda x: x[\"codes_list\"] is not None)\nds = ds.filter(lambda x: len(x[\"codes_list\"]) > 0)\n\n#@title Create Input Ids\ndef remove_duplicate_frames(example):\n vals = example[\"codes_list\"]\n if len(vals) % 7 != 0:\n raise ValueError(\"Input list length must be divisible by 7\")\n\n result = vals[:7]\n\n removed_frames = 0\n\n for i in range(7, len(vals), 7):\n current_first = vals[i]\n previous_first = result[-7]\n\n if current_first != previous_first:\n result.extend(vals[i:i+7])\n else:\n removed_frames += 1\n\n example[\"codes_list\"] = result\n\n return example\n\nds = ds.map(remove_duplicate_frames, num_proc=num_proc)\n\n\ndef create_input_ids(example):\n text_ids = tokenizer.encode({example['text']}, add_special_tokens=True)\n text_ids.append(end_of_text)\n example[\"text_tokens\"] = text_ids\n input_ids = (\n [start_of_human]\n + example[\"text_tokens\"]\n + [end_of_human]\n + [start_of_ai]\n + [start_of_speech]\n + example[\"codes_list\"]\n + [end_of_speech]\n + [end_of_ai]\n )\n example[\"input_ids\"] = input_ids\n example[\"labels\"] = input_ids\n example[\"attention_mask\"] = [1] * len(input_ids)\n\n return example\n\nds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=[\"text\", \"codes_list\"])\n\n#@title Remove unnecessary columns\ncolumns_to_keep = [\"input_ids\", \"labels\", \"attention_mask\"]\ncolumns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]\n\nds = ds.remove_columns(columns_to_remove)\n\nds.push_to_hub(name_to_push_dataset_to)",
"crumbs": [
"Getting Started",
"Model Guides",
"Orpheus"
]
},
{
"objectID": "docs/models/orpheus.html#finetune-pre-processing",
"href": "docs/models/orpheus.html#finetune-pre-processing",
"title": "Orpheus",
"section": "Finetune pre-processing",
"text": "Finetune pre-processing\nUse this code to add a new voice.\nimport torch\nfrom snac import SNAC\nfrom datasets import load_dataset\nfrom huggingface_hub import snapshot_download\nfrom datasets import load_dataset\nimport random\nimport torchaudio.transforms as T\nfrom transformers import AutoTokenizer\nimport os\n\nmy_original_dataset_name = \"<huggingface-id-of-dataset-that-we-want-to-preprocess>\"\nname_to_push_dataset_to = \"<huggingface-id-of-where-to-save-dataset>\"\n\ndsn = my_original_dataset_name\n\nsnapshot_download(\n repo_id=dsn,\n repo_type=\"dataset\",\n revision=\"main\",\n max_workers=64,\n)\n\n\nds = load_dataset(dsn, split=\"train\")\nds_sample_rate = ds[0][\"audio\"][\"sampling_rate\"]\n\nmodel = SNAC.from_pretrained(\"hubertsiuzdak/snac_24khz\")\nmodel = model.to(\"mps\")\n\ndef tokenise_audio(waveform):\n waveform = torch.from_numpy(waveform).unsqueeze(0)\n waveform = waveform.to(dtype=torch.float32)\n resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)\n waveform = resample_transform(waveform)\n\n waveform = waveform.unsqueeze(0).to(\"cuda\")\n\n #generate the codes from snac\n with torch.inference_mode():\n codes = model.encode(waveform)\n\n all_codes = []\n for i in range(codes[0].shape[1]):\n all_codes.append(codes[0][0][i].item()+128266)\n all_codes.append(codes[1][0][2*i].item()+128266+4096)\n all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))\n all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))\n all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))\n all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))\n all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))\n\n\n return all_codes\n\ndef add_codes(example):\n # Always initialize codes_list to None\n codes_list = None\n\n try:\n answer_audio = example.get(\"audio\")\n # If there's a valid audio array, tokenise it\n if answer_audio and \"array\" in answer_audio:\n audio_array = answer_audio[\"array\"]\n codes_list = tokenise_audio(audio_array)\n except Exception as e:\n print(f\"Skipping row due to error: {e}\")\n # Keep codes_list as None if we fail\n example[\"codes_list\"] = codes_list\n\n return example\n\nds = ds.map(add_codes, remove_columns=[\"audio\"])\n\n#@title Load Tokenizer\ntokeniser_length = 128256\nstart_of_text = 128000\nend_of_text = 128009\n\nstart_of_speech = tokeniser_length + 1\nend_of_speech = tokeniser_length + 2\n\nstart_of_human = tokeniser_length + 3\nend_of_human = tokeniser_length + 4\n\nstart_of_ai = tokeniser_length + 5\nend_of_ai = tokeniser_length + 6\npad_token = tokeniser_length + 7\n\naudio_tokens_start = tokeniser_length + 10\n\ntokenizer_name = \"canopylabs/orpheus-3b-0.1-pretrained\"\n\n\ntokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\nnum_proc = os.cpu_count() - 2\n\nds = ds.filter(lambda x: x[\"codes_list\"] is not None)\nds = ds.filter(lambda x: len(x[\"codes_list\"]) > 0)\n\n#@title Create Input Ids\ndef remove_duplicate_frames(example):\n vals = example[\"codes_list\"]\n if len(vals) % 7 != 0:\n raise ValueError(\"Input list length must be divisible by 7\")\n\n result = vals[:7]\n\n removed_frames = 0\n\n for i in range(7, len(vals), 7):\n current_first = vals[i]\n previous_first = result[-7]\n\n if current_first != previous_first:\n result.extend(vals[i:i+7])\n else:\n removed_frames += 1\n\n example[\"codes_list\"] = result\n\n return example\n\nds = ds.map(remove_duplicate_frames, num_proc=num_proc)\n\ntok_info = '''*** HERE you can modify the text prompt\ni.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:\nf\"{example[\"source\"]}: {example[\"text\"]}\", as is passed.\n'''\nprint(tok_info)\n\ndef create_input_ids(example):\n text_ids = tokenizer.encode(f\"{example['speaker_id']}: {example['text']}\", add_special_tokens=True)\n text_ids.append(end_of_text)\n example[\"text_tokens\"] = text_ids\n input_ids = (\n [start_of_human]\n + example[\"text_tokens\"]\n + [end_of_human]\n + [start_of_ai]\n + [start_of_speech]\n + example[\"codes_list\"]\n + [end_of_speech]\n + [end_of_ai]\n )\n example[\"input_ids\"] = input_ids\n example[\"labels\"] = input_ids\n example[\"attention_mask\"] = [1] * len(input_ids)\n\n return example\n\nds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=[\"text\", \"codes_list\"])\n\n#@title Remove unnecessary columns\ncolumns_to_keep = [\"input_ids\", \"labels\", \"attention_mask\"]\ncolumns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]\n\nds = ds.remove_columns(columns_to_remove)\n\nds.push_to_hub(name_to_push_dataset_to)",
"crumbs": [
"Getting Started",
"Model Guides",
"Orpheus"
]
},
{
"objectID": "docs/models/orpheus.html#training",
"href": "docs/models/orpheus.html#training",
"title": "Orpheus",
"section": "Training",
"text": "Training\nAfter preprocessing is done, fill out the blanks in finetune.yml and simply run axolotl train finetune.yml",
"crumbs": [
"Getting Started",
"Model Guides",
"Orpheus"
]
},
{
"objectID": "docs/models/orpheus.html#inference",
"href": "docs/models/orpheus.html#inference",
"title": "Orpheus",
"section": "Inference",
"text": "Inference\nFor inference, please refer to the original orpheus github.",
"crumbs": [
"Getting Started",
"Model Guides",
"Orpheus"
]
},
{
"objectID": "docs/models/qwen3.html",
"href": "docs/models/qwen3.html",
"title": "Qwen 3",
"section": "",
"text": "Qwen3 are a family of open source models trained by Alibaba.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"Qwen 3"
]
},
{
"objectID": "docs/models/qwen3.html#getting-started",
"href": "docs/models/qwen3.html#getting-started",
"title": "Qwen 3",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nInstall Cut Cross Entropy to reduce training VRAM usage.\nRun the finetuning example:\naxolotl train examples/qwen3/32b-qlora.yaml\n\nLet us know how it goes. Happy finetuning! 🚀\n\nChat template masking a few tokens off\nIf you notice that the chat_template masking for assistant prompts are off by a few tokens, please ensure that you are adding the below to the yaml.\nchat_template: qwen3\n\n\nTIPS\n\nFor inference, please check the official model card as it depends on your reasoning mode.\nYou can run a full finetuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"Qwen 3"
]
},
{
"objectID": "docs/models/qwen3.html#optimization-guides",
"href": "docs/models/qwen3.html#optimization-guides",
"title": "Qwen 3",
"section": "Optimization Guides",
"text": "Optimization Guides\nPlease check the Optimizations doc.",
"crumbs": [
"Getting Started",
"Model Guides",
"Qwen 3"
]
},
{
"objectID": "docs/models/qwen3.html#related-resources",
"href": "docs/models/qwen3.html#related-resources",
"title": "Qwen 3",
"section": "Related Resources",
"text": "Related Resources\n\nQwen3 Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Qwen 3"
]
},
{
"objectID": "docs/models/ministral3/think.html",
"href": "docs/models/ministral3/think.html",
"title": "Ministral 3 Thinking",
"section": "",
"text": "This guide covers fine-tuning Ministral3 2512 with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral 3 Thinking"
]
},
{
"objectID": "docs/models/ministral3/think.html#prerequisites",
"href": "docs/models/ministral3/think.html#prerequisites",
"title": "Ministral 3 Thinking",
"section": "Prerequisites",
"text": "Prerequisites\nBefore starting, ensure you have:\n\nInstalled Axolotl (see main README)",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral 3 Thinking"
]
},
{
"objectID": "docs/models/ministral3/think.html#getting-started",
"href": "docs/models/ministral3/think.html#getting-started",
"title": "Ministral 3 Thinking",
"section": "Getting Started",
"text": "Getting Started\nRun the thinking model fine-tuning:\naxolotl train examples/ministral3/think/ministral3-3b-think-qlora.yaml\nThis config uses about 4.76 GiB VRAM.\n\nTips\n\nDataset uses multi-content format with type: thinking support. See Dataset Format below.\nYou cannot mix content: str and content: list[dict], otherwise, dataset loading will fail. Keep it consistent.",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral 3 Thinking"
]
},
{
"objectID": "docs/models/ministral3/think.html#dataset-format",
"href": "docs/models/ministral3/think.html#dataset-format",
"title": "Ministral 3 Thinking",
"section": "Dataset Format",
"text": "Dataset Format\nThe thinking model requires the multi-content dataset format with support for an extra role: thinking within system and assistant messages.\nExample format:\n{\n \"messages\": [\n {\n \"role\": \"system\",\n \"content\": [\n { \"type\": \"text\", \"text\": \"{SYSTEM_PROMPT}\"}\n ]\n },\n {\n \"role\": \"user\",\n \"content\": [\n { \"type\": \"text\", \"text\": \"Solve this step by step: What is 15% of 240?\"}\n ]\n },\n {\n \"role\": \"assistant\",\n \"content\": [\n {\n \"type\": \"thinking\",\n \"thinking\": \"I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36.\"\n },\n {\n \"type\": \"text\",\n \"text\": \"To find 15% of 240, I'll multiply 240 by 0.15:\\n\\n240 × 0.15 = 36\\n\\nTherefore, 15% of 240 is 36.\"\n }\n ]\n }\n ]\n}\n\nAdvanced Options\nThe thinking section supports an optional closed parameter:\n{\n \"type\": \"thinking\",\n \"thinking\": \"Internal reasoning here...\",\n \"closed\": true // Default: true, controls adding the closing [/THINK] tag\n}",
"crumbs": [
"Getting Started",
"Model Guides",
"Ministral3",
"Ministral 3 Thinking"
]
},
{
"objectID": "docs/models/apertus.html",
"href": "docs/models/apertus.html",
"title": "Apertus",
"section": "",
"text": "Apertus is a family of opensource models trained by Swiss-ai.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"Apertus"
]
},
{
"objectID": "docs/models/apertus.html#getting-started",
"href": "docs/models/apertus.html#getting-started",
"title": "Apertus",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide. You need to install from main as Apertus is only on nightly or use our latest Docker images.\nHere is an example of how to install from main for pip:\n\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\ngit clone https://github.com/axolotl-ai-cloud/axolotl.git\ncd axolotl\n\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation -e '.[flash-attn]'\n\n# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy\npython scripts/cutcrossentropy_install.py | sh\n\n(Optional, highly recommended) Install XIELU CUDA\n\n## Recommended for reduced VRAM and faster speeds\n\n# Point to CUDA toolkit directory\n# For those using our Docker image, use the below path.\nexport CUDA_HOME=/usr/local/cuda\n\npip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps\nFor any installation errors, see XIELU Installation Issues\n\nRun the finetuning example:\n\naxolotl train examples/apertus/apertus-8b-qlora.yaml\nThis config uses about 8.7 GiB VRAM.\nLet us know how it goes. Happy finetuning! 🚀\n\nTips\n\nFor inference, the official Apertus team recommends top_p=0.9 and temperature=0.8.\nYou can instead use full paremter fine-tuning by removing the adapter: qlora and load_in_4bit: true from the config.\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.\n\n\n\nXIELU Installation Issues\n\nModuleNotFoundError: No module named 'torch'\nPlease check these one by one:\n- Running in correct environment\n- Env has PyTorch installed\n- CUDA toolkit is at CUDA_HOME\nIf those didnt help, please try the below solutions:\n\nPass env for CMAKE and try install again:\nPython_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps\nGit clone the repo and manually hardcode python path:\ngit clone https://github.com/nickjbrowning/XIELU\ncd xielu\ngit checkout 59d6031\n\ncd xielu\nnano CMakeLists.txt # or vi depending on your preference\nexecute_process(\n- COMMAND ${Python_EXECUTABLE} -c \"import torch.utils; print(torch.utils.cmake_prefix_path)\"\n+ COMMAND /root/miniconda3/envs/py3.11/bin/python -c \"import torch.utils; print(torch.utils.cmake_prefix_path)\"\n RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT\n OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT\n ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR\n)\npip3 install . --no-build-isolation --no-deps",
"crumbs": [
"Getting Started",
"Model Guides",
"Apertus"
]
},
{
"objectID": "docs/models/apertus.html#optimization-guides",
"href": "docs/models/apertus.html#optimization-guides",
"title": "Apertus",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training\nLoRA Optimizations",
"crumbs": [
"Getting Started",
"Model Guides",
"Apertus"
]
},
{
"objectID": "docs/models/apertus.html#related-resources",
"href": "docs/models/apertus.html#related-resources",
"title": "Apertus",
"section": "Related Resources",
"text": "Related Resources\n\nApertus Tech Report\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"Apertus"
]
},
{
"objectID": "docs/models/gpt-oss.html",
"href": "docs/models/gpt-oss.html",
"title": "GPT-OSS",
"section": "",
"text": "GPT-OSS are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.\nIn October 2025, OpenAI released safeguard models built upon GPT-OSS called GPT-OSS-Safeguard. They use the same architecture, so the same examples below can be re-used.\nThis guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.",
"crumbs": [
"Getting Started",
"Model Guides",
"GPT-OSS"
]
},
{
"objectID": "docs/models/gpt-oss.html#getting-started",
"href": "docs/models/gpt-oss.html#getting-started",
"title": "GPT-OSS",
"section": "Getting started",
"text": "Getting started\n\nInstall Axolotl following the installation guide.\nHere is an example of how to install from pip:\n\n# Ensure you have Pytorch installed (Pytorch 2.6.0 min)\npip3 install packaging==26.0 setuptools==75.8.0 wheel ninja\npip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'\n\nChoose one of the following configs below for training the 20B model. (for 120B, see below)\n\n# LoRA SFT linear layers (1x48GB @ ~44GiB)\naxolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml\n\n# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)\naxolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml\n\n# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)\naxolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml\nNote: Memory usage taken from device_mem_reserved(gib) from logs.\n\nTraining 120B\nOn 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base\nmodel, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.\n# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)\naxolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml\nTo simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, weve partnered with Baseten to showcase multi-node\ntraining of the 120B model using Baseten Truss. You can read more about this recipe on\nBasetens blog. The recipe can\nbe found on their\nGitHub.\nERRATA: Transformers saves the model Architecture prefixed with FSDP which needs to be manually renamed in config.json.\nSee https://github.com/huggingface/transformers/pull/40207 for the status of this issue.\nsed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json\nWhen using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your\nconfigured output_dir. However, if that step fails due to a disk space error, you can take an additional step to\nmerge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded\nweights to {output_dir}/merged.\naxolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml\nmv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/\n\n\nHow to set reasoning_effort in template?\nThe harmony template has a feature to set the reasoning_effort during prompt building. The default is medium. If you would like to adjust this, you can add the following to your config:\nchat_template_kwargs:\n reasoning_effort: \"high\" # low | medium | high\nCurrently, this applies globally. There is no method to apply per sample yet. If you are interested in adding this, please feel free to create an Issue to discuss.\n\n\nInferencing your fine-tuned model\n\nvLLM\nGPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425\nfor more information about using a special vllm-openai docker image for inferencing with vLLM.\nOptionally, vLLM can be installed from nightly:\npip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly\nand the vLLM server can be started with the following command (modify --tensor-parallel-size 8 to match your environment):\nvllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8\n\n\nSGLang\nSGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing\nSGLang from source. Once youve installed SGLang, run the following command to launch a SGLang server:\npython3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8\n\n\n\nTool use\nGPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.\nHere is an example dataset config:\ndatasets:\n - path: Nanobit/text-tools-2k-test\n type: chat_template\nSee Nanobit/text-tools-2k-test for the sample dataset.\nRefer to our docs for more info.\n\n\nThinking and chat_template masking conflict\nOpenAIs Harmony template hides thinking in all non-final turns, which conflicts with Axolotls chat_template masking.\nIf your dataset has thinking content mid-turn, there are two paths we recommend:\n\nTrain only on the last turn. This can be accomplished via chat_templates train on last doc.\nAdjust your dataset to only have thinking content in the last turn.\n\n\n\nTIPS\n\nRead more on how to load your own dataset at docs.\nThe dataset format follows the OpenAI Messages format as seen here.",
"crumbs": [
"Getting Started",
"Model Guides",
"GPT-OSS"
]
},
{
"objectID": "docs/models/gpt-oss.html#optimization-guides",
"href": "docs/models/gpt-oss.html#optimization-guides",
"title": "GPT-OSS",
"section": "Optimization Guides",
"text": "Optimization Guides\n\nMulti-GPU Training\nMulti-Node Training",
"crumbs": [
"Getting Started",
"Model Guides",
"GPT-OSS"
]
},
{
"objectID": "docs/models/gpt-oss.html#related-resources",
"href": "docs/models/gpt-oss.html#related-resources",
"title": "GPT-OSS",
"section": "Related Resources",
"text": "Related Resources\n\nGPT-OSS Blog\nAxolotl Docs\nAxolotl Website\nAxolotl GitHub\nAxolotl Discord",
"crumbs": [
"Getting Started",
"Model Guides",
"GPT-OSS"
]
},
{
"objectID": "docs/mixed_precision.html",
"href": "docs/mixed_precision.html",
"title": "Mixed Precision Training",
"section": "",
"text": "Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats:",
"crumbs": [
"Core Concepts",
"Mixed Precision Training"
]
},
{
"objectID": "docs/mixed_precision.html#sec-fp16",
"href": "docs/mixed_precision.html#sec-fp16",
"title": "Mixed Precision Training",
"section": "1 FP16 Mixed Precision",
"text": "1 FP16 Mixed Precision\n\n1.1 Overview\nFP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16.\n\n\n1.2 Configuration\nfp16: true\n\n\n1.3 FP16 Considerations\n\nMay require gradient scaling to prevent underflow\nLess numerically stable than BF16\nCan cause training instability with some model architectures\nConsider using BF16 if your hardware supports it",
"crumbs": [
"Core Concepts",
"Mixed Precision Training"
]
},
{
"objectID": "docs/mixed_precision.html#sec-bf16",
"href": "docs/mixed_precision.html#sec-bf16",
"title": "Mixed Precision Training",
"section": "2 BF16 Mixed Precision",
"text": "2 BF16 Mixed Precision\n\n2.1 Overview\nBF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory.\n\n\n2.2 Configuration\n# Automatic BF16 detection (recommended)\nbf16: auto\n\n# Or explicitly enable\nbf16: true\n\n# For evaluation with BF16\nbf16: full # Equivalent to bf16_full_eval in the HF trainer",
"crumbs": [
"Core Concepts",
"Mixed Precision Training"
]
},
{
"objectID": "docs/mixed_precision.html#sec-fp8",
"href": "docs/mixed_precision.html#sec-fp8",
"title": "Mixed Precision Training",
"section": "3 FP8 Mixed Precision",
"text": "3 FP8 Mixed Precision\n\n\n\n\n\n\nNote\n\n\n\nFP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO.\n\n\n\n3.1 What is FP8?\nFP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotls implementation uses PyTorchs TorchAO library with “tensorwise” scaling strategy.\n\n\n3.2 Requirements\n\nHopper+ GPUs (H100/H200)\nPyTorch 2.7+ (+ compatible TorchAO version)\nCUDA 12.4+\n\n\n\n3.3 Configuration\nAdd to your YAML config:\n# Enable FP8 mixed precision\nfp8: true\n\n# Optional: Enable FP8 for FSDP all-gather operations\nfp8_enable_fsdp_float8_all_gather: true\n\n# Enable torch.compile (almost always necessary for FP8 speedups)\ntorch_compile: true\n\n\n\n\n\n\nImportant\n\n\n\ntorch.compile is critical for FP8 performance\nFP8 training requires torch_compile: true to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16.\n\n\n\n\n3.4 Advanced FP8 Configs\nFor FSDP (Fully Sharded Data Parallel) training:\nfp8: true\nfp8_enable_fsdp_float8_all_gather: true\n\ntorch_compile: true\n\n# FSDP configuration\nfsdp_version: 2\nfsdp_config:\n offload_params: false\n cpu_ram_efficient_loading: true\n auto_wrap_policy: TRANSFORMER_BASED_WRAP\n transformer_layer_cls_to_wrap: LlamaDecoderLayer\n state_dict_type: FULL_STATE_DICT\n reshard_after_forward: true",
"crumbs": [
"Core Concepts",
"Mixed Precision Training"
]
},
{
"objectID": "docs/mixed_precision.html#sec-best-practices",
"href": "docs/mixed_precision.html#sec-best-practices",
"title": "Mixed Precision Training",
"section": "4 Best Practices",
"text": "4 Best Practices\n\n4.1 Choosing Precision Format\n\nStart with automatic detection: bf16: auto\nFor Hopper+ (H100/H200): Try FP8 + torch.compile for maximum speed\nFor Ampere (A100/RTX 30/40): Use BF16\nFor older Pascal/Turing GPUs: Use FP16 with caution\nFor very old or unsupported GPUs: Use FP32\n\n\n\n4.2 Validation and Testing\nAlways validate your mixed precision setup:\n\nStart with a small dataset to verify stability\nMonitor loss curves for irregularities\nCompare with FP32 baseline when possible\nTest evaluation metrics match expectations\n\n\n\n4.3 FP8 Particulars\n\nUse cases\n\nSingle GPU training\nMulti GPU training with FSDP2 or Deepspeed\n\nSpeedups\n\nPlease refer to the TorchAO FP8 training benchmarks for expected matmul speedups for different (M, K, N) settings\nConcrete number for LLaMA 3 8B training can be found here\n\nKnown issues:\n\nFP8 + DDP + torch.compile (causes error)\nFP8 + FSDP2 + torch.compile + FSDP2 activation checkpointing tends to be slower than the BF16 equivalent training\nFlash Attention 2 does not play nicely with torch.compile\n\n\nSee examples/llama-3/3b-fp8-fsdp2.yaml for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model\nFor more information on multi-GPU training, see our Multi-GPU guide.",
"crumbs": [
"Core Concepts",
"Mixed Precision Training"
]
},
{
"objectID": "docs/lora_optims.html",
"href": "docs/lora_optims.html",
"title": "LoRA Optimizations",
"section": "",
"text": "Inspired by Unsloth, weve implemented two\noptimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU\n(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU\nand GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom\nautograd functions. Our goal was to leverage operator fusion and tensor re-use in order\nto improve speed and reduce memory usage during the forward and backward passes of\nthese calculations.\nWe currently support several common model architectures, including (but not limited to):",
"crumbs": [
"How To Guides",
"LoRA Optimizations"
]
},
{
"objectID": "docs/lora_optims.html#usage",
"href": "docs/lora_optims.html#usage",
"title": "LoRA Optimizations",
"section": "Usage",
"text": "Usage\nThese optimizations can be enabled in your Axolotl config YAML file. The\nlora_mlp_kernel option enables the optimized MLP path, while lora_qkv_kernel and\nlora_o_kernel enable the fused query-key-value projection and optimized output\nprojection, respectively.\nlora_mlp_kernel: true\nlora_qkv_kernel: true\nlora_o_kernel: true\n\n\n\n\n\n\nNote\n\n\n\nCurrently, LoRA kernels are not supported for RLHF training, only SFT.\n\n\n\n\n\n\n\n\nWarning\n\n\n\nLoRA kernels do not support remote modeling code.",
"crumbs": [
"How To Guides",
"LoRA Optimizations"
]
},
{
"objectID": "docs/lora_optims.html#requirements",
"href": "docs/lora_optims.html#requirements",
"title": "LoRA Optimizations",
"section": "Requirements",
"text": "Requirements\n\nOne or more NVIDIA or AMD GPUs (in order to use the Triton kernels)\n\nNote: Set TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 to enable memory-efficient attention on AMD GPUs\n\nTargeted LoRA adapters cannot use Dropout\n\nThis may limit model expressivity / cause overfitting\n\nTargeted LoRA adapters cannot have bias terms\n\nThis may limit model expressivity\n\n\nModels with pre-existing LoRA adapters that use Dropout or have bias terms may need to\nbe re-finetuned without these features in order to be useful.",
"crumbs": [
"How To Guides",
"LoRA Optimizations"
]
},
{
"objectID": "docs/lora_optims.html#implementation-details",
"href": "docs/lora_optims.html#implementation-details",
"title": "LoRA Optimizations",
"section": "Implementation details",
"text": "Implementation details\n\nCustom autograd functions\nThe LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the\nLoRA and base weight computations together and provides a single, efficient backward\npass for the entire MLP block.\nFor attention components, similar optimizations are provided through a function that\nhandles the query, key, and value projections, and a function that handles the output\nprojection. They are designed to work with the existing transformers attention\nimplementation via some monkey-patching logic.\n\n\nTriton kernels\nTwo activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for\nimproved speed and memory performance. These kernels handle both the forward and\nbackward passes.\n\n\nIntegration\nThe custom autograd functions and Triton kernels are designed to work together. The\nautograd function manages the high-level computation flow and gradient tracking, while\ncalling the Triton kernels for the activation function computation. During the backward\npass, the kernel computes both the activation output and the required gradients, which\nthe autograd function then uses to compute the final gradients for the entire\ncomputation path.",
"crumbs": [
"How To Guides",
"LoRA Optimizations"
]
},
{
"objectID": "docs/lora_optims.html#future-work",
"href": "docs/lora_optims.html#future-work",
"title": "LoRA Optimizations",
"section": "Future Work",
"text": "Future Work\n\nSupport for additional model architectures\nSupport for dropout and bias\nAdditional operator fusions",
"crumbs": [
"How To Guides",
"LoRA Optimizations"
]
},
{
"objectID": "docs/dataset_loading.html",
"href": "docs/dataset_loading.html",
"title": "Dataset Loading",
"section": "",
"text": "Datasets can be loaded in a number of different ways depending on the how it is saved (the extension of the file) and where it is stored.",
"crumbs": [
"How To Guides",
"Dataset Loading"
]
},
{
"objectID": "docs/dataset_loading.html#overview",
"href": "docs/dataset_loading.html#overview",
"title": "Dataset Loading",
"section": "",
"text": "Datasets can be loaded in a number of different ways depending on the how it is saved (the extension of the file) and where it is stored.",
"crumbs": [
"How To Guides",
"Dataset Loading"
]
},
{
"objectID": "docs/dataset_loading.html#loading-datasets",
"href": "docs/dataset_loading.html#loading-datasets",
"title": "Dataset Loading",
"section": "Loading Datasets",
"text": "Loading Datasets\nWe use the datasets library to load datasets and a mix of load_dataset and load_from_disk to load them.\nYou may recognize the similar named configs between load_dataset and the datasets section of the config file.\ndatasets:\n - path:\n name:\n data_files:\n split:\n revision:\n trust_remote_code:\n\n\n\n\n\n\nTip\n\n\n\nDo not feel overwhelmed by the number of options here. A lot of them are optional. In fact, the most common config to use would be path and sometimes data_files.\n\n\nThis matches the API of datasets.load_dataset, so if youre familiar with that, you will feel right at home.\nFor HuggingFaces guide to load different dataset types, see here.\nFor full details on the config, see config-reference.qmd.\n\n\n\n\n\n\nNote\n\n\n\nYou can set multiple datasets in the config file by more than one entry under datasets.\ndatasets:\n - path: /path/to/your/dataset\n - path: /path/to/your/other/dataset\n\n\n\nLocal dataset\n\nFiles\nTo load a JSON file, you would do something like this:\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"json\", data_files=\"data.json\")\nWhich translates to the following config:\ndatasets:\n - path: data.json\n ds_type: json\nIn the example above, it can be seen that we can just point the path to the file or directory along with the ds_type to load the dataset.\nThis works for CSV, JSON, Parquet, and Arrow files.\n\n\n\n\n\n\nTip\n\n\n\nIf path points to a file and ds_type is not specified, we will automatically infer the dataset type from the file extension, so you could omit ds_type if youd like.\n\n\n\n\nDirectory\nIf youre loading a directory, you can point the path to the directory.\nThen, you have two options:\n\nLoading entire directory\nYou do not need any additional configs.\nWe will attempt to load in the following order:\n- datasets saved with datasets.save_to_disk\n- loading entire directory of files (such as with parquet/arrow files)\ndatasets:\n - path: /path/to/your/directory\n\n\nLoading specific files in directory\nProvide data_files with a list of files to load.\ndatasets:\n # single file\n - path: /path/to/your/directory\n ds_type: csv\n data_files: file1.csv\n\n # multiple files\n - path: /path/to/your/directory\n ds_type: json\n data_files:\n - file1.jsonl\n - file2.jsonl\n\n # multiple files for parquet\n - path: /path/to/your/directory\n ds_type: parquet\n data_files:\n - file1.parquet\n - file2.parquet\n\n\n\n\nHuggingFace Hub\nThe method you use to load the dataset depends on how the dataset was created, whether a folder was uploaded directly or a HuggingFace Dataset was pushed.\n\n\n\n\n\n\nNote\n\n\n\nIf youre using a private dataset, you will need to enable the hf_use_auth_token flag in the root-level of the config file.\n\n\n\nFolder uploaded\nThis would mean that the dataset is a single file or file(s) uploaded to the Hub.\ndatasets:\n - path: org/dataset-name\n data_files:\n - file1.jsonl\n - file2.jsonl\n\n\nHuggingFace Dataset\nThis means that the dataset is created as a HuggingFace Dataset and pushed to the Hub via datasets.push_to_hub.\ndatasets:\n - path: org/dataset-name\n\n\n\n\n\n\nNote\n\n\n\nThere are some other configs which may be required like name, split, revision, trust_remote_code, etc depending on the dataset.\n\n\n\n\n\nRemote Filesystems\nVia the storage_options config under load_dataset, you can load datasets from remote filesystems like S3, GCS, Azure, and OCI.\n\n\n\n\n\n\nWarning\n\n\n\nThis is currently experimental. Please let us know if you run into any issues!\n\n\nThe only difference between the providers is that you need to prepend the path with the respective protocols.\ndatasets:\n # Single file\n - path: s3://bucket-name/path/to/your/file.jsonl\n\n # Directory\n - path: s3://bucket-name/path/to/your/directory\nFor directory, we load via load_from_disk.\n\nS3\nPrepend the path with s3://.\nThe credentials are pulled in the following order:\n\nAWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN environment variables\nfrom the ~/.aws/credentials file\nfor nodes on EC2, the IAM metadata provider\n\n\n\n\n\n\n\nNote\n\n\n\nWe assume you have credentials setup and not using anonymous access. If you want to use anonymous access, let us know! We may have to open a config option for this.\n\n\nOther environment variables that can be set can be found in boto3 docs\n\n\nGCS\nPrepend the path with gs:// or gcs://.\nThe credentials are loaded in the following order:\n\ngcloud credentials\nfor nodes on GCP, the google metadata service\nanonymous access\n\n\n\nAzure\n\nGen 1\nPrepend the path with adl://.\nEnsure you have the following environment variables set:\n\nAZURE_STORAGE_TENANT_ID\nAZURE_STORAGE_CLIENT_ID\nAZURE_STORAGE_CLIENT_SECRET\n\n\n\nGen 2\nPrepend the path with abfs:// or az://.\nEnsure you have the following environment variables set:\n\nAZURE_STORAGE_ACCOUNT_NAME\nAZURE_STORAGE_ACCOUNT_KEY\n\nOther environment variables that can be set can be found in adlfs docs\n\n\n\nOCI\nPrepend the path with oci://.\nIt would attempt to read in the following order:\n\nOCIFS_IAM_TYPE, OCIFS_CONFIG_LOCATION, and OCIFS_CONFIG_PROFILE environment variables\nwhen on OCI resource, resource principal\n\nOther environment variables:\n\nOCI_REGION_METADATA\n\nPlease see the ocifs docs.\n\n\n\nHTTPS\nThe path should start with https://.\ndatasets:\n - path: https://path/to/your/dataset/file.jsonl\nThis must be publically accessible.",
"crumbs": [
"How To Guides",
"Dataset Loading"
]
},
{
"objectID": "docs/dataset_loading.html#next-steps",
"href": "docs/dataset_loading.html#next-steps",
"title": "Dataset Loading",
"section": "Next steps",
"text": "Next steps\nNow that you know how to load datasets, you can learn more on how to load your specific dataset format into your target output format dataset formats docs.",
"crumbs": [
"How To Guides",
"Dataset Loading"
]
},
{
"objectID": "docs/input_output.html",
"href": "docs/input_output.html",
"title": "Template-free prompt construction",
"section": "",
"text": "The documentation moved to here."
},
{
"objectID": "docs/fsdp_qlora.html",
"href": "docs/fsdp_qlora.html",
"title": "FSDP + QLoRA",
"section": "",
"text": "Using FSDP with QLoRA is essential for fine-tuning larger (70b+ parameter) LLMs on consumer GPUs. For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs1.\nBelow, we describe how to use this feature in Axolotl.",
"crumbs": [
"Advanced Features",
"FSDP + QLoRA"
]
},
{
"objectID": "docs/fsdp_qlora.html#background",
"href": "docs/fsdp_qlora.html#background",
"title": "FSDP + QLoRA",
"section": "",
"text": "Using FSDP with QLoRA is essential for fine-tuning larger (70b+ parameter) LLMs on consumer GPUs. For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs1.\nBelow, we describe how to use this feature in Axolotl.",
"crumbs": [
"Advanced Features",
"FSDP + QLoRA"
]
},
{
"objectID": "docs/fsdp_qlora.html#usage",
"href": "docs/fsdp_qlora.html#usage",
"title": "FSDP + QLoRA",
"section": "Usage",
"text": "Usage\nTo enable QLoRA with FSDP, you need to perform the following steps:\n\n![Tip]\nSee the example config file in addition to reading these instructions.\n\n\nSet adapter: qlora in your axolotl config file.\nEnable FSDP in your axolotl config, as described here.\nUse one of the supported model types: llama, mistral or mixtral.",
"crumbs": [
"Advanced Features",
"FSDP + QLoRA"
]
},
{
"objectID": "docs/fsdp_qlora.html#enabling-swap-for-fsdp2",
"href": "docs/fsdp_qlora.html#enabling-swap-for-fsdp2",
"title": "FSDP + QLoRA",
"section": "Enabling Swap for FSDP2",
"text": "Enabling Swap for FSDP2\nIf available memory is insufficient even after FSDPs CPU offloading, you can enable swap memory usage by setting cpu_offload_pin_memory: false alongside offload_params: true in FSDP config.\nThis disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.",
"crumbs": [
"Advanced Features",
"FSDP + QLoRA"
]
},
{
"objectID": "docs/fsdp_qlora.html#example-config",
"href": "docs/fsdp_qlora.html#example-config",
"title": "FSDP + QLoRA",
"section": "Example Config",
"text": "Example Config\nexamples/llama-2/qlora-fsdp.yml contains an example of how to enable QLoRA + FSDP in axolotl.",
"crumbs": [
"Advanced Features",
"FSDP + QLoRA"
]
},
{
"objectID": "docs/fsdp_qlora.html#references",
"href": "docs/fsdp_qlora.html#references",
"title": "FSDP + QLoRA",
"section": "References",
"text": "References\n\nPR #1378 enabling QLoRA in FSDP in Axolotl.\nBlog Post from the Answer.AI team describing the work that enabled QLoRA in FSDP.\nRelated HuggingFace PRs Enabling FDSP + QLoRA:\n\nAccelerate PR#2544\nTransformers PR#29587\nTRL PR#1416\nPEFT PR#1550",
"crumbs": [
"Advanced Features",
"FSDP + QLoRA"
]
},
{
"objectID": "docs/fsdp_qlora.html#footnotes",
"href": "docs/fsdp_qlora.html#footnotes",
"title": "FSDP + QLoRA",
"section": "Footnotes",
"text": "Footnotes\n\n\nThis was enabled by this work from the Answer.AI team.↩︎",
"crumbs": [
"Advanced Features",
"FSDP + QLoRA"
]
},
{
"objectID": "docs/agents/preference_tuning.html",
"href": "docs/agents/preference_tuning.html",
"title": "Preference Learning (RLHF) — Agent Reference",
"section": "",
"text": "Reference for DPO, IPO, KTO, ORPO, and SimPO. For config templates and dataset format examples, see rlhf.qmd. For GRPO, see grpo.qmd. For EBFT, see ebft.qmd.\n\n\n\n\n\n\n\n\n\n\n\nMethod\nData Requirement\nKey Idea\nBest For\n\n\n\n\nDPO\nPaired (chosen + rejected)\nImplicit reward via preference pairs\nGeneral alignment, most common\n\n\nIPO\nPaired (chosen + rejected)\nDPO with different loss (avoids overfitting)\nWhen DPO overfits\n\n\nKTO\nUnpaired (completion + binary label)\nKahneman-Tversky loss, no pairs needed\nWhen you only have thumbs-up/down\n\n\nORPO\nPaired (chosen + rejected)\nCombined SFT + preference, no ref model\nSingle-stage alignment, saves VRAM\n\n\nSimPO\nPaired (chosen + rejected)\nLength-normalized, no ref model\nSimple setup, length-robust\n\n\n\nDefault: start with DPO. All methods require sample_packing: false.\n\n\n\n┌──────────────┐ ┌───────────────┐ ┌───────────────┐\n│ Policy Model │ │ Reference │ │ Preference │\n│ (trainable) │ │ Model (frozen)│ │ Dataset │\n└──────┬───────┘ └──────┬────────┘ └──────┬────────┘\n └──────────┬───────┘ │\n v │\n Forward pass on chosen + rejected <─────┘\n │\n Preference Loss (DPO/IPO/KTO/...)\n │\n Backprop + Update\n\nException: ORPO and SimPO do NOT use a reference model (~50% less VRAM).\nNo vLLM server needed (unlike GRPO). Offline RL with pre-collected preference data.\n\n\n\n\nPaired preference data (chosen + rejected)?\n\nDefault → rl: dpo\nOverfitting → rl: ipo\nVRAM-limited → rl: orpo (no ref model)\nLength-sensitive → rl: simpo (no ref model)\n\nOnly binary labels (good/bad)? → rl: kto\nSingle-stage training (no separate SFT)? → rl: orpo\n\n\n\n\n\n\n\n\n\n\n\n\n\nDPO\nIPO\nKTO\nORPO\nSimPO\n\n\n\n\nReference model\nYes\nYes\nYes\nNo\nNo\n\n\nVRAM overhead\n~2x model\n~2x model\n~2x model\n~1x model\n~1x model\n\n\nTRL trainer class\nDPOTrainer\nDPOTrainer\nKTOTrainer\nORPOTrainer\nCPOTrainer\n\n\n\n\n\n\nThe type field resolves to a Python function:\ntype: \"chatml.intel\"\n → axolotl.prompt_strategies.dpo.chatml.intel(cfg, **kwargs)\n → returns transform_fn(sample) → {\"prompt\", \"chosen\", \"rejected\"}\n\ntype: \"chat_template.default\"\n → axolotl.prompt_strategies.dpo.chat_template.default(cfg, dataset_idx, **kwargs)\n\ntype: {\"field_prompt\": \"prompt\", ...} (dict)\n → axolotl.prompt_strategies.dpo.user_defined.default(...)\nModule base: axolotl.prompt_strategies.{rl_method} — replace dpo with kto or orpo.\n\n\n\n\n\n\n\n\n\n\n\nMetric\nHealthy Range\nProblem\n\n\n\n\ntrain/loss\nDecreasing, 0.3-0.7\nFlat or increasing = broken data or too high LR\n\n\nrewards/chosen\nIncreasing\nFlat = model not learning preferences\n\n\nrewards/rejected\nDecreasing\nIncreasing = model prefers wrong responses\n\n\nrewards/margins\nPositive and increasing\nNegative = prefers rejected over chosen\n\n\nrewards/accuracies\n> 0.5, toward 0.7+\n< 0.5 = worse than random\n\n\nlogps/rejected\nDecreasing\nIncreasing = reward hacking\n\n\ngrad_norm\n0.01 - 10.0\n> 100 = exploding gradients\n\n\n\nMethod-specific: DPO/IPO watch rewards/margins; KTO loss is noisier; ORPO monitor SFT + odds ratio components; SimPO check length-normalized reward separation.\n\n\n\n\n\n\n\n\n\n\nIssue\nFix\n\n\n\n\nSample packing crash\nSet sample_packing: false (required for all preference methods)\n\n\nKTO KeyError: 'label'\nEnsure dataset has boolean label column\n\n\nORPO/KTO KeyError during tokenization\nAdd remove_unused_columns: false\n\n\nORPO template not applied\nORPO requires explicit chat_template setting\n\n\nOOM with ref model (DPO/IPO/KTO)\nUse LoRA/QLoRA, or switch to ORPO/SimPO (no ref model)\n\n\nIPO + label_smoothing\nDo not set dpo_label_smoothing when rl: ipo\n\n\n\nFull troubleshooting: training_stability.qmd\n\n\n\nsrc/axolotl/\n core/trainers/dpo/ # DPO trainer, args, strategy\n core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer class\n core/training_args.py # AxolotlKTOConfig, AxolotlORPOConfig, AxolotlCPOConfig\n prompt_strategies/\n dpo/ # DPO/IPO/SimPO dataset strategies\n chat_template.py # chat_template.default, chat_template.argilla_chat\n chatml.py # chatml.default/intel/icr/argilla_chat/prompt_pairs/ultra\n llama3.py # llama3 variants (same subtypes as chatml)\n user_defined.py # Custom field mapping\n passthrough.py # No transform\n kto/ # KTO dataset strategies (chatml, llama3, user_defined)\n orpo/ # ORPO dataset strategies (chat_template.argilla)\n utils/schemas/enums.py # RLType enum (dpo, ipo, kto, orpo, simpo, grpo, gdpo, ebft)\n utils/schemas/config.py # All rl/dpo/kto/orpo/simpo config fields\n\ndocs/rlhf.qmd # Full user docs: all dataset formats, config templates\ndocs/choosing_method.qmd # SFT vs DPO vs GRPO decision guide\nexamples/qwen2/dpo.yaml # DPO example\nexamples/llama-3/qlora-1b-kto.yaml # KTO example"
},
{
"objectID": "docs/agents/preference_tuning.html#method-overview",
"href": "docs/agents/preference_tuning.html#method-overview",
"title": "Preference Learning (RLHF) — Agent Reference",
"section": "",
"text": "Method\nData Requirement\nKey Idea\nBest For\n\n\n\n\nDPO\nPaired (chosen + rejected)\nImplicit reward via preference pairs\nGeneral alignment, most common\n\n\nIPO\nPaired (chosen + rejected)\nDPO with different loss (avoids overfitting)\nWhen DPO overfits\n\n\nKTO\nUnpaired (completion + binary label)\nKahneman-Tversky loss, no pairs needed\nWhen you only have thumbs-up/down\n\n\nORPO\nPaired (chosen + rejected)\nCombined SFT + preference, no ref model\nSingle-stage alignment, saves VRAM\n\n\nSimPO\nPaired (chosen + rejected)\nLength-normalized, no ref model\nSimple setup, length-robust\n\n\n\nDefault: start with DPO. All methods require sample_packing: false."
},
{
"objectID": "docs/agents/preference_tuning.html#architecture",
"href": "docs/agents/preference_tuning.html#architecture",
"title": "Preference Learning (RLHF) — Agent Reference",
"section": "",
"text": "┌──────────────┐ ┌───────────────┐ ┌───────────────┐\n│ Policy Model │ │ Reference │ │ Preference │\n│ (trainable) │ │ Model (frozen)│ │ Dataset │\n└──────┬───────┘ └──────┬────────┘ └──────┬────────┘\n └──────────┬───────┘ │\n v │\n Forward pass on chosen + rejected <─────┘\n │\n Preference Loss (DPO/IPO/KTO/...)\n │\n Backprop + Update\n\nException: ORPO and SimPO do NOT use a reference model (~50% less VRAM).\nNo vLLM server needed (unlike GRPO). Offline RL with pre-collected preference data."
},
{
"objectID": "docs/agents/preference_tuning.html#method-selection",
"href": "docs/agents/preference_tuning.html#method-selection",
"title": "Preference Learning (RLHF) — Agent Reference",
"section": "",
"text": "Paired preference data (chosen + rejected)?\n\nDefault → rl: dpo\nOverfitting → rl: ipo\nVRAM-limited → rl: orpo (no ref model)\nLength-sensitive → rl: simpo (no ref model)\n\nOnly binary labels (good/bad)? → rl: kto\nSingle-stage training (no separate SFT)? → rl: orpo\n\n\n\n\n\n\n\n\n\n\n\n\n\nDPO\nIPO\nKTO\nORPO\nSimPO\n\n\n\n\nReference model\nYes\nYes\nYes\nNo\nNo\n\n\nVRAM overhead\n~2x model\n~2x model\n~2x model\n~1x model\n~1x model\n\n\nTRL trainer class\nDPOTrainer\nDPOTrainer\nKTOTrainer\nORPOTrainer\nCPOTrainer"
},
{
"objectID": "docs/agents/preference_tuning.html#prompt-strategy-resolution",
"href": "docs/agents/preference_tuning.html#prompt-strategy-resolution",
"title": "Preference Learning (RLHF) — Agent Reference",
"section": "",
"text": "The type field resolves to a Python function:\ntype: \"chatml.intel\"\n → axolotl.prompt_strategies.dpo.chatml.intel(cfg, **kwargs)\n → returns transform_fn(sample) → {\"prompt\", \"chosen\", \"rejected\"}\n\ntype: \"chat_template.default\"\n → axolotl.prompt_strategies.dpo.chat_template.default(cfg, dataset_idx, **kwargs)\n\ntype: {\"field_prompt\": \"prompt\", ...} (dict)\n → axolotl.prompt_strategies.dpo.user_defined.default(...)\nModule base: axolotl.prompt_strategies.{rl_method} — replace dpo with kto or orpo."
},
{
"objectID": "docs/agents/preference_tuning.html#healthy-training-indicators",
"href": "docs/agents/preference_tuning.html#healthy-training-indicators",
"title": "Preference Learning (RLHF) — Agent Reference",
"section": "",
"text": "Metric\nHealthy Range\nProblem\n\n\n\n\ntrain/loss\nDecreasing, 0.3-0.7\nFlat or increasing = broken data or too high LR\n\n\nrewards/chosen\nIncreasing\nFlat = model not learning preferences\n\n\nrewards/rejected\nDecreasing\nIncreasing = model prefers wrong responses\n\n\nrewards/margins\nPositive and increasing\nNegative = prefers rejected over chosen\n\n\nrewards/accuracies\n> 0.5, toward 0.7+\n< 0.5 = worse than random\n\n\nlogps/rejected\nDecreasing\nIncreasing = reward hacking\n\n\ngrad_norm\n0.01 - 10.0\n> 100 = exploding gradients\n\n\n\nMethod-specific: DPO/IPO watch rewards/margins; KTO loss is noisier; ORPO monitor SFT + odds ratio components; SimPO check length-normalized reward separation."
},
{
"objectID": "docs/agents/preference_tuning.html#known-issues",
"href": "docs/agents/preference_tuning.html#known-issues",
"title": "Preference Learning (RLHF) — Agent Reference",
"section": "",
"text": "Issue\nFix\n\n\n\n\nSample packing crash\nSet sample_packing: false (required for all preference methods)\n\n\nKTO KeyError: 'label'\nEnsure dataset has boolean label column\n\n\nORPO/KTO KeyError during tokenization\nAdd remove_unused_columns: false\n\n\nORPO template not applied\nORPO requires explicit chat_template setting\n\n\nOOM with ref model (DPO/IPO/KTO)\nUse LoRA/QLoRA, or switch to ORPO/SimPO (no ref model)\n\n\nIPO + label_smoothing\nDo not set dpo_label_smoothing when rl: ipo\n\n\n\nFull troubleshooting: training_stability.qmd"
},
{
"objectID": "docs/agents/preference_tuning.html#file-map",
"href": "docs/agents/preference_tuning.html#file-map",
"title": "Preference Learning (RLHF) — Agent Reference",
"section": "",
"text": "src/axolotl/\n core/trainers/dpo/ # DPO trainer, args, strategy\n core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer class\n core/training_args.py # AxolotlKTOConfig, AxolotlORPOConfig, AxolotlCPOConfig\n prompt_strategies/\n dpo/ # DPO/IPO/SimPO dataset strategies\n chat_template.py # chat_template.default, chat_template.argilla_chat\n chatml.py # chatml.default/intel/icr/argilla_chat/prompt_pairs/ultra\n llama3.py # llama3 variants (same subtypes as chatml)\n user_defined.py # Custom field mapping\n passthrough.py # No transform\n kto/ # KTO dataset strategies (chatml, llama3, user_defined)\n orpo/ # ORPO dataset strategies (chat_template.argilla)\n utils/schemas/enums.py # RLType enum (dpo, ipo, kto, orpo, simpo, grpo, gdpo, ebft)\n utils/schemas/config.py # All rl/dpo/kto/orpo/simpo config fields\n\ndocs/rlhf.qmd # Full user docs: all dataset formats, config templates\ndocs/choosing_method.qmd # SFT vs DPO vs GRPO decision guide\nexamples/qwen2/dpo.yaml # DPO example\nexamples/llama-3/qlora-1b-kto.yaml # KTO example"
},
{
"objectID": "docs/agents/reward_modelling.html",
"href": "docs/agents/reward_modelling.html",
"title": "Reward Modelling — Agent Reference",
"section": "",
"text": "Train models to score responses for use as reward signals in RL. For full docs, see reward_modelling.qmd.\n\n\n\n\nTrain a classifier to predict preference over entire interactions. Uses AutoModelForSequenceClassification.\nbase_model: google/gemma-2-2b\nmodel_type: AutoModelForSequenceClassification\nnum_labels: 1\nreward_model: true\nchat_template: gemma\ndatasets:\n - path: argilla/distilabel-intel-orca-dpo-pairs\n type: bradley_terry.chat_template\nDataset format: {\"system\": \"...\", \"input\": \"...\", \"chosen\": \"...\", \"rejected\": \"...\"}\n\n\n\nTrain a token classifier to score each reasoning step. Uses AutoModelForTokenClassification.\nbase_model: Qwen/Qwen2.5-3B\nmodel_type: AutoModelForTokenClassification\nnum_labels: 2\nprocess_reward_model: true\ndatasets:\n - path: trl-lib/math_shepherd\n type: stepwise_supervised\nDataset format: see stepwise_supervised.qmd.\n\n\n\n\nsrc/axolotl/\n core/builders/causal.py # Handles reward_model flag in trainer builder\n prompt_strategies/bradley_terry/ # Bradley-Terry prompt strategies\n prompt_strategies/stepwise_supervised.py # PRM dataset strategy\n utils/schemas/config.py # reward_model, process_reward_model config fields"
},
{
"objectID": "docs/agents/reward_modelling.html#types",
"href": "docs/agents/reward_modelling.html#types",
"title": "Reward Modelling — Agent Reference",
"section": "",
"text": "Train a classifier to predict preference over entire interactions. Uses AutoModelForSequenceClassification.\nbase_model: google/gemma-2-2b\nmodel_type: AutoModelForSequenceClassification\nnum_labels: 1\nreward_model: true\nchat_template: gemma\ndatasets:\n - path: argilla/distilabel-intel-orca-dpo-pairs\n type: bradley_terry.chat_template\nDataset format: {\"system\": \"...\", \"input\": \"...\", \"chosen\": \"...\", \"rejected\": \"...\"}\n\n\n\nTrain a token classifier to score each reasoning step. Uses AutoModelForTokenClassification.\nbase_model: Qwen/Qwen2.5-3B\nmodel_type: AutoModelForTokenClassification\nnum_labels: 2\nprocess_reward_model: true\ndatasets:\n - path: trl-lib/math_shepherd\n type: stepwise_supervised\nDataset format: see stepwise_supervised.qmd."
},
{
"objectID": "docs/agents/reward_modelling.html#file-map",
"href": "docs/agents/reward_modelling.html#file-map",
"title": "Reward Modelling — Agent Reference",
"section": "",
"text": "src/axolotl/\n core/builders/causal.py # Handles reward_model flag in trainer builder\n prompt_strategies/bradley_terry/ # Bradley-Terry prompt strategies\n prompt_strategies/stepwise_supervised.py # PRM dataset strategy\n utils/schemas/config.py # reward_model, process_reward_model config fields"
},
{
"objectID": "docs/optimizations.html",
"href": "docs/optimizations.html",
"title": "Optimizations Guide",
"section": "",
"text": "Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.\nThis guide provides a high-level overview and directs you to the detailed documentation for each feature.",
"crumbs": [
"How To Guides",
"Optimizations Guide"
]
},
{
"objectID": "docs/optimizations.html#speed-optimizations",
"href": "docs/optimizations.html#speed-optimizations",
"title": "Optimizations Guide",
"section": "Speed Optimizations",
"text": "Speed Optimizations\nThese optimizations focus on increasing training throughput and reducing total training time.\n\nSample Packing\nImproves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the attention implementations below.\n\nConfig: sample_packing: true\nLearn more: Sample Packing\n\n\n\nAttention Implementations\nUsing an optimized attention implementation is critical for training speed.\n\nFlash Attention 2: flash_attention: true. (Recommended) The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check AMD Support.\nFlex Attention: flex_attention: true.\nSDP Attention: sdp_attention: true. PyTorchs native implementation.\nXformers: xformers_attention: true. Works with FP16.\n\nNote: You should only enable one attention backend.\n\n\nLoRA Optimizations\nLeverages optimized kernels to accelerate LoRA training and reduce memory usage.\n\nLearn more: LoRA Optimizations Documentation",
"crumbs": [
"How To Guides",
"Optimizations Guide"
]
},
{
"objectID": "docs/optimizations.html#memory-optimizations",
"href": "docs/optimizations.html#memory-optimizations",
"title": "Optimizations Guide",
"section": "Memory Optimizations",
"text": "Memory Optimizations\nThese techniques help you fit larger models or use bigger batch sizes on your existing hardware.\n\nParameter Efficient Finetuning (LoRA & QLoRA)\nDrastically reduces memory by training a small set of “adapter” parameters instead of the full model. This is the most common and effective memory-saving technique.\n\nExamples: Find configs with lora or qlora in the examples directory.\nConfig Reference: See adapter, load_in_4bit, and load_in_8bit in the Configuration Reference.\n\n\n\nGradient Checkpointing & Activation Offloading\nThese techniques save VRAM by changing how activations are handled.\n\nGradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.\nActivation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.\nLearn more: Gradient Checkpointing and Offloading Docs\n\n\n\nLayer Offloading\nOffloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.\n\nConfig: layer_offloading: true\nLearn more: Layer Offloading Docs\n\n\n\nCut Cross Entropy (CCE)\nReduces VRAM usage by using an optimized cross-entropy loss calculation.\n\nLearn more: Custom Integrations - CCE\n\n\n\nLiger Kernels\nProvides efficient Triton kernels to improve training speed and reduce memory usage.\n\nLearn more: Custom Integrations - Liger Kernels\n\n\n\nExpert Kernels\nOptimized kernel implementations for Mixture of Experts (MoE) model training.\n\nScatterMoE: Triton-based MoE kernels with fused LoRA support.\nSonicMoE: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.\nLearn more: Custom Integrations - Kernels Integration",
"crumbs": [
"How To Guides",
"Optimizations Guide"
]
},
{
"objectID": "docs/optimizations.html#long-context-models",
"href": "docs/optimizations.html#long-context-models",
"title": "Optimizations Guide",
"section": "Long Context Models",
"text": "Long Context Models\nTechniques to train models on sequences longer than their original context window.\n\nRoPE Scaling\nExtends a models context window by interpolating its Rotary Position Embeddings.\n\nConfig: Pass the rope_scaling config under the overrides_of_model_config:. To learn how to set RoPE, check the respective model config.\n\n\n\nSequence Parallelism\nSplits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.\n\nLearn more: Sequence Parallelism Documentation\n\n\n\nArtic Long Sequence Training (ALST)\nALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:\n\nTiledMLP to reduce memory usage in MLP layers.\nTiled Loss functions (like CCE.\nActivation Offloading to CPU.\nExample: ALST Example Configuration",
"crumbs": [
"How To Guides",
"Optimizations Guide"
]
},
{
"objectID": "docs/optimizations.html#large-models-distributed-training",
"href": "docs/optimizations.html#large-models-distributed-training",
"title": "Optimizations Guide",
"section": "Large Models (Distributed Training)",
"text": "Large Models (Distributed Training)\nTo train models that dont fit on a single GPU, youll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.\n\nLearn more: Multi-GPU Guide\nLearn more: Multi-Node Guide\n\n\nN-D Parallelism (Beta)\nFor advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.\n\nLearn more: N-D Parallelism Guide",
"crumbs": [
"How To Guides",
"Optimizations Guide"
]
},
{
"objectID": "docs/optimizations.html#quantization",
"href": "docs/optimizations.html#quantization",
"title": "Optimizations Guide",
"section": "Quantization",
"text": "Quantization\nTechniques to reduce the precision of model weights for memory savings.\n\n4-bit Training (QLoRA)\nThe recommended approach for quantization-based training. It loads the base model in 4-bit using bitsandbytes and then trains QLoRA adapters. See Adapter Finetuning for details.\n\n\nFP8 Training\nEnables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.\n\nExample: Llama 3 FP8 FSDP Example\n\n\n\nQuantization Aware Training (QAT)\nSimulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.\n\nLearn more: QAT Documentation\n\n\n\nGPTQ\nAllows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.\n\nExample: GPTQ LoRA Example\n\n\n\nMoE Expert Quantization\nQuantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused nn.Parameter tensors.\n\nConfig: quantize_moe_experts: true\nLearn more: MoE Expert Quantization",
"crumbs": [
"How To Guides",
"Optimizations Guide"
]
},
{
"objectID": "docs/training_stability.html",
"href": "docs/training_stability.html",
"title": "Training Stability & Debugging",
"section": "",
"text": "This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows.",
"crumbs": [
"Troubleshooting",
"Training Stability & Debugging"
]
},
{
"objectID": "docs/training_stability.html#monitoring-training",
"href": "docs/training_stability.html#monitoring-training",
"title": "Training Stability & Debugging",
"section": "Monitoring Training",
"text": "Monitoring Training\n\nKey Metrics for SFT\nEvery SFT run should be monitored through at least these four metrics:\n\n\n\n\n\n\n\n\nMetric\nWhat It Tells You\nHealthy Range\n\n\n\n\ntrain/loss\nHow well the model fits training data\nDecreasing; typically 0.52.0 for chat fine-tuning\n\n\neval/loss\nGeneralization performance\nTracks train loss with small gap; divergence signals overfitting\n\n\ngrad_norm\nGradient magnitude\n0.110.0; spikes above 100 indicate instability\n\n\nlearning_rate\nCurrent LR from scheduler\nShould follow expected schedule (warmup then decay)\n\n\n\n\n\n\n\n\n\nTipSet Up Logging Early\n\n\n\nEnable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork.\nwandb_project: my-project\nwandb_run_id: # optional, for resuming\nlogging_steps: 1\n\n\n\n\nKey Metrics for RL (GRPO)\nGRPO training logs a richer set of metrics. These are the critical ones:\n\n\n\n\n\n\n\n\nMetric\nHealthy Range\nRed Flag\n\n\n\n\nrewards/<name>/mean\n> 0.15 within 20 steps\nStays at 0 reward function is broken or task is too hard\n\n\nreward_std\n> 0 on most steps\nAlways 0 no learning signal (all completions get the same reward)\n\n\nfrac_reward_zero_std\n< 0.8\n1.0 on every step zero-advantage skip fires constantly, no gradient updates\n\n\ngrad_norm\n0.0011.0\n0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable\n\n\nentropy\n0.050.5\n< 0.01 suggests mode collapse; > 1.0 suggests the model is not converging\n\n\nkl\n0.00.5\n> 2.0 suggests policy has diverged too far from reference\n\n\nsampling/sampling_logp_difference/mean\n< 0.1\n> 1.0 means policy has diverged far from vLLM server weights\n\n\nsampling/importance_sampling_ratio/min\n> 0.1\nNear 0 indicates stale off-policy data; increase vllm_sync_interval\n\n\nclip_ratio/region_mean\n< 0.1\n> 0.3 means PPO clipping is too aggressive\n\n\ncompletions/mean_length\nTask-dependent\nMonotonically increasing to max length suggests reward hacking\n\n\ncompletions/clipped_ratio\n< 0.3\n> 0.8 means most completions hit max_completion_length increase it\n\n\n\n\n\n\n\n\n\nNoteEBFT-Specific Metrics\n\n\n\nFor EBFT training, also monitor ebft/alignment (should trend upward, healthy 0.30.9), ebft/diversity (healthy 0.010.1; > 1.0 indicates mode collapse), and ebft/cfm_loss (should trend downward, < 10).",
"crumbs": [
"Troubleshooting",
"Training Stability & Debugging"
]
},
{
"objectID": "docs/training_stability.html#sft-stability",
"href": "docs/training_stability.html#sft-stability",
"title": "Training Stability & Debugging",
"section": "SFT Stability",
"text": "SFT Stability\n\nLoss Plateau\nSymptom: Loss stops decreasing early in training, well above expected values.\nCauses and fixes:\n\nLearning rate too low: Increase by 25x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4.\nInsufficient warmup: Set warmup_steps to 510% of total steps. Too-aggressive learning at the start can push the model into a flat region.\nData quality: Check that labels are correctly masked. Use axolotl preprocess and inspect tokenized samples to confirm only the target tokens are trainable.\nWeight decay too high: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA.\n\n\n\nLoss Spikes\nSymptom: Loss suddenly jumps by 210x then (possibly) recovers.\nCauses and fixes:\n\nBad data samples: A single malformed or extremely long example can cause a spike. Enable sample_packing: false temporarily and check if spikes correlate with specific batches.\nLearning rate too high: Reduce by 25x, or increase warmup.\nGradient accumulation mismatch: Effective batch size = micro_batch_size * gradient_accumulation_steps * num_gpus. Very large effective batch sizes amplify gradient noise.\nMixed precision issues: With bf16: true, some operations can lose precision. If spikes are severe, try fp32 for diagnosis.\n\n\n\nOverfitting\nSymptom: Train loss keeps decreasing but eval loss starts increasing.\nFixes:\n\nIncrease val_set_size (e.g., 0.05) and monitor eval/loss.\nReduce num_epochs or max_steps.\nIncrease weight_decay (try 0.010.1).\nUse a smaller LoRA rank (lora_r). Typical values: 832.\nIncrease dropout: lora_dropout: 0.05.",
"crumbs": [
"Troubleshooting",
"Training Stability & Debugging"
]
},
{
"objectID": "docs/training_stability.html#rlgrpo-stability",
"href": "docs/training_stability.html#rlgrpo-stability",
"title": "Training Stability & Debugging",
"section": "RL/GRPO Stability",
"text": "RL/GRPO Stability\n\nReward Never Increases\nIf rewards/*/mean stays at 0 for more than 20 steps:\n\nTest reward function standalone: Run it outside training with known inputs to verify it returns nonzero values.\ncd experiments && python -c \"import my_rewards; print(my_rewards.accuracy_reward(...))\"\nCheck dataset columns: The reward function receives **kwargs containing dataset columns. Verify the columns it needs (e.g., answer) are not removed by the dataset transform.\nCheck completion content: Enable log_completions: true in the trl: config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task.\nVerify vLLM is serving the right model: Hit the vLLM health endpoint and confirm the model name matches your config.\n\n\n\nEntropy Collapse (Mode Collapse)\nSymptom: entropy drops below 0.01; all completions become nearly identical.\nFixes:\n\nIncrease temperature in generation kwargs (try 0.81.0).\nReduce learning rate.\nAdd a KL penalty term (beta parameter in GRPO config).\nCheck that num_generations is sufficient (16+ gives better advantage estimates).\n\n\n\nIS Ratio Divergence\nSymptom: sampling/importance_sampling_ratio/min drops near 0, or sampling/sampling_logp_difference/mean exceeds 1.0.\nThis means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable.\nFixes:\n\nDecrease vllm_sync_interval (sync weights more often).\nEnable off_policy_mask_threshold (e.g., 0.5) to mask stale off-policy samples.\nUse importance_sampling_level: token for finer-grained correction.\n\n\n\nGradient Norm Instability\nSymptom: grad_norm oscillates wildly or exceeds 10.0 regularly.\nFixes:\n\nEnable gradient clipping: max_grad_norm: 1.0 (default in most configs).\nReduce learning rate.\nIncrease gradient_accumulation_steps to smooth out noisy batches.\nCheck for NaN issues (see next section).",
"crumbs": [
"Troubleshooting",
"Training Stability & Debugging"
]
},
{
"objectID": "docs/training_stability.html#nan-and-inf-handling",
"href": "docs/training_stability.html#nan-and-inf-handling",
"title": "Training Stability & Debugging",
"section": "NaN and Inf Handling",
"text": "NaN and Inf Handling\n\nCommon Causes\n\n\n\n\n\n\n\n\nCause\nWhere It Manifests\nDetection\n\n\n\n\nFP8 zero-scale division\nForward pass logits\ngrad_norm: nan, loss becomes NaN immediately\n\n\nGradient explosion\nBackward pass\ngrad_norm spikes to inf, then loss goes NaN\n\n\nBad data (empty sequences)\nLogprob computation\nNaN in specific batches only\n\n\nNumerical overflow in log-softmax\nLoss computation\nLarge negative logprobs cause exp() overflow\n\n\n\n\n\nFP8-Specific NaN Issues\nFP8 quantization (fp8: true) can produce NaN when the activation quantization kernel divides by max(abs(x)) / 448. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero.\nFixes applied in axolotl:\n\nThe act_quant_kernel has a zero-guard: s = tl.where(s == 0, 1.0, s).\nA safety net nan_to_num(logits, nan=0.0) is applied in _get_per_token_logps_and_entropies.\nEmbedding padding is zero-padded for FP8 compatibility.\n\n\n\n\n\n\n\nImportantAfter Modifying Triton Kernels\n\n\n\nIf you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect:\nrm -rf ~/.triton/cache\n\n\n\n\nGeneral NaN Debugging Steps\n\nEnable anomaly detection (slow, but pinpoints the source):\ntorch.autograd.set_detect_anomaly(True)\nCheck grad_norm: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad_norm was fine on the previous step, the forward pass is the problem.\nReduce to single GPU, single batch: Eliminate distributed training variables.\nInspect data: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns.",
"crumbs": [
"Troubleshooting",
"Training Stability & Debugging"
]
},
{
"objectID": "docs/training_stability.html#oom-debugging",
"href": "docs/training_stability.html#oom-debugging",
"title": "Training Stability & Debugging",
"section": "OOM Debugging",
"text": "OOM Debugging\nOut-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive:\n\nStep 1: Reduce Batch Size\nThe single highest-impact change. VRAM scales roughly linearly with batch size.\nmicro_batch_size: 1 # Start here\ngradient_accumulation_steps: 16 # Increase to maintain effective batch size\nFor GRPO specifically, the logits tensor for policy logprob computation can be very large. batch_size * num_generations * seq_len * vocab_size in bf16. For example, with num_generations: 16 and micro_batch_size: 8, the logits tensor alone is:\n8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB (way too large)\nReduce micro_batch_size to 24 for GRPO.\n\n\nStep 2: Enable Gradient Checkpointing\nTrades compute for memory by recomputing activations during the backward pass instead of storing them.\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n use_reentrant: false # Recommended default\n\n\n\n\n\n\nWarningReentrant Checkpointing Exceptions\n\n\n\nSome configurations require use_reentrant: true:\n\nDeepSpeed ZeRO-3 (non-reentrant causes CheckpointError)\nEBFT strided mode with flex_attention\n\n\n\n\n\nStep 3: Use Quantization\nLoad the base model in reduced precision:\n# 4-bit QLoRA\nadapter: qlora\nload_in_4bit: true\n\n# 8-bit\nload_in_8bit: true\n\n# FP8 (saves ~50% model VRAM, same compute speed as bf16)\nfp8: true\n\n\nStep 4: Reduce Sequence Length\nsequence_len: 1024 # Down from 2048 or 4096\nFor GRPO, also reduce max_completion_length. Memory scales quadratically with sequence length when using standard attention.\n\n\nStep 5: Use Flash Attention\nReduces attention memory from O(n^2) to O(n):\nflash_attention: true\n\n\nStep 6: Offload with DeepSpeed\nFor extreme cases, offload optimizer states or parameters to CPU:\ndeepspeed: deepspeed_configs/zero3_bf16.json\n\n\nDiagnosing the Specific Culprit\nUse the profiler_steps config option to capture GPU memory snapshots:\nprofiler_steps: [1, 2]\nThis generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM.",
"crumbs": [
"Troubleshooting",
"Training Stability & Debugging"
]
},
{
"objectID": "docs/training_stability.html#common-errors",
"href": "docs/training_stability.html#common-errors",
"title": "Training Stability & Debugging",
"section": "Common Errors",
"text": "Common Errors\n\n\n\nError Message\nLikely Cause\nFix\n\n\n\n\nexitcode: -9\nSystem RAM exhaustion\nReduce dataset size, dataset_num_proc, or number of data workers\n\n\nexitcode: -7 (DeepSpeed)\nDeepSpeed version issue\npip install -U deepspeed\n\n\nCUDA out of memory\nGPU VRAM exhaustion\nFollow OOM debugging steps above\n\n\nRuntimeError: NCCL communicator was aborted\nGPU communication failure\nSee NCCL docs; check NCCL_DEBUG=INFO output\n\n\nValueError: Asking to pad but the tokenizer does not have a padding token\nMissing pad token\nAdd special_tokens: { pad_token: \"<\\|endoftext\\|>\" } to config\n\n\n'DummyOptim' object has no attribute 'step'\nDeepSpeed on single GPU\nRemove deepspeed: section from config\n\n\nunable to load strategy X then None is not callable\nReward module not importable\nRun cd experiments && python -c \"import my_rewards\" to check\n\n\ngeneration_batch_size not divisible by num_generations\nmicro_batch_size too small\nSet micro_batch_size >= num_generations and make it divisible\n\n\n'weight' must be 2-D\nFSDP1 flattened parameters\nUse fsdp_version: 2 or skip unwrap_model when FSDP is enabled\n\n\nCheckpointError (tensor count mismatch)\nNon-reentrant checkpointing + ZeRO-3 or flex_attention\nSet use_reentrant: true in gradient_checkpointing_kwargs\n\n\nBFloat16 TypeError during weight sync\nNumPy does not support bf16\nFixed in axolotls weight_serde.py (auto bf16 to fp16 conversion)\n\n\nContent end boundary is before start boundary\nChat template parsing issue\nCheck eos_token matches template; file a GitHub issue if persistent\n\n\nCAS service error during data processing\nHuggingFace XET issue\nSet export HF_HUB_DISABLE_XET=1\n\n\nTraining hangs (multi-GPU)\nFSDP + async prefetch deadlock\nSet async_prefetch: false with FSDP",
"crumbs": [
"Troubleshooting",
"Training Stability & Debugging"
]
},
{
"objectID": "docs/training_stability.html#profiling",
"href": "docs/training_stability.html#profiling",
"title": "Training Stability & Debugging",
"section": "Profiling",
"text": "Profiling\n\nPyTorch Profiler\nAxolotl supports PyTorch profiler integration via the config:\nprofiler_steps: [1, 2, 3]\nThis captures profiler traces for the specified steps. View them in TensorBoard:\ntensorboard --logdir output_dir/runs\nOr open the .json trace file in chrome://tracing.\n\n\nCUDA Memory Snapshots\nFor detailed memory analysis, use PyTorchs memory snapshot API. Add this to your training script or use it interactively:\nimport torch\n\n# Enable memory history tracking\ntorch.cuda.memory._record_memory_history()\n\n# ... run your training step ...\n\n# Save snapshot\ntorch.cuda.memory._dump_snapshot(\"memory_snapshot.pickle\")\nVisualize with PyTorchs memory visualizer:\npython -m torch.cuda.memory._viz memory_snapshot.pickle\n\n\nQuick GPU Memory Check\nDuring training, monitor GPU utilization in a separate terminal:\nwatch -n 1 nvidia-smi\nFor programmatic access within axolotl, the logged metrics memory/max_alloc and memory/max_reserved come from torch.cuda.max_memory_allocated() and torch.cuda.max_memory_reserved(). Note these report PyTorchs view of memory, which may differ from nvidia-smi (see FAQ).",
"crumbs": [
"Troubleshooting",
"Training Stability & Debugging"
]
},
{
"objectID": "docs/training_stability.html#wb-and-logging",
"href": "docs/training_stability.html#wb-and-logging",
"title": "Training Stability & Debugging",
"section": "W&B and Logging",
"text": "W&B and Logging\n\nEnabling Logging\nwandb_project: my-project\nwandb_entity: my-team # optional\nwandb_run_id: run-123 # optional, for resuming\nwandb_name: experiment-name # optional\nlogging_steps: 1 # log every step (recommended for RL)\n\n\nDebug Logging\nFor detailed axolotl-internal debug output:\nAXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.log\n\n\n\n\n\n\nTipAlways Log to a File\n\n\n\nPipe training output to a log file so you can inspect it after the run:\naxolotl train config.yaml 2>&1 | tee /tmp/my_run.log\n\n\n\n\nWhat Axolotl Logs\nSFT metrics (logged every logging_steps):\n\ntrain/loss, eval/loss training and validation loss\ntrain/grad_norm gradient L2 norm (before clipping)\ntrain/learning_rate current learning rate\nmemory/max_alloc, memory/max_reserved peak GPU memory\n\nGRPO/RL metrics (logged every step):\n\nrewards/<name>/mean, rewards/<name>/std per-reward-function statistics\nreward, reward_std aggregated reward across all reward functions\nfrac_reward_zero_std fraction of prompt groups where all completions got the same reward\ncompletions/mean_length, completions/min_length, completions/max_length completion token lengths\ncompletions/clipped_ratio fraction of completions that hit the max length\ncompletions/mean_terminated_length, completions/min_terminated_length, completions/max_terminated_length lengths of naturally terminated completions\nkl KL divergence between policy and reference\nentropy policy entropy (measure of output diversity)\nclip_ratio/region_mean, clip_ratio/low_mean, clip_ratio/high_mean PPO clipping statistics\nsampling/sampling_logp_difference/mean, sampling/sampling_logp_difference/max log-probability difference between policy and sampling distribution\nsampling/importance_sampling_ratio/min, sampling/importance_sampling_ratio/mean, sampling/importance_sampling_ratio/max IS ratio statistics for off-policy correction\nnum_tokens total tokens processed\n\n\n\nReading W&B Charts\nFor a healthy GRPO run, expect to see:\n\nreward/mean: Gradual upward trend. May start near 0 and reach 0.30.8 depending on task difficulty. Not monotonic fluctuations are normal.\nentropy: Gradual decrease from initial values (often 0.30.6) as the model becomes more confident. Should not collapse to near-zero.\ngrad_norm: Mostly in the 0.0011.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation.\nkl: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference.\ncompletions/mean_length: Should reflect the tasks natural answer length. If it steadily increases to max_completion_length, the model may be reward-hacking by generating longer outputs.",
"crumbs": [
"Troubleshooting",
"Training Stability & Debugging"
]
},
{
"objectID": "docs/cli.html",
"href": "docs/cli.html",
"title": "Command Line Interface (CLI)",
"section": "",
"text": "The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers\nthe CLI commands, their usage, and common examples.",
"crumbs": [
"Getting Started",
"Command Line Interface (CLI)"
]
},
{
"objectID": "docs/cli.html#basic-commands",
"href": "docs/cli.html#basic-commands",
"title": "Command Line Interface (CLI)",
"section": "Basic Commands",
"text": "Basic Commands\nAll Axolotl commands follow this general structure:\naxolotl <command> [config.yml] [options]\nThe config file can be local or a URL to a raw YAML file.\n\nLauncher Arguments\nFor commands that support multi-GPU (train, evaluate, …), you can pass launcher-specific arguments using the -- separator:\n# Pass torchrun arguments\naxolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1\n\n# Pass accelerate arguments\naxolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4\nArguments after -- are passed directly to the launcher (torchrun, accelerate launch, etc.).",
"crumbs": [
"Getting Started",
"Command Line Interface (CLI)"
]
},
{
"objectID": "docs/cli.html#command-reference",
"href": "docs/cli.html#command-reference",
"title": "Command Line Interface (CLI)",
"section": "Command Reference",
"text": "Command Reference\n\nfetch\nDownloads example configurations and deepspeed configs to your local machine.\n# Get example YAML files\naxolotl fetch examples\n\n# Get deepspeed config files\naxolotl fetch deepspeed_configs\n\n# Specify custom destination\naxolotl fetch examples --dest path/to/folder\n\n\npreprocess\nPreprocesses and tokenizes your dataset before training. This is recommended for large datasets.\n# Basic preprocessing\naxolotl preprocess config.yml\n\n# Preprocessing with one GPU\nCUDA_VISIBLE_DEVICES=\"0\" axolotl preprocess config.yml\n\n# Debug mode to see processed examples\naxolotl preprocess config.yml --debug\n\n# Debug with limited examples\naxolotl preprocess config.yml --debug --debug-num-examples 5\nConfiguration options:\ndataset_prepared_path: Local folder for saving preprocessed data\npush_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)\n\n\ntrain\nTrains or fine-tunes a model using the configuration specified in your YAML file.\n# Basic training\naxolotl train config.yml\n\n# Train and set/override specific options\naxolotl train config.yml \\\n --learning-rate 1e-4 \\\n --micro-batch-size 2 \\\n --num-epochs 3\n\n# Training without accelerate\naxolotl train config.yml --launcher python\n\n# Pass launcher-specific arguments using -- separator\naxolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1\naxolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml\n\n# Resume training from checkpoint\naxolotl train config.yml --resume-from-checkpoint path/to/checkpoint\nIt is possible to run sweeps over multiple hyperparameters by passing in a sweeps config.\n# Basic training with sweeps\naxolotl train config.yml --sweep path/to/sweep.yaml\nExample sweep config:\n_:\n # This section is for dependent variables we need to fix\n - load_in_8bit: false\n load_in_4bit: false\n adapter: lora\n - load_in_8bit: true\n load_in_4bit: false\n adapter: lora\n\n# These are independent variables\nlearning_rate: [0.0003, 0.0006]\nlora_r:\n - 16\n - 32\nlora_alpha:\n - 16\n - 32\n - 64\n\n\ninference\nRuns inference using your trained model in either CLI or Gradio interface mode.\n# CLI inference with LoRA\naxolotl inference config.yml --lora-model-dir=\"./outputs/lora-out\"\n\n# CLI inference with full model\naxolotl inference config.yml --base-model=\"./completed-model\"\n\n# Gradio web interface\naxolotl inference config.yml --gradio \\\n --lora-model-dir=\"./outputs/lora-out\"\n\n# Inference with input from file\ncat prompt.txt | axolotl inference config.yml \\\n --base-model=\"./completed-model\"\n\n\nmerge-lora\nMerges trained LoRA adapters into the base model.\n# Basic merge\naxolotl merge-lora config.yml\n\n# Specify LoRA directory (usually used with checkpoints)\naxolotl merge-lora config.yml --lora-model-dir=\"./lora-output/checkpoint-100\"\n\n# Merge using CPU (if out of GPU memory)\nCUDA_VISIBLE_DEVICES=\"\" axolotl merge-lora config.yml\nConfiguration options:\ngpu_memory_limit: Limit GPU memory usage\nlora_on_cpu: Load LoRA weights on CPU\n\n\nmerge-sharded-fsdp-weights\nMerges sharded FSDP model checkpoints into a single combined checkpoint.\n# Basic merge\naxolotl merge-sharded-fsdp-weights config.yml\n\n\nevaluate\nEvaluates a models performance (loss etc) on the train and eval datasets.\n# Basic evaluation\naxolotl evaluate config.yml\n\n# Evaluation with launcher arguments\naxolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2\n\n\nlm-eval\nRuns LM Evaluation Harness on your model.\n# Basic evaluation\naxolotl lm-eval config.yml\nConfiguration options:\nlm_eval_model: # model to evaluate (local or hf path)\n\n# List of tasks to evaluate\nlm_eval_tasks:\n - arc_challenge\n - hellaswag\nlm_eval_batch_size: # Batch size for evaluation\noutput_dir: # Directory to save evaluation results\nSee LM Eval Harness integration docs for full configuration details.\n\n\ndelinearize-llama4\nDelinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.\naxolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir\nThis would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.\n\n\nquantize\nQuantizes a model using the quantization configuration specified in your YAML file.\naxolotl quantize config.yml\nSee Quantization for more details.",
"crumbs": [
"Getting Started",
"Command Line Interface (CLI)"
]
},
{
"objectID": "docs/cli.html#legacy-cli-usage",
"href": "docs/cli.html#legacy-cli-usage",
"title": "Command Line Interface (CLI)",
"section": "Legacy CLI Usage",
"text": "Legacy CLI Usage\nWhile the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:\n# Preprocess\npython -m axolotl.cli.preprocess config.yml\n\n# Train\naccelerate launch -m axolotl.cli.train config.yml\n\n# Inference\naccelerate launch -m axolotl.cli.inference config.yml \\\n --lora_model_dir=\"./outputs/lora-out\"\n\n# Gradio interface\naccelerate launch -m axolotl.cli.inference config.yml \\\n --lora_model_dir=\"./outputs/lora-out\" --gradio\n\n\n\n\n\n\nImportant\n\n\n\nWhen overriding CLI parameters in the legacy CLI, use same notation as in yaml file (e.g., --lora_model_dir).\nNote: This differs from the new Click-based CLI, which uses dash notation (e.g., --lora-model-dir). Keep this in mind if youre referencing newer documentation or switching between CLI versions.",
"crumbs": [
"Getting Started",
"Command Line Interface (CLI)"
]
},
{
"objectID": "docs/cli.html#remote-compute-with-modal-cloud",
"href": "docs/cli.html#remote-compute-with-modal-cloud",
"title": "Command Line Interface (CLI)",
"section": "Remote Compute with Modal Cloud",
"text": "Remote Compute with Modal Cloud\nAxolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a\ncloud YAML file alongside your regular Axolotl config.\n\nCloud Configuration\nCreate a cloud config YAML with your Modal settings:\n# cloud_config.yml\nprovider: modal\ngpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4\ngpu_count: 1 # Number of GPUs to use\ntimeout: 86400 # Maximum runtime in seconds (24 hours)\nbranch: main # Git branch to use (optional)\n\nvolumes: # Persistent storage volumes\n - name: axolotl-cache\n mount: /workspace/cache\n - name: axolotl-data\n mount: /workspace/data\n - name: axolotl-artifacts\n mount: /workspace/artifacts\n\nsecrets: # Secrets to inject\n - WANDB_API_KEY\n - HF_TOKEN\n\n\nRunning on Modal Cloud\nCommands that support the cloud flag:\n# Preprocess on cloud\naxolotl preprocess config.yml --cloud cloud_config.yml\n\n# Train on cloud\naxolotl train config.yml --cloud cloud_config.yml\n\n# Run lm-eval on cloud\naxolotl lm-eval config.yml --cloud cloud_config.yml\n\n\nCloud Configuration Options\nprovider: # compute provider, currently only `modal` is supported\ngpu: # GPU type to use\ngpu_count: # Number of GPUs (default: 1)\nmemory: # RAM in GB (default: 128)\ntimeout: # Maximum runtime in seconds\ntimeout_preprocess: # Preprocessing timeout\nbranch: # Git branch to use\ndocker_tag: # Custom Docker image tag\nvolumes: # List of persistent storage volumes\n\n# Environment variables to pass. Can be specified in two ways:\n# 1. As a string: Will load the value from the host computer's environment variables\n# 2. As a key-value pair: Will use the specified value directly\n# Example:\n# env:\n# - CUSTOM_VAR # Loads from host's $CUSTOM_VAR\n# - {CUSTOM_VAR: \"value\"} # Uses \"value\" directly\nenv:\n\n# Secrets to inject. Same input format as `env` but for sensitive data.\nsecrets:\n # - HF_TOKEN\n # - WANDB_API_KEY",
"crumbs": [
"Getting Started",
"Command Line Interface (CLI)"
]
},
{
"objectID": "docs/api/utils.callbacks.mlflow_.html",
"href": "docs/api/utils.callbacks.mlflow_.html",
"title": "utils.callbacks.mlflow_",
"section": "",
"text": "utils.callbacks.mlflow_\nMLFlow module for trainer callbacks\n\n\n\n\n\nName\nDescription\n\n\n\n\nSaveAxolotlConfigtoMlflowCallback\nCallback to save axolotl config to mlflow\n\n\n\n\n\nutils.callbacks.mlflow_.SaveAxolotlConfigtoMlflowCallback(axolotl_config_path)\nCallback to save axolotl config to mlflow"
},
{
"objectID": "docs/api/utils.callbacks.mlflow_.html#classes",
"href": "docs/api/utils.callbacks.mlflow_.html#classes",
"title": "utils.callbacks.mlflow_",
"section": "",
"text": "Name\nDescription\n\n\n\n\nSaveAxolotlConfigtoMlflowCallback\nCallback to save axolotl config to mlflow\n\n\n\n\n\nutils.callbacks.mlflow_.SaveAxolotlConfigtoMlflowCallback(axolotl_config_path)\nCallback to save axolotl config to mlflow"
},
{
"objectID": "docs/api/models.mamba.modeling_mamba.html",
"href": "docs/api/models.mamba.modeling_mamba.html",
"title": "models.mamba.modeling_mamba",
"section": "",
"text": "models.mamba.modeling_mamba\nmodels.mamba.modeling_mamba"
},
{
"objectID": "docs/api/core.trainers.dpo.trainer.html",
"href": "docs/api/core.trainers.dpo.trainer.html",
"title": "core.trainers.dpo.trainer",
"section": "",
"text": "core.trainers.dpo.trainer\nDPO trainer for axolotl\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlDPOTrainer\nExtend the base DPOTrainer for axolotl helpers.\n\n\n\n\n\ncore.trainers.dpo.trainer.AxolotlDPOTrainer(*args, dataset_tags=None, **kwargs)\nExtend the base DPOTrainer for axolotl helpers.\n\n\n\n\n\nName\nDescription\n\n\n\n\npush_to_hub\nOverwrite the push_to_hub method in order to force-add the tags when pushing\n\n\n\n\n\ncore.trainers.dpo.trainer.AxolotlDPOTrainer.push_to_hub(*args, **kwargs)\nOverwrite the push_to_hub method in order to force-add the tags when pushing\nthe model on the Hub. Please refer to ~transformers.Trainer.push_to_hub\nfor more details."
},
{
"objectID": "docs/api/core.trainers.dpo.trainer.html#classes",
"href": "docs/api/core.trainers.dpo.trainer.html#classes",
"title": "core.trainers.dpo.trainer",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlDPOTrainer\nExtend the base DPOTrainer for axolotl helpers.\n\n\n\n\n\ncore.trainers.dpo.trainer.AxolotlDPOTrainer(*args, dataset_tags=None, **kwargs)\nExtend the base DPOTrainer for axolotl helpers.\n\n\n\n\n\nName\nDescription\n\n\n\n\npush_to_hub\nOverwrite the push_to_hub method in order to force-add the tags when pushing\n\n\n\n\n\ncore.trainers.dpo.trainer.AxolotlDPOTrainer.push_to_hub(*args, **kwargs)\nOverwrite the push_to_hub method in order to force-add the tags when pushing\nthe model on the Hub. Please refer to ~transformers.Trainer.push_to_hub\nfor more details."
},
{
"objectID": "docs/api/cli.utils.fetch.html",
"href": "docs/api/cli.utils.fetch.html",
"title": "cli.utils.fetch",
"section": "",
"text": "cli.utils.fetch\nUtilities for axolotl fetch CLI command.\n\n\n\n\n\nName\nDescription\n\n\n\n\nfetch_from_github\nSync files from a specific directory in the GitHub repository.\n\n\n\n\n\ncli.utils.fetch.fetch_from_github(dir_prefix, dest_dir=None, max_workers=5)\nSync files from a specific directory in the GitHub repository.\nOnly downloads files that dont exist locally or have changed.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ndir_prefix\nstr\nDirectory prefix to filter files (e.g., examples/, deepspeed_configs/).\nrequired\n\n\ndest_dir\nstr | None\nLocal destination directory.\nNone\n\n\nmax_workers\nint\nMaximum number of concurrent downloads.\n5"
},
{
"objectID": "docs/api/cli.utils.fetch.html#functions",
"href": "docs/api/cli.utils.fetch.html#functions",
"title": "cli.utils.fetch",
"section": "",
"text": "Name\nDescription\n\n\n\n\nfetch_from_github\nSync files from a specific directory in the GitHub repository.\n\n\n\n\n\ncli.utils.fetch.fetch_from_github(dir_prefix, dest_dir=None, max_workers=5)\nSync files from a specific directory in the GitHub repository.\nOnly downloads files that dont exist locally or have changed.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ndir_prefix\nstr\nDirectory prefix to filter files (e.g., examples/, deepspeed_configs/).\nrequired\n\n\ndest_dir\nstr | None\nLocal destination directory.\nNone\n\n\nmax_workers\nint\nMaximum number of concurrent downloads.\n5"
},
{
"objectID": "docs/api/core.builders.causal.html",
"href": "docs/api/core.builders.causal.html",
"title": "core.builders.causal",
"section": "",
"text": "core.builders.causal\nBuilder for causal trainers\n\n\n\n\n\nName\nDescription\n\n\n\n\nHFCausalTrainerBuilder\nBuild the HuggingFace training args/trainer for causal models and reward modeling\n\n\n\n\n\ncore.builders.causal.HFCausalTrainerBuilder(\n cfg,\n model,\n tokenizer,\n processor=None,\n)\nBuild the HuggingFace training args/trainer for causal models and reward modeling\nusing TRL."
},
{
"objectID": "docs/api/core.builders.causal.html#classes",
"href": "docs/api/core.builders.causal.html#classes",
"title": "core.builders.causal",
"section": "",
"text": "Name\nDescription\n\n\n\n\nHFCausalTrainerBuilder\nBuild the HuggingFace training args/trainer for causal models and reward modeling\n\n\n\n\n\ncore.builders.causal.HFCausalTrainerBuilder(\n cfg,\n model,\n tokenizer,\n processor=None,\n)\nBuild the HuggingFace training args/trainer for causal models and reward modeling\nusing TRL."
},
{
"objectID": "docs/api/core.builders.rl.html",
"href": "docs/api/core.builders.rl.html",
"title": "core.builders.rl",
"section": "",
"text": "core.builders.rl\nBuilder for RLHF trainers\n\n\n\n\n\nName\nDescription\n\n\n\n\nHFRLTrainerBuilder\nTrainer factory class for TRL-based RLHF trainers (e.g. DPO)\n\n\n\n\n\ncore.builders.rl.HFRLTrainerBuilder(cfg, model, tokenizer, processor=None)\nTrainer factory class for TRL-based RLHF trainers (e.g. DPO)"
},
{
"objectID": "docs/api/core.builders.rl.html#classes",
"href": "docs/api/core.builders.rl.html#classes",
"title": "core.builders.rl",
"section": "",
"text": "Name\nDescription\n\n\n\n\nHFRLTrainerBuilder\nTrainer factory class for TRL-based RLHF trainers (e.g. DPO)\n\n\n\n\n\ncore.builders.rl.HFRLTrainerBuilder(cfg, model, tokenizer, processor=None)\nTrainer factory class for TRL-based RLHF trainers (e.g. DPO)"
},
{
"objectID": "docs/api/utils.bench.html",
"href": "docs/api/utils.bench.html",
"title": "utils.bench",
"section": "",
"text": "utils.bench\nBenchmarking and measurement utilities\n\n\n\n\n\nName\nDescription\n\n\n\n\ncheck_cuda_device\nwraps a function and returns the default value instead of running the\n\n\n\n\n\nutils.bench.check_cuda_device(default_value)\nwraps a function and returns the default value instead of running the\nwrapped function if cuda isnt available or the device is auto\n:param default_value:\n:return:"
},
{
"objectID": "docs/api/utils.bench.html#functions",
"href": "docs/api/utils.bench.html#functions",
"title": "utils.bench",
"section": "",
"text": "Name\nDescription\n\n\n\n\ncheck_cuda_device\nwraps a function and returns the default value instead of running the\n\n\n\n\n\nutils.bench.check_cuda_device(default_value)\nwraps a function and returns the default value instead of running the\nwrapped function if cuda isnt available or the device is auto\n:param default_value:\n:return:"
},
{
"objectID": "docs/api/prompt_strategies.kto.user_defined.html",
"href": "docs/api/prompt_strategies.kto.user_defined.html",
"title": "prompt_strategies.kto.user_defined",
"section": "",
"text": "prompt_strategies.kto.user_defined\nprompt_strategies.kto.user_defined\nUser-defined KTO strategies"
},
{
"objectID": "docs/api/prompt_strategies.alpaca_instruct.html",
"href": "docs/api/prompt_strategies.alpaca_instruct.html",
"title": "prompt_strategies.alpaca_instruct",
"section": "",
"text": "prompt_strategies.alpaca_instruct\nprompt_strategies.alpaca_instruct\nModule loading the AlpacaInstructPromptTokenizingStrategy class"
},
{
"objectID": "docs/api/prompt_strategies.alpaca_chat.html",
"href": "docs/api/prompt_strategies.alpaca_chat.html",
"title": "prompt_strategies.alpaca_chat",
"section": "",
"text": "prompt_strategies.alpaca_chat\nModule for Alpaca prompt strategy classes\n\n\n\n\n\nName\nDescription\n\n\n\n\nAlpacaChatPrompter\nAlpaca Chat Prompter extending the system prompt to for chat-instruct answers\n\n\nAlpacaConcisePrompter\nAlpaca Prompter extending the system prompt to ask for concise chat-instruct answers\n\n\nAlpacaQAPromptTokenizingStrategy\nTokenizing strategy for AlpacaQA\n\n\nCamelAIPromptTokenizingStrategy\nTokenizing strategy for CamelAI datasets\n\n\nNoSystemPrompter\nNull Prompter with no system prompts\n\n\n\n\n\nprompt_strategies.alpaca_chat.AlpacaChatPrompter()\nAlpaca Chat Prompter extending the system prompt to for chat-instruct answers\n\n\n\nprompt_strategies.alpaca_chat.AlpacaConcisePrompter(\n prompt_style=PromptStyle.INSTRUCT.value,\n)\nAlpaca Prompter extending the system prompt to ask for concise chat-instruct answers\n\n\n\nprompt_strategies.alpaca_chat.AlpacaQAPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for AlpacaQA\n\n\n\nprompt_strategies.alpaca_chat.CamelAIPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for CamelAI datasets\n\n\n\nprompt_strategies.alpaca_chat.NoSystemPrompter()\nNull Prompter with no system prompts"
},
{
"objectID": "docs/api/prompt_strategies.alpaca_chat.html#classes",
"href": "docs/api/prompt_strategies.alpaca_chat.html#classes",
"title": "prompt_strategies.alpaca_chat",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAlpacaChatPrompter\nAlpaca Chat Prompter extending the system prompt to for chat-instruct answers\n\n\nAlpacaConcisePrompter\nAlpaca Prompter extending the system prompt to ask for concise chat-instruct answers\n\n\nAlpacaQAPromptTokenizingStrategy\nTokenizing strategy for AlpacaQA\n\n\nCamelAIPromptTokenizingStrategy\nTokenizing strategy for CamelAI datasets\n\n\nNoSystemPrompter\nNull Prompter with no system prompts\n\n\n\n\n\nprompt_strategies.alpaca_chat.AlpacaChatPrompter()\nAlpaca Chat Prompter extending the system prompt to for chat-instruct answers\n\n\n\nprompt_strategies.alpaca_chat.AlpacaConcisePrompter(\n prompt_style=PromptStyle.INSTRUCT.value,\n)\nAlpaca Prompter extending the system prompt to ask for concise chat-instruct answers\n\n\n\nprompt_strategies.alpaca_chat.AlpacaQAPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for AlpacaQA\n\n\n\nprompt_strategies.alpaca_chat.CamelAIPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for CamelAI datasets\n\n\n\nprompt_strategies.alpaca_chat.NoSystemPrompter()\nNull Prompter with no system prompts"
},
{
"objectID": "docs/api/utils.collators.mm_chat.html",
"href": "docs/api/utils.collators.mm_chat.html",
"title": "utils.collators.mm_chat",
"section": "",
"text": "utils.collators.mm_chat\nCollators for multi-modal chat messages and packing\n\n\n\n\n\nName\nDescription\n\n\n\n\nMultiModalChatDataCollator\nCollator for multi-modal chat messages\n\n\n\n\n\nutils.collators.mm_chat.MultiModalChatDataCollator(\n tokenizer,\n processing_strategy,\n packing=False,\n return_tensors='pt',\n padding=True,\n pad_to_multiple_of=None,\n)\nCollator for multi-modal chat messages"
},
{
"objectID": "docs/api/utils.collators.mm_chat.html#classes",
"href": "docs/api/utils.collators.mm_chat.html#classes",
"title": "utils.collators.mm_chat",
"section": "",
"text": "Name\nDescription\n\n\n\n\nMultiModalChatDataCollator\nCollator for multi-modal chat messages\n\n\n\n\n\nutils.collators.mm_chat.MultiModalChatDataCollator(\n tokenizer,\n processing_strategy,\n packing=False,\n return_tensors='pt',\n padding=True,\n pad_to_multiple_of=None,\n)\nCollator for multi-modal chat messages"
},
{
"objectID": "docs/api/utils.schedulers.html",
"href": "docs/api/utils.schedulers.html",
"title": "utils.schedulers",
"section": "",
"text": "utils.schedulers\nModule for custom LRScheduler class\n\n\n\n\n\nName\nDescription\n\n\n\n\nInterpolatingLogScheduler\nA scheduler that interpolates learning rates in a logarithmic fashion\n\n\nJaggedLRRestartScheduler\nWraps another scheduler to apply per-lora-restart learning rate warmups.\n\n\nRexLR\nReflected Exponential (REX) learning rate scheduler.\n\n\n\n\n\nutils.schedulers.InterpolatingLogScheduler(\n optimizer,\n num_steps,\n min_lr,\n max_lr,\n last_epoch=-1,\n)\nA scheduler that interpolates learning rates in a logarithmic fashion\n\n\n\nutils.schedulers.JaggedLRRestartScheduler(\n optimizer,\n inner_schedule,\n jagged_restart_steps,\n jagged_restart_warmup_steps,\n jagged_restart_anneal_steps=1,\n min_lr_scale=0.001,\n)\nWraps another scheduler to apply per-lora-restart learning rate warmups.\n\n\n\n\n\nName\nDescription\n\n\n\n\nload_state_dict\nRestore state, including inner_schedule.\n\n\nstate_dict\nReturn serializable state, saving inner_schedule as its own state_dict.\n\n\n\n\n\nutils.schedulers.JaggedLRRestartScheduler.load_state_dict(state_dict)\nRestore state, including inner_schedule.\n\n\n\nutils.schedulers.JaggedLRRestartScheduler.state_dict()\nReturn serializable state, saving inner_schedule as its own state_dict.\n\n\n\n\n\nutils.schedulers.RexLR(\n optimizer,\n max_lr,\n min_lr,\n total_steps=0,\n num_warmup_steps=0,\n last_step=0,\n)\nReflected Exponential (REX) learning rate scheduler.\n\nOriginal implementation: https://github.com/IvanVassi/REX_LR\nOriginal license: Apache 2.0\nBased on: https://arxiv.org/abs/2107.04197\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\noptimizer\ntorch.optim.Optimizer\nThe optimizer to schedule the learning rate for.\nrequired\n\n\nmax_lr\nfloat\nThe maximum learning rate.\nrequired\n\n\nmin_lr\nfloat\nThe minimum learning rate.\nrequired\n\n\ntotal_steps\nint\nThe total number of training steps.\n0\n\n\nnum_warmup_steps\nint\nThe number of warmup steps.\n0\n\n\nlast_step\nint\nThe index of last step.\n0\n\n\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_cosine_schedule_with_min_lr\n\n\n\nget_cosine_schedule_with_quadratic_warmup\nCreate a schedule with a learning rate that decreases following the values of the cosine function between the\n\n\nget_cosine_schedule_with_warmup_decay_constant\nImplementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)\n\n\n\n\n\nutils.schedulers.get_cosine_schedule_with_min_lr(\n optimizer,\n num_warmup_steps,\n num_training_steps,\n min_lr_ratio=0.0,\n)\n\n\n\nlinear warmup from 0 -> max_lr over num_warmup_steps\ncosine learning rate annealing from max_lr -> min_lr over num_training_steps\n\n\n\n\n\nutils.schedulers.get_cosine_schedule_with_quadratic_warmup(\n optimizer,\n num_warmup_steps,\n num_training_steps,\n num_cycles=0.5,\n last_epoch=-1,\n)\nCreate a schedule with a learning rate that decreases following the values of the cosine function between the\ninitial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\ninitial lr set in the optimizer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\noptimizer\n[~torch.optim.Optimizer]\nThe optimizer for which to schedule the learning rate.\nrequired\n\n\nnum_warmup_steps\nint\nThe number of steps for the warmup phase.\nrequired\n\n\nnum_training_steps\nint\nThe total number of training steps.\nrequired\n\n\nnum_cycles\nfloat, optional, defaults to 0.5\nThe number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine).\n0.5\n\n\nlast_epoch\nint, optional, defaults to -1\nThe index of the last epoch when resuming training.\n-1\n\n\n\n\n\n\ntorch.optim.lr_scheduler.LambdaLR with the appropriate schedule.\n\n\n\n\nutils.schedulers.get_cosine_schedule_with_warmup_decay_constant(\n optimizer,\n num_warmup_steps,\n num_training_steps,\n constant_lr_ratio,\n min_lr_ratio,\n num_cycles=0.5,\n last_epoch=-1,\n)\nImplementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)\nCreate a schedule with a learning rate that decreases following the values of the cosine function between the\ninitial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate\n, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\noptimizer\n[~torch.optim.Optimizer]\nThe optimizer for which to schedule the learning rate.\nrequired\n\n\nnum_warmup_steps\nint\nThe number of steps for the warmup phase.\nrequired\n\n\nnum_training_steps\nint\nThe total number of training steps.\nrequired\n\n\nconstant_lr_ratio\nfloat\n(float): The ratio of num_training_steps to decrease by cosine function.\nrequired\n\n\nmin_lr_ratio\nfloat\n(float): The ratio of maximum learning rate for cosine function to decay to minimum learning rate. | _required_ | | num_cycles |float, *optional*, defaults to 0.5 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine). |0.5| | last_epoch |int, *optional*, defaults to -1 | The index of the last epoch when resuming training. |-1`\n\n\n\n\n\n\n\ntorch.optim.lr_scheduler.LambdaLR with the appropriate schedule."
},
{
"objectID": "docs/api/utils.schedulers.html#classes",
"href": "docs/api/utils.schedulers.html#classes",
"title": "utils.schedulers",
"section": "",
"text": "Name\nDescription\n\n\n\n\nInterpolatingLogScheduler\nA scheduler that interpolates learning rates in a logarithmic fashion\n\n\nJaggedLRRestartScheduler\nWraps another scheduler to apply per-lora-restart learning rate warmups.\n\n\nRexLR\nReflected Exponential (REX) learning rate scheduler.\n\n\n\n\n\nutils.schedulers.InterpolatingLogScheduler(\n optimizer,\n num_steps,\n min_lr,\n max_lr,\n last_epoch=-1,\n)\nA scheduler that interpolates learning rates in a logarithmic fashion\n\n\n\nutils.schedulers.JaggedLRRestartScheduler(\n optimizer,\n inner_schedule,\n jagged_restart_steps,\n jagged_restart_warmup_steps,\n jagged_restart_anneal_steps=1,\n min_lr_scale=0.001,\n)\nWraps another scheduler to apply per-lora-restart learning rate warmups.\n\n\n\n\n\nName\nDescription\n\n\n\n\nload_state_dict\nRestore state, including inner_schedule.\n\n\nstate_dict\nReturn serializable state, saving inner_schedule as its own state_dict.\n\n\n\n\n\nutils.schedulers.JaggedLRRestartScheduler.load_state_dict(state_dict)\nRestore state, including inner_schedule.\n\n\n\nutils.schedulers.JaggedLRRestartScheduler.state_dict()\nReturn serializable state, saving inner_schedule as its own state_dict.\n\n\n\n\n\nutils.schedulers.RexLR(\n optimizer,\n max_lr,\n min_lr,\n total_steps=0,\n num_warmup_steps=0,\n last_step=0,\n)\nReflected Exponential (REX) learning rate scheduler.\n\nOriginal implementation: https://github.com/IvanVassi/REX_LR\nOriginal license: Apache 2.0\nBased on: https://arxiv.org/abs/2107.04197\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\noptimizer\ntorch.optim.Optimizer\nThe optimizer to schedule the learning rate for.\nrequired\n\n\nmax_lr\nfloat\nThe maximum learning rate.\nrequired\n\n\nmin_lr\nfloat\nThe minimum learning rate.\nrequired\n\n\ntotal_steps\nint\nThe total number of training steps.\n0\n\n\nnum_warmup_steps\nint\nThe number of warmup steps.\n0\n\n\nlast_step\nint\nThe index of last step.\n0"
},
{
"objectID": "docs/api/utils.schedulers.html#functions",
"href": "docs/api/utils.schedulers.html#functions",
"title": "utils.schedulers",
"section": "",
"text": "Name\nDescription\n\n\n\n\nget_cosine_schedule_with_min_lr\n\n\n\nget_cosine_schedule_with_quadratic_warmup\nCreate a schedule with a learning rate that decreases following the values of the cosine function between the\n\n\nget_cosine_schedule_with_warmup_decay_constant\nImplementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)\n\n\n\n\n\nutils.schedulers.get_cosine_schedule_with_min_lr(\n optimizer,\n num_warmup_steps,\n num_training_steps,\n min_lr_ratio=0.0,\n)\n\n\n\nlinear warmup from 0 -> max_lr over num_warmup_steps\ncosine learning rate annealing from max_lr -> min_lr over num_training_steps\n\n\n\n\n\nutils.schedulers.get_cosine_schedule_with_quadratic_warmup(\n optimizer,\n num_warmup_steps,\n num_training_steps,\n num_cycles=0.5,\n last_epoch=-1,\n)\nCreate a schedule with a learning rate that decreases following the values of the cosine function between the\ninitial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\ninitial lr set in the optimizer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\noptimizer\n[~torch.optim.Optimizer]\nThe optimizer for which to schedule the learning rate.\nrequired\n\n\nnum_warmup_steps\nint\nThe number of steps for the warmup phase.\nrequired\n\n\nnum_training_steps\nint\nThe total number of training steps.\nrequired\n\n\nnum_cycles\nfloat, optional, defaults to 0.5\nThe number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine).\n0.5\n\n\nlast_epoch\nint, optional, defaults to -1\nThe index of the last epoch when resuming training.\n-1\n\n\n\n\n\n\ntorch.optim.lr_scheduler.LambdaLR with the appropriate schedule.\n\n\n\n\nutils.schedulers.get_cosine_schedule_with_warmup_decay_constant(\n optimizer,\n num_warmup_steps,\n num_training_steps,\n constant_lr_ratio,\n min_lr_ratio,\n num_cycles=0.5,\n last_epoch=-1,\n)\nImplementation of Continual Pre-Training of Large Language Models: How to (re)warm your model? (https://arxiv.org/pdf/2308.04014.pdf)\nCreate a schedule with a learning rate that decreases following the values of the cosine function between the\ninitial lr set in the optimizer to min_lr_ratio until num_training_steps * constant_lr_ratio, after constant_rate returns constant value of min_rate\n, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\noptimizer\n[~torch.optim.Optimizer]\nThe optimizer for which to schedule the learning rate.\nrequired\n\n\nnum_warmup_steps\nint\nThe number of steps for the warmup phase.\nrequired\n\n\nnum_training_steps\nint\nThe total number of training steps.\nrequired\n\n\nconstant_lr_ratio\nfloat\n(float): The ratio of num_training_steps to decrease by cosine function.\nrequired\n\n\nmin_lr_ratio\nfloat\n(float): The ratio of maximum learning rate for cosine function to decay to minimum learning rate. | _required_ | | num_cycles |float, *optional*, defaults to 0.5 | The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 following a half-cosine). |0.5| | last_epoch |int, *optional*, defaults to -1 | The index of the last epoch when resuming training. |-1`\n\n\n\n\n\n\n\ntorch.optim.lr_scheduler.LambdaLR with the appropriate schedule."
},
{
"objectID": "docs/api/kernels.utils.html",
"href": "docs/api/kernels.utils.html",
"title": "kernels.utils",
"section": "",
"text": "kernels.utils\nkernels.utils\nUtilities for axolotl.kernels submodules."
},
{
"objectID": "docs/api/core.chat.format.chatml.html",
"href": "docs/api/core.chat.format.chatml.html",
"title": "core.chat.format.chatml",
"section": "",
"text": "core.chat.format.chatml\ncore.chat.format.chatml\nChatML transformation functions for MessageContents"
},
{
"objectID": "docs/api/loaders.constants.html",
"href": "docs/api/loaders.constants.html",
"title": "loaders.constants",
"section": "",
"text": "loaders.constants\nloaders.constants\nShared constants for axolotl.loaders module"
},
{
"objectID": "docs/api/utils.schemas.model.html",
"href": "docs/api/utils.schemas.model.html",
"title": "utils.schemas.model",
"section": "",
"text": "utils.schemas.model\nPydantic models for model input / output, etc. configuration\n\n\n\n\n\nName\nDescription\n\n\n\n\nModelInputConfig\nModel configuration subset\n\n\nModelOutputConfig\nmodel save configuration subset\n\n\nSpecialTokensConfig\nSpecial tokens configuration subset\n\n\n\n\n\nutils.schemas.model.ModelInputConfig()\nModel configuration subset\n\n\n\nutils.schemas.model.ModelOutputConfig()\nmodel save configuration subset\n\n\n\nutils.schemas.model.SpecialTokensConfig()\nSpecial tokens configuration subset"
},
{
"objectID": "docs/api/utils.schemas.model.html#classes",
"href": "docs/api/utils.schemas.model.html#classes",
"title": "utils.schemas.model",
"section": "",
"text": "Name\nDescription\n\n\n\n\nModelInputConfig\nModel configuration subset\n\n\nModelOutputConfig\nmodel save configuration subset\n\n\nSpecialTokensConfig\nSpecial tokens configuration subset\n\n\n\n\n\nutils.schemas.model.ModelInputConfig()\nModel configuration subset\n\n\n\nutils.schemas.model.ModelOutputConfig()\nmodel save configuration subset\n\n\n\nutils.schemas.model.SpecialTokensConfig()\nSpecial tokens configuration subset"
},
{
"objectID": "docs/api/integrations.grokfast.optimizer.html",
"href": "docs/api/integrations.grokfast.optimizer.html",
"title": "integrations.grokfast.optimizer",
"section": "",
"text": "integrations.grokfast.optimizer\nintegrations.grokfast.optimizer"
},
{
"objectID": "docs/api/cli.utils.load.html",
"href": "docs/api/cli.utils.load.html",
"title": "cli.utils.load",
"section": "",
"text": "cli.utils.load\nUtilities for model, tokenizer, etc. loading.\n\n\n\n\n\nName\nDescription\n\n\n\n\nload_model_and_tokenizer\nHelper function for loading a model, tokenizer, and processor specified in the\n\n\n\n\n\ncli.utils.load.load_model_and_tokenizer(cfg, inference=False)\nHelper function for loading a model, tokenizer, and processor specified in the\ngiven axolotl config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ninference\nbool\nBoolean denoting inference mode.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any, ProcessorMixin | None]\nTuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin)."
},
{
"objectID": "docs/api/cli.utils.load.html#functions",
"href": "docs/api/cli.utils.load.html#functions",
"title": "cli.utils.load",
"section": "",
"text": "Name\nDescription\n\n\n\n\nload_model_and_tokenizer\nHelper function for loading a model, tokenizer, and processor specified in the\n\n\n\n\n\ncli.utils.load.load_model_and_tokenizer(cfg, inference=False)\nHelper function for loading a model, tokenizer, and processor specified in the\ngiven axolotl config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ninference\nbool\nBoolean denoting inference mode.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any, ProcessorMixin | None]\nTuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin)."
},
{
"objectID": "docs/api/loaders.adapter.html",
"href": "docs/api/loaders.adapter.html",
"title": "loaders.adapter",
"section": "",
"text": "loaders.adapter\nAdapter loading functionality, including LoRA / QLoRA and associated utils\n\n\n\n\n\nName\nDescription\n\n\n\n\nsetup_quantized_meta_for_peft\nReplaces quant_state.to with a dummy function to prevent PEFT from moving quant_state to meta device\n\n\nsetup_quantized_peft_meta_for_training\nReplaces dummy quant_state.to method with the original function to allow training to continue\n\n\n\n\n\nloaders.adapter.setup_quantized_meta_for_peft(model)\nReplaces quant_state.to with a dummy function to prevent PEFT from moving quant_state to meta device\n\n\n\nloaders.adapter.setup_quantized_peft_meta_for_training(model)\nReplaces dummy quant_state.to method with the original function to allow training to continue"
},
{
"objectID": "docs/api/loaders.adapter.html#functions",
"href": "docs/api/loaders.adapter.html#functions",
"title": "loaders.adapter",
"section": "",
"text": "Name\nDescription\n\n\n\n\nsetup_quantized_meta_for_peft\nReplaces quant_state.to with a dummy function to prevent PEFT from moving quant_state to meta device\n\n\nsetup_quantized_peft_meta_for_training\nReplaces dummy quant_state.to method with the original function to allow training to continue\n\n\n\n\n\nloaders.adapter.setup_quantized_meta_for_peft(model)\nReplaces quant_state.to with a dummy function to prevent PEFT from moving quant_state to meta device\n\n\n\nloaders.adapter.setup_quantized_peft_meta_for_training(model)\nReplaces dummy quant_state.to method with the original function to allow training to continue"
},
{
"objectID": "docs/api/cli.train.html",
"href": "docs/api/cli.train.html",
"title": "cli.train",
"section": "",
"text": "cli.train\nCLI to run training on a model.\n\n\n\n\n\nName\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_train.\n\n\ndo_train\nTrains a transformers model by first loading the dataset(s) specified in the\n\n\n\n\n\ncli.train.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls do_train.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.train.do_train(cfg, cli_args)\nTrains a transformers model by first loading the dataset(s) specified in the\naxolotl config, and then calling axolotl.train.train. Also runs the plugin\nmanagers post_train_unload once training completes.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nTrainerCliArgs\nTraining-specific CLI arguments.\nrequired"
},
{
"objectID": "docs/api/cli.train.html#functions",
"href": "docs/api/cli.train.html#functions",
"title": "cli.train",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_train.\n\n\ndo_train\nTrains a transformers model by first loading the dataset(s) specified in the\n\n\n\n\n\ncli.train.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls do_train.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.train.do_train(cfg, cli_args)\nTrains a transformers model by first loading the dataset(s) specified in the\naxolotl config, and then calling axolotl.train.train. Also runs the plugin\nmanagers post_train_unload once training completes.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nTrainerCliArgs\nTraining-specific CLI arguments.\nrequired"
},
{
"objectID": "docs/api/monkeypatch.stablelm_attn_hijack_flash.html",
"href": "docs/api/monkeypatch.stablelm_attn_hijack_flash.html",
"title": "monkeypatch.stablelm_attn_hijack_flash",
"section": "",
"text": "monkeypatch.stablelm_attn_hijack_flash\nPyTorch StableLM Epoch model.\n\n\n\n\n\nName\nDescription\n\n\n\n\nrepeat_kv\nThis is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n\n\nrotate_half\nRotates half the hidden dims of the input.\n\n\n\n\n\nmonkeypatch.stablelm_attn_hijack_flash.repeat_kv(hidden_states, n_rep)\nThis is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\nnum_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n\n\n\nmonkeypatch.stablelm_attn_hijack_flash.rotate_half(x)\nRotates half the hidden dims of the input."
},
{
"objectID": "docs/api/monkeypatch.stablelm_attn_hijack_flash.html#functions",
"href": "docs/api/monkeypatch.stablelm_attn_hijack_flash.html#functions",
"title": "monkeypatch.stablelm_attn_hijack_flash",
"section": "",
"text": "Name\nDescription\n\n\n\n\nrepeat_kv\nThis is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n\n\nrotate_half\nRotates half the hidden dims of the input.\n\n\n\n\n\nmonkeypatch.stablelm_attn_hijack_flash.repeat_kv(hidden_states, n_rep)\nThis is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\nnum_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n\n\n\nmonkeypatch.stablelm_attn_hijack_flash.rotate_half(x)\nRotates half the hidden dims of the input."
},
{
"objectID": "docs/api/cli.checks.html",
"href": "docs/api/cli.checks.html",
"title": "cli.checks",
"section": "",
"text": "cli.checks\nVarious checks for Axolotl CLI.\n\n\n\n\n\nName\nDescription\n\n\n\n\ncheck_accelerate_default_config\nLogs at warning level if no accelerate config file is found.\n\n\ncheck_user_token\nChecks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.\n\n\n\n\n\ncli.checks.check_accelerate_default_config()\nLogs at warning level if no accelerate config file is found.\n\n\n\ncli.checks.check_user_token()\nChecks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nbool\nBoolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved).\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nLocalTokenNotFoundError\nIf HF user info cant be retrieved."
},
{
"objectID": "docs/api/cli.checks.html#functions",
"href": "docs/api/cli.checks.html#functions",
"title": "cli.checks",
"section": "",
"text": "Name\nDescription\n\n\n\n\ncheck_accelerate_default_config\nLogs at warning level if no accelerate config file is found.\n\n\ncheck_user_token\nChecks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.\n\n\n\n\n\ncli.checks.check_accelerate_default_config()\nLogs at warning level if no accelerate config file is found.\n\n\n\ncli.checks.check_user_token()\nChecks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nbool\nBoolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved).\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nLocalTokenNotFoundError\nIf HF user info cant be retrieved."
},
{
"objectID": "docs/api/prompt_strategies.dpo.user_defined.html",
"href": "docs/api/prompt_strategies.dpo.user_defined.html",
"title": "prompt_strategies.dpo.user_defined",
"section": "",
"text": "prompt_strategies.dpo.user_defined\nprompt_strategies.dpo.user_defined\nUser-defined DPO strategies"
},
{
"objectID": "docs/api/prompt_strategies.llama2_chat.html",
"href": "docs/api/prompt_strategies.llama2_chat.html",
"title": "prompt_strategies.llama2_chat",
"section": "",
"text": "prompt_strategies.llama2_chat\nPrompt Strategy for finetuning Llama2 chat models\nsee also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.\nThis implementation is based on the Vicuna PR and the fastchat repo, see also:\nhttps://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847\nUse dataset type: “llama2_chat” in conig.yml to use this prompt style.\nE.g. in the config.yml:\ndatasets:\n - path: llama_finetune_train.jsonl\n type: llama2_chat\nThe dataset itself should look like this:\n{'conversations':[{\"from\": \"human\", \"value\": \"Who are you?\"}, {\"from\": \"gpt\", \"value\": \"I am Vicuna\"},...]}\nin a jsonl file. The first message should be from the human, the second from gpt.\nFor a custom system message, the first “from” can be “system” (followed by alternating “human” and “gpt” turns).\nImportant: Dont use “special_tokens:” in your config.yml if you are not sure what you are doing!\n\n\n\n\n\nName\nDescription\n\n\n\n\nLLama2ChatTokenizingStrategy\nTokenizing strategy for Llama2 prompts.\n\n\nLlama2ChatConversation\nA class that manages prompt templates and keeps all conversation history.\n\n\nLlama2ChatPrompter\nA prompter that generates prompts for Llama2 models.\n\n\n\n\n\nprompt_strategies.llama2_chat.LLama2ChatTokenizingStrategy(*args, **kwargs)\nTokenizing strategy for Llama2 prompts.\nadapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py\n\n\n\nprompt_strategies.llama2_chat.Llama2ChatConversation(\n name='llama2',\n system=\"[INST] <<SYS>>\\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\\n<</SYS>>\\n\\n\",\n roles=('[INST]', '[/INST]'),\n messages=list(),\n offset=0,\n)\nA class that manages prompt templates and keeps all conversation history.\ncopied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py\n\n\n\n\n\nName\nDescription\n\n\n\n\nappend_message\nAppend a new message.\n\n\nget_prompt\nGet the prompt for generation.\n\n\n\n\n\nprompt_strategies.llama2_chat.Llama2ChatConversation.append_message(\n role,\n message,\n)\nAppend a new message.\n\n\n\nprompt_strategies.llama2_chat.Llama2ChatConversation.get_prompt()\nGet the prompt for generation.\n\n\n\n\n\nprompt_strategies.llama2_chat.Llama2ChatPrompter()\nA prompter that generates prompts for Llama2 models."
},
{
"objectID": "docs/api/prompt_strategies.llama2_chat.html#classes",
"href": "docs/api/prompt_strategies.llama2_chat.html#classes",
"title": "prompt_strategies.llama2_chat",
"section": "",
"text": "Name\nDescription\n\n\n\n\nLLama2ChatTokenizingStrategy\nTokenizing strategy for Llama2 prompts.\n\n\nLlama2ChatConversation\nA class that manages prompt templates and keeps all conversation history.\n\n\nLlama2ChatPrompter\nA prompter that generates prompts for Llama2 models.\n\n\n\n\n\nprompt_strategies.llama2_chat.LLama2ChatTokenizingStrategy(*args, **kwargs)\nTokenizing strategy for Llama2 prompts.\nadapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py\n\n\n\nprompt_strategies.llama2_chat.Llama2ChatConversation(\n name='llama2',\n system=\"[INST] <<SYS>>\\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\\n<</SYS>>\\n\\n\",\n roles=('[INST]', '[/INST]'),\n messages=list(),\n offset=0,\n)\nA class that manages prompt templates and keeps all conversation history.\ncopied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py\n\n\n\n\n\nName\nDescription\n\n\n\n\nappend_message\nAppend a new message.\n\n\nget_prompt\nGet the prompt for generation.\n\n\n\n\n\nprompt_strategies.llama2_chat.Llama2ChatConversation.append_message(\n role,\n message,\n)\nAppend a new message.\n\n\n\nprompt_strategies.llama2_chat.Llama2ChatConversation.get_prompt()\nGet the prompt for generation.\n\n\n\n\n\nprompt_strategies.llama2_chat.Llama2ChatPrompter()\nA prompter that generates prompts for Llama2 models."
},
{
"objectID": "docs/api/core.trainers.trl.html",
"href": "docs/api/core.trainers.trl.html",
"title": "core.trainers.trl",
"section": "",
"text": "core.trainers.trl\nModule for TRL RL trainers\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlCPOTrainer\nExtend the base CPOTrainer for axolotl helpers\n\n\nAxolotlKTOTrainer\nExtend the base KTOTrainer for axolotl helpers\n\n\nAxolotlORPOTrainer\nExtend the base ORPOTrainer for axolotl helpers\n\n\nAxolotlPRMTrainer\nExtend the base trl.PRMTrainer for axolotl helpers\n\n\nAxolotlRewardTrainer\nExtend the base RewardTrainer for axolotl helpers\n\n\n\n\n\ncore.trainers.trl.AxolotlCPOTrainer(*args, **kwargs)\nExtend the base CPOTrainer for axolotl helpers\n\n\n\ncore.trainers.trl.AxolotlKTOTrainer(*args, **kwargs)\nExtend the base KTOTrainer for axolotl helpers\n\n\n\ncore.trainers.trl.AxolotlORPOTrainer(*args, **kwargs)\nExtend the base ORPOTrainer for axolotl helpers\n\n\n\ncore.trainers.trl.AxolotlPRMTrainer(*args, **kwargs)\nExtend the base trl.PRMTrainer for axolotl helpers\n\n\n\ncore.trainers.trl.AxolotlRewardTrainer(*args, **kwargs)\nExtend the base RewardTrainer for axolotl helpers"
},
{
"objectID": "docs/api/core.trainers.trl.html#classes",
"href": "docs/api/core.trainers.trl.html#classes",
"title": "core.trainers.trl",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlCPOTrainer\nExtend the base CPOTrainer for axolotl helpers\n\n\nAxolotlKTOTrainer\nExtend the base KTOTrainer for axolotl helpers\n\n\nAxolotlORPOTrainer\nExtend the base ORPOTrainer for axolotl helpers\n\n\nAxolotlPRMTrainer\nExtend the base trl.PRMTrainer for axolotl helpers\n\n\nAxolotlRewardTrainer\nExtend the base RewardTrainer for axolotl helpers\n\n\n\n\n\ncore.trainers.trl.AxolotlCPOTrainer(*args, **kwargs)\nExtend the base CPOTrainer for axolotl helpers\n\n\n\ncore.trainers.trl.AxolotlKTOTrainer(*args, **kwargs)\nExtend the base KTOTrainer for axolotl helpers\n\n\n\ncore.trainers.trl.AxolotlORPOTrainer(*args, **kwargs)\nExtend the base ORPOTrainer for axolotl helpers\n\n\n\ncore.trainers.trl.AxolotlPRMTrainer(*args, **kwargs)\nExtend the base trl.PRMTrainer for axolotl helpers\n\n\n\ncore.trainers.trl.AxolotlRewardTrainer(*args, **kwargs)\nExtend the base RewardTrainer for axolotl helpers"
},
{
"objectID": "docs/api/monkeypatch.mistral_attn_hijack_flash.html",
"href": "docs/api/monkeypatch.mistral_attn_hijack_flash.html",
"title": "monkeypatch.mistral_attn_hijack_flash",
"section": "",
"text": "monkeypatch.mistral_attn_hijack_flash\nmonkeypatch.mistral_attn_hijack_flash\nFlash attention monkey patch for mistral model"
},
{
"objectID": "docs/api/core.trainers.mixins.scheduler.html",
"href": "docs/api/core.trainers.mixins.scheduler.html",
"title": "core.trainers.mixins.scheduler",
"section": "",
"text": "core.trainers.mixins.scheduler\nModule for Axolotl trainer scheduler mixin\n\n\n\n\n\nName\nDescription\n\n\n\n\nSchedulerMixin\nMixin class for scheduler setup in CausalTrainer.\n\n\n\n\n\ncore.trainers.mixins.scheduler.SchedulerMixin()\nMixin class for scheduler setup in CausalTrainer.\n\n\n\n\n\nName\nDescription\n\n\n\n\ncreate_scheduler\nSet up the scheduler. The optimizer of the trainer must have been set up either before this method is called or\n\n\n\n\n\ncore.trainers.mixins.scheduler.SchedulerMixin.create_scheduler(\n num_training_steps,\n optimizer=None,\n)\nSet up the scheduler. The optimizer of the trainer must have been set up either before this method is called or\npassed as an argument.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nnum_training_steps\nint\nThe number of training steps to do.\nrequired\n\n\noptimizer\ntorch.optim.Optimizer\nThe training optimizer\nNone"
},
{
"objectID": "docs/api/core.trainers.mixins.scheduler.html#classes",
"href": "docs/api/core.trainers.mixins.scheduler.html#classes",
"title": "core.trainers.mixins.scheduler",
"section": "",
"text": "Name\nDescription\n\n\n\n\nSchedulerMixin\nMixin class for scheduler setup in CausalTrainer.\n\n\n\n\n\ncore.trainers.mixins.scheduler.SchedulerMixin()\nMixin class for scheduler setup in CausalTrainer.\n\n\n\n\n\nName\nDescription\n\n\n\n\ncreate_scheduler\nSet up the scheduler. The optimizer of the trainer must have been set up either before this method is called or\n\n\n\n\n\ncore.trainers.mixins.scheduler.SchedulerMixin.create_scheduler(\n num_training_steps,\n optimizer=None,\n)\nSet up the scheduler. The optimizer of the trainer must have been set up either before this method is called or\npassed as an argument.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nnum_training_steps\nint\nThe number of training steps to do.\nrequired\n\n\noptimizer\ntorch.optim.Optimizer\nThe training optimizer\nNone"
},
{
"objectID": "docs/api/core.trainers.grpo.trainer.html",
"href": "docs/api/core.trainers.grpo.trainer.html",
"title": "core.trainers.grpo.trainer",
"section": "",
"text": "core.trainers.grpo.trainer\nAxolotl GRPO trainers (with and without sequence parallelism handling)\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlAsyncGRPOTrainer\nExtend AsyncGRPOTrainer with axolotl helpers\n\n\nAxolotlGRPOSequenceParallelTrainer\nExtend the base GRPOTrainer for sequence parallelism handling\n\n\nAxolotlGRPOTrainer\nExtend the base GRPOTrainer for axolotl helpers\n\n\n\n\n\ncore.trainers.grpo.trainer.AxolotlAsyncGRPOTrainer(*args, **kwargs)\nExtend AsyncGRPOTrainer with axolotl helpers\n\n\n\ncore.trainers.grpo.trainer.AxolotlGRPOSequenceParallelTrainer(\n model,\n reward_funcs,\n args=None,\n train_dataset=None,\n eval_dataset=None,\n processing_class=None,\n reward_processing_classes=None,\n callbacks=None,\n optimizers=(None, None),\n peft_config=None,\n optimizer_cls_and_kwargs=None,\n)\nExtend the base GRPOTrainer for sequence parallelism handling\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_train_dataloader\nGet dataloader for training\n\n\n\n\n\ncore.trainers.grpo.trainer.AxolotlGRPOSequenceParallelTrainer.get_train_dataloader(\n)\nGet dataloader for training\n\n\n\n\n\ncore.trainers.grpo.trainer.AxolotlGRPOTrainer(*args, **kwargs)\nExtend the base GRPOTrainer for axolotl helpers"
},
{
"objectID": "docs/api/core.trainers.grpo.trainer.html#classes",
"href": "docs/api/core.trainers.grpo.trainer.html#classes",
"title": "core.trainers.grpo.trainer",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlAsyncGRPOTrainer\nExtend AsyncGRPOTrainer with axolotl helpers\n\n\nAxolotlGRPOSequenceParallelTrainer\nExtend the base GRPOTrainer for sequence parallelism handling\n\n\nAxolotlGRPOTrainer\nExtend the base GRPOTrainer for axolotl helpers\n\n\n\n\n\ncore.trainers.grpo.trainer.AxolotlAsyncGRPOTrainer(*args, **kwargs)\nExtend AsyncGRPOTrainer with axolotl helpers\n\n\n\ncore.trainers.grpo.trainer.AxolotlGRPOSequenceParallelTrainer(\n model,\n reward_funcs,\n args=None,\n train_dataset=None,\n eval_dataset=None,\n processing_class=None,\n reward_processing_classes=None,\n callbacks=None,\n optimizers=(None, None),\n peft_config=None,\n optimizer_cls_and_kwargs=None,\n)\nExtend the base GRPOTrainer for sequence parallelism handling\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_train_dataloader\nGet dataloader for training\n\n\n\n\n\ncore.trainers.grpo.trainer.AxolotlGRPOSequenceParallelTrainer.get_train_dataloader(\n)\nGet dataloader for training\n\n\n\n\n\ncore.trainers.grpo.trainer.AxolotlGRPOTrainer(*args, **kwargs)\nExtend the base GRPOTrainer for axolotl helpers"
},
{
"objectID": "docs/api/cli.merge_lora.html",
"href": "docs/api/cli.merge_lora.html",
"title": "cli.merge_lora",
"section": "",
"text": "cli.merge_lora\nCLI to merge a trained LoRA into a base model.\n\n\n\n\n\nName\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_merge_lora. Note that various\n\n\ndo_merge_lora\nMerges LoRA adapters with base model using either memory-efficient or legacy approach.\n\n\n\n\n\ncli.merge_lora.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls do_merge_lora. Note that various\nconfig values will be overwritten to allow the LoRA merge logic to work as expected\n(load_in_8bit=False, load_in4bit=False, flash_attention=False, etc.).\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf target directory for LoRA merged model does not exist.\n\n\n\n\n\n\n\ncli.merge_lora.do_merge_lora(cfg)\nMerges LoRA adapters with base model using either memory-efficient or legacy approach.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired"
},
{
"objectID": "docs/api/cli.merge_lora.html#functions",
"href": "docs/api/cli.merge_lora.html#functions",
"title": "cli.merge_lora",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_merge_lora. Note that various\n\n\ndo_merge_lora\nMerges LoRA adapters with base model using either memory-efficient or legacy approach.\n\n\n\n\n\ncli.merge_lora.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls do_merge_lora. Note that various\nconfig values will be overwritten to allow the LoRA merge logic to work as expected\n(load_in_8bit=False, load_in4bit=False, flash_attention=False, etc.).\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nValueError\nIf target directory for LoRA merged model does not exist.\n\n\n\n\n\n\n\ncli.merge_lora.do_merge_lora(cfg)\nMerges LoRA adapters with base model using either memory-efficient or legacy approach.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired"
},
{
"objectID": "docs/api/datasets.html",
"href": "docs/api/datasets.html",
"title": "datasets",
"section": "",
"text": "datasets\nModule containing dataset functionality.\nWe want this to be a wrapper for an existing dataset that we have loaded. Lets use the\nconcept of middlewares to wrap each dataset. Well use the collators later on to pad the\ndatasets.\n\n\n\n\n\nName\nDescription\n\n\n\n\nTokenizedPromptDataset\nDataset that returns tokenized prompts from a stream of text files.\n\n\n\n\n\ndatasets.TokenizedPromptDataset(\n prompt_tokenizer,\n dataset,\n process_count=None,\n keep_in_memory=False,\n **kwargs,\n)\nDataset that returns tokenized prompts from a stream of text files.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nprompt_tokenizer\nPromptTokenizingStrategy\nThe prompt tokenizing method for processing the data.\nrequired\n\n\ndataset\nDataset\nDataset with text files.\nrequired\n\n\nprocess_count\nint | None\nNumber of processes to use for tokenizing.\nNone\n\n\nkeep_in_memory\nbool | None\nWhether to keep the tokenized dataset in memory.\nFalse"
},
{
"objectID": "docs/api/datasets.html#classes",
"href": "docs/api/datasets.html#classes",
"title": "datasets",
"section": "",
"text": "Name\nDescription\n\n\n\n\nTokenizedPromptDataset\nDataset that returns tokenized prompts from a stream of text files.\n\n\n\n\n\ndatasets.TokenizedPromptDataset(\n prompt_tokenizer,\n dataset,\n process_count=None,\n keep_in_memory=False,\n **kwargs,\n)\nDataset that returns tokenized prompts from a stream of text files.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nprompt_tokenizer\nPromptTokenizingStrategy\nThe prompt tokenizing method for processing the data.\nrequired\n\n\ndataset\nDataset\nDataset with text files.\nrequired\n\n\nprocess_count\nint | None\nNumber of processes to use for tokenizing.\nNone\n\n\nkeep_in_memory\nbool | None\nWhether to keep the tokenized dataset in memory.\nFalse"
},
{
"objectID": "docs/api/utils.schemas.training.html",
"href": "docs/api/utils.schemas.training.html",
"title": "utils.schemas.training",
"section": "",
"text": "utils.schemas.training\nPydantic models for training hyperparameters\n\n\n\n\n\nName\nDescription\n\n\n\n\nHyperparametersConfig\nTraining hyperparams configuration subset\n\n\nJaggedLRConfig\nJaggedLR configuration subset, can be used w/ ReLoRA training\n\n\nLrGroup\nCustom learning rate group configuration\n\n\n\n\n\nutils.schemas.training.HyperparametersConfig()\nTraining hyperparams configuration subset\n\n\n\nutils.schemas.training.JaggedLRConfig()\nJaggedLR configuration subset, can be used w/ ReLoRA training\n\n\n\nutils.schemas.training.LrGroup()\nCustom learning rate group configuration"
},
{
"objectID": "docs/api/utils.schemas.training.html#classes",
"href": "docs/api/utils.schemas.training.html#classes",
"title": "utils.schemas.training",
"section": "",
"text": "Name\nDescription\n\n\n\n\nHyperparametersConfig\nTraining hyperparams configuration subset\n\n\nJaggedLRConfig\nJaggedLR configuration subset, can be used w/ ReLoRA training\n\n\nLrGroup\nCustom learning rate group configuration\n\n\n\n\n\nutils.schemas.training.HyperparametersConfig()\nTraining hyperparams configuration subset\n\n\n\nutils.schemas.training.JaggedLRConfig()\nJaggedLR configuration subset, can be used w/ ReLoRA training\n\n\n\nutils.schemas.training.LrGroup()\nCustom learning rate group configuration"
},
{
"objectID": "docs/api/utils.distributed.html",
"href": "docs/api/utils.distributed.html",
"title": "utils.distributed",
"section": "",
"text": "utils.distributed\nUtilities for distributed functionality.\n\n\n\n\n\nName\nDescription\n\n\n\n\nbarrier\nActs as a barrier to wait for all processes. This ensures that all processes\n\n\ncleanup_distributed\nDestroy process group if torch distributed is initialized. Called in training early\n\n\ncompute_and_broadcast\nCompute a value using the function fn only on the specified rank (default is 0).\n\n\ngather_from_all_ranks\nRun a callable fn on all ranks and gather the results on the specified rank.\n\n\ngather_scalar_from_all_ranks\nRun a callable fn on all ranks and gather the results on the specified rank.\n\n\nis_distributed\nCheck if distributed training is initialized.\n\n\nis_main_process\nCheck if the current process is the main process. If not in distributed mode,\n\n\nreduce_and_broadcast\nRun a callable fn1 on all ranks, gather the results, reduce them using fn2,\n\n\nzero_first\nruns the wrapped context so that rank 0 runs first before other ranks\n\n\n\n\n\nutils.distributed.barrier()\nActs as a barrier to wait for all processes. This ensures that all processes\nreach the barrier before proceeding further.\n\n\n\nutils.distributed.cleanup_distributed()\nDestroy process group if torch distributed is initialized. Called in training early\ntermination or when training successfully completes.\n\n\n\nutils.distributed.compute_and_broadcast(fn)\nCompute a value using the function fn only on the specified rank (default is 0).\nThe value is then broadcasted to all other ranks.\nArgs:\n- fn (callable): A function that computes the value. This should not have any side effects.\n- rank (int, optional): The rank that computes the value. Default is 0.\nReturns:\n- The computed value (int or float).\n\n\n\nutils.distributed.gather_from_all_ranks(fn, world_size=1)\nRun a callable fn on all ranks and gather the results on the specified rank.\nArgs:\n- fn (callable): A function that computes the value. This should not have any side effects.\n- rank (int, optional): The rank that gathers the values. Default is 0.\n- world_size (int, optional): Total number of processes in the current distributed setup.\nReturns:\n- A list of computed values from all ranks if on the gathering rank, otherwise None.\n\n\n\nutils.distributed.gather_scalar_from_all_ranks(fn, world_size=1)\nRun a callable fn on all ranks and gather the results on the specified rank.\nArgs:\n- fn (callable): A function that computes the value. This should not have any side effects.\n- rank (int, optional): The rank that gathers the values. Default is 0.\n- world_size (int, optional): Total number of processes in the current distributed setup.\nReturns:\n- A list of computed values from all ranks if on the gathering rank, otherwise None.\n\n\n\nutils.distributed.is_distributed()\nCheck if distributed training is initialized.\n\n\n\nutils.distributed.is_main_process()\nCheck if the current process is the main process. If not in distributed mode,\nalways return True.\nWe use a simpler logic when the distributed state is not initialized: we just log\non the 0-th local rank.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nbool\nTrue if the current process is the main process, False otherwise.\n\n\n\n\n\n\n\nutils.distributed.reduce_and_broadcast(fn1, fn2)\nRun a callable fn1 on all ranks, gather the results, reduce them using fn2,\nand then broadcast the reduced result to all ranks.\nArgs:\n- fn1 (callable): A function that computes the value on each rank.\n- fn2 (callable): A reduction function that takes a list of values and returns a single value.\n- world_size (int, optional): Total number of processes in the current distributed setup.\nReturns:\n- The reduced and broadcasted value.\n\n\n\nutils.distributed.zero_first(is_main)\nruns the wrapped context so that rank 0 runs first before other ranks"
},
{
"objectID": "docs/api/utils.distributed.html#functions",
"href": "docs/api/utils.distributed.html#functions",
"title": "utils.distributed",
"section": "",
"text": "Name\nDescription\n\n\n\n\nbarrier\nActs as a barrier to wait for all processes. This ensures that all processes\n\n\ncleanup_distributed\nDestroy process group if torch distributed is initialized. Called in training early\n\n\ncompute_and_broadcast\nCompute a value using the function fn only on the specified rank (default is 0).\n\n\ngather_from_all_ranks\nRun a callable fn on all ranks and gather the results on the specified rank.\n\n\ngather_scalar_from_all_ranks\nRun a callable fn on all ranks and gather the results on the specified rank.\n\n\nis_distributed\nCheck if distributed training is initialized.\n\n\nis_main_process\nCheck if the current process is the main process. If not in distributed mode,\n\n\nreduce_and_broadcast\nRun a callable fn1 on all ranks, gather the results, reduce them using fn2,\n\n\nzero_first\nruns the wrapped context so that rank 0 runs first before other ranks\n\n\n\n\n\nutils.distributed.barrier()\nActs as a barrier to wait for all processes. This ensures that all processes\nreach the barrier before proceeding further.\n\n\n\nutils.distributed.cleanup_distributed()\nDestroy process group if torch distributed is initialized. Called in training early\ntermination or when training successfully completes.\n\n\n\nutils.distributed.compute_and_broadcast(fn)\nCompute a value using the function fn only on the specified rank (default is 0).\nThe value is then broadcasted to all other ranks.\nArgs:\n- fn (callable): A function that computes the value. This should not have any side effects.\n- rank (int, optional): The rank that computes the value. Default is 0.\nReturns:\n- The computed value (int or float).\n\n\n\nutils.distributed.gather_from_all_ranks(fn, world_size=1)\nRun a callable fn on all ranks and gather the results on the specified rank.\nArgs:\n- fn (callable): A function that computes the value. This should not have any side effects.\n- rank (int, optional): The rank that gathers the values. Default is 0.\n- world_size (int, optional): Total number of processes in the current distributed setup.\nReturns:\n- A list of computed values from all ranks if on the gathering rank, otherwise None.\n\n\n\nutils.distributed.gather_scalar_from_all_ranks(fn, world_size=1)\nRun a callable fn on all ranks and gather the results on the specified rank.\nArgs:\n- fn (callable): A function that computes the value. This should not have any side effects.\n- rank (int, optional): The rank that gathers the values. Default is 0.\n- world_size (int, optional): Total number of processes in the current distributed setup.\nReturns:\n- A list of computed values from all ranks if on the gathering rank, otherwise None.\n\n\n\nutils.distributed.is_distributed()\nCheck if distributed training is initialized.\n\n\n\nutils.distributed.is_main_process()\nCheck if the current process is the main process. If not in distributed mode,\nalways return True.\nWe use a simpler logic when the distributed state is not initialized: we just log\non the 0-th local rank.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nbool\nTrue if the current process is the main process, False otherwise.\n\n\n\n\n\n\n\nutils.distributed.reduce_and_broadcast(fn1, fn2)\nRun a callable fn1 on all ranks, gather the results, reduce them using fn2,\nand then broadcast the reduced result to all ranks.\nArgs:\n- fn1 (callable): A function that computes the value on each rank.\n- fn2 (callable): A reduction function that takes a list of values and returns a single value.\n- world_size (int, optional): Total number of processes in the current distributed setup.\nReturns:\n- The reduced and broadcasted value.\n\n\n\nutils.distributed.zero_first(is_main)\nruns the wrapped context so that rank 0 runs first before other ranks"
},
{
"objectID": "docs/api/cli.cloud.base.html",
"href": "docs/api/cli.cloud.base.html",
"title": "cli.cloud.base",
"section": "",
"text": "cli.cloud.base\nbase class for cloud platforms from cli\n\n\n\n\n\nName\nDescription\n\n\n\n\nCloud\nAbstract base class for cloud platforms.\n\n\n\n\n\ncli.cloud.base.Cloud()\nAbstract base class for cloud platforms."
},
{
"objectID": "docs/api/cli.cloud.base.html#classes",
"href": "docs/api/cli.cloud.base.html#classes",
"title": "cli.cloud.base",
"section": "",
"text": "Name\nDescription\n\n\n\n\nCloud\nAbstract base class for cloud platforms.\n\n\n\n\n\ncli.cloud.base.Cloud()\nAbstract base class for cloud platforms."
},
{
"objectID": "docs/api/kernels.geglu.html",
"href": "docs/api/kernels.geglu.html",
"title": "kernels.geglu",
"section": "",
"text": "kernels.geglu\nModule for definition of GEGLU Triton kernels.\nSee “GLU Variants Improve Transformer” (https://arxiv.org/abs/2002.05202).\nCredit to unsloth (https://unsloth.ai/) for inspiration for this implementation.\n\n\n\n\n\nName\nDescription\n\n\n\n\ngeglu_backward\nGEGLU backward pass using in-place operations.\n\n\ngeglu_forward\nGEGLU forward pass.\n\n\n\n\n\nkernels.geglu.geglu_backward(grad_output, gate, up)\nGEGLU backward pass using in-place operations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ngrad_output\ntorch.Tensor\nGradient of loss with respect to output, shape [batch, seq_len, hidden_dim].\nrequired\n\n\ngate\ntorch.Tensor\nGate tensor from forward pass, shape [batch, seq_len, hidden_dim].\nrequired\n\n\nup\ntorch.Tensor\nUp-projection tensor from forward pass, shape [batch, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[torch.Tensor, torch.Tensor, torch.Tensor]\nTuple containing: - GEGLU activation output (h) - Gradient with respect to gate (grad_gate) - Gradient with respect to up (grad_up)\n\n\n\n\n\n\nThis function modifies its input tensors in-place to store results.\n\n\n\n\nkernels.geglu.geglu_forward(gate, up)\nGEGLU forward pass.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ngate\ntorch.Tensor\nInput gate tensor of shape [batch, seq_len, hidden_dim].\nrequired\n\n\nup\ntorch.Tensor\nUp-projection tensor of shape [batch, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\ntorch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim]."
},
{
"objectID": "docs/api/kernels.geglu.html#functions",
"href": "docs/api/kernels.geglu.html#functions",
"title": "kernels.geglu",
"section": "",
"text": "Name\nDescription\n\n\n\n\ngeglu_backward\nGEGLU backward pass using in-place operations.\n\n\ngeglu_forward\nGEGLU forward pass.\n\n\n\n\n\nkernels.geglu.geglu_backward(grad_output, gate, up)\nGEGLU backward pass using in-place operations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ngrad_output\ntorch.Tensor\nGradient of loss with respect to output, shape [batch, seq_len, hidden_dim].\nrequired\n\n\ngate\ntorch.Tensor\nGate tensor from forward pass, shape [batch, seq_len, hidden_dim].\nrequired\n\n\nup\ntorch.Tensor\nUp-projection tensor from forward pass, shape [batch, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[torch.Tensor, torch.Tensor, torch.Tensor]\nTuple containing: - GEGLU activation output (h) - Gradient with respect to gate (grad_gate) - Gradient with respect to up (grad_up)\n\n\n\n\n\n\nThis function modifies its input tensors in-place to store results.\n\n\n\n\nkernels.geglu.geglu_forward(gate, up)\nGEGLU forward pass.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ngate\ntorch.Tensor\nInput gate tensor of shape [batch, seq_len, hidden_dim].\nrequired\n\n\nup\ntorch.Tensor\nUp-projection tensor of shape [batch, seq_len, hidden_dim].\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntorch.Tensor\ntorch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim]."
},
{
"objectID": "docs/api/core.trainers.mixins.optimizer.html",
"href": "docs/api/core.trainers.mixins.optimizer.html",
"title": "core.trainers.mixins.optimizer",
"section": "",
"text": "core.trainers.mixins.optimizer\nModule for Axolotl trainer optimizer mixin\n\n\n\n\n\nName\nDescription\n\n\n\n\nOptimizerInitMixin\nMixin to handle common optimizer initialization logic for Trainers (mostly TRL) that do not\n\n\nOptimizerMixin\nMixin class for shared handling of building custom optimizers\n\n\n\n\n\ncore.trainers.mixins.optimizer.OptimizerInitMixin(*args, **kwargs)\nMixin to handle common optimizer initialization logic for Trainers (mostly TRL) that do not\naccept optimizer_cls_and_kwargs as kwarg in constructor.\n\n\n\ncore.trainers.mixins.optimizer.OptimizerMixin()\nMixin class for shared handling of building custom optimizers"
},
{
"objectID": "docs/api/core.trainers.mixins.optimizer.html#classes",
"href": "docs/api/core.trainers.mixins.optimizer.html#classes",
"title": "core.trainers.mixins.optimizer",
"section": "",
"text": "Name\nDescription\n\n\n\n\nOptimizerInitMixin\nMixin to handle common optimizer initialization logic for Trainers (mostly TRL) that do not\n\n\nOptimizerMixin\nMixin class for shared handling of building custom optimizers\n\n\n\n\n\ncore.trainers.mixins.optimizer.OptimizerInitMixin(*args, **kwargs)\nMixin to handle common optimizer initialization logic for Trainers (mostly TRL) that do not\naccept optimizer_cls_and_kwargs as kwarg in constructor.\n\n\n\ncore.trainers.mixins.optimizer.OptimizerMixin()\nMixin class for shared handling of building custom optimizers"
},
{
"objectID": "docs/api/index.html",
"href": "docs/api/index.html",
"title": "API Reference",
"section": "",
"text": "Core functionality for training\n\n\n\ntrain\nPrepare and train a model on a dataset. Can also infer from a model or merge lora\n\n\nevaluate\nModule for evaluating models.\n\n\ndatasets\nModule containing dataset functionality.\n\n\nconvert\nModule containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes\n\n\nprompt_tokenizers\nModule containing PromptTokenizingStrategy and Prompter classes\n\n\nlogging_config\nCommon logging module for axolotl.\n\n\ncore.builders.base\nBase class for trainer builder\n\n\ncore.builders.causal\nBuilder for causal trainers\n\n\ncore.builders.rl\nBuilder for RLHF trainers\n\n\ncore.training_args\nextra axolotl specific training args\n\n\ncore.chat.messages\ninternal message representations of chat messages\n\n\ncore.chat.format.chatml\nChatML transformation functions for MessageContents\n\n\ncore.chat.format.llama3x\nLlama 3.x chat formatting functions for MessageContents\n\n\ncore.chat.format.shared\nshared functions for format transforms\n\n\ncore.datasets.chat\nchat dataset module\n\n\ncore.datasets.transforms.chat_builder\nThis module contains a function that builds a transform that takes a row from the\n\n\n\n\n\n\nCommand-line interface\n\n\n\ncli.main\nClick CLI definitions for various axolotl commands.\n\n\ncli.train\nCLI to run training on a model.\n\n\ncli.evaluate\nCLI to run evaluation on a model.\n\n\ncli.args\nModule for axolotl CLI command arguments.\n\n\ncli.art\nAxolotl ASCII logo utils.\n\n\ncli.checks\nVarious checks for Axolotl CLI.\n\n\ncli.config\nConfiguration loading and processing.\n\n\ncli.delinearize_llama4\nCLI tool to delinearize quantized/Linearized Llama-4 models.\n\n\ncli.inference\nCLI to run inference on a trained model.\n\n\ncli.merge_lora\nCLI to merge a trained LoRA into a base model.\n\n\ncli.merge_sharded_fsdp_weights\nCLI to merge sharded FSDP model checkpoints into a single combined checkpoint.\n\n\ncli.preprocess\nCLI to run preprocessing of a dataset.\n\n\ncli.quantize\nCLI to post-training quantize a model using torchao\n\n\ncli.vllm_serve\nCLI to start the vllm server for online RL\n\n\ncli.cloud.base\nbase class for cloud platforms from cli\n\n\ncli.cloud.modal_\nModal Cloud support from CLI\n\n\ncli.utils\nInit for axolotl.cli.utils module.\n\n\ncli.utils.args\nUtilities for axolotl CLI args.\n\n\ncli.utils.fetch\nUtilities for axolotl fetch CLI command.\n\n\ncli.utils.load\nUtilities for model, tokenizer, etc. loading.\n\n\ncli.utils.sweeps\nUtilities for handling sweeps over configs for axolotl train CLI command\n\n\ncli.utils.train\nUtilities for axolotl train CLI command.\n\n\n\n\n\n\nTraining implementations\n\n\n\ncore.trainers.base\nModule for customized trainers\n\n\ncore.trainers.trl\nModule for TRL RL trainers\n\n\ncore.trainers.mamba\nModule for mamba trainer\n\n\ncore.trainers.dpo.trainer\nDPO trainer for axolotl\n\n\ncore.trainers.grpo.trainer\nAxolotl GRPO trainers (with and without sequence parallelism handling)\n\n\ncore.trainers.grpo.sampler\nRepeat random sampler (similar to the one implemented in\n\n\ncore.trainers.utils\nUtils for Axolotl trainers\n\n\n\n\n\n\nFunctionality for loading and patching models, tokenizers, etc.\n\n\n\nloaders.model\nModel loader class implementation for loading, configuring, and patching various models.\n\n\nloaders.tokenizer\nTokenizer loading functionality and associated utils\n\n\nloaders.processor\nProcessor loading functionality for multi-modal models\n\n\nloaders.adapter\nAdapter loading functionality, including LoRA / QLoRA and associated utils\n\n\nloaders.patch_manager\nPatch manager class implementation to complement axolotl.loaders.ModelLoader.\n\n\nloaders.constants\nShared constants for axolotl.loaders module\n\n\n\n\n\n\nMixin classes for augmenting trainers\n\n\n\ncore.trainers.mixins.optimizer\nModule for Axolotl trainer optimizer mixin\n\n\ncore.trainers.mixins.rng_state_loader\nTemporary fix/override for bug in resume from checkpoint\n\n\ncore.trainers.mixins.scheduler\nModule for Axolotl trainer scheduler mixin\n\n\n\n\n\n\nContext managers for altering trainer behaviors\n\n\n\nutils.ctx_managers.sequence_parallel\nModule for Axolotl trainer sequence parallelism manager and utilities\n\n\n\n\n\n\nPrompt formatting strategies\n\n\n\nprompt_strategies.base\nmodule for base dataset transform strategies\n\n\nprompt_strategies.chat_template\nHF Chat Templates prompt strategy\n\n\nprompt_strategies.alpaca_chat\nModule for Alpaca prompt strategy classes\n\n\nprompt_strategies.alpaca_instruct\nModule loading the AlpacaInstructPromptTokenizingStrategy class\n\n\nprompt_strategies.alpaca_w_system\nPrompt strategies loader for alpaca instruction datasets with system prompts\n\n\nprompt_strategies.user_defined\nUser Defined prompts with configuration from the YML config\n\n\nprompt_strategies.llama2_chat\nPrompt Strategy for finetuning Llama2 chat models\n\n\nprompt_strategies.completion\nBasic completion text\n\n\nprompt_strategies.input_output\nModule for plain input/output prompt pairs\n\n\nprompt_strategies.stepwise_supervised\nModule for stepwise datasets, typically including a prompt and reasoning traces,\n\n\nprompt_strategies.metharme\nModule containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class\n\n\nprompt_strategies.orcamini\nPrompt Strategy for finetuning Orca Mini (v2) models\n\n\nprompt_strategies.pygmalion\nModule containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class\n\n\nprompt_strategies.messages.chat\nChat dataset wrapping strategy for new internal messages representations\n\n\nprompt_strategies.dpo.chat_template\nDPO prompt strategies for using tokenizer chat templates.\n\n\nprompt_strategies.dpo.llama3\nDPO strategies for llama-3 chat template\n\n\nprompt_strategies.dpo.chatml\nDPO strategies for chatml\n\n\nprompt_strategies.dpo.zephyr\nDPO strategies for zephyr\n\n\nprompt_strategies.dpo.user_defined\nUser-defined DPO strategies\n\n\nprompt_strategies.dpo.passthrough\nDPO prompt strategies passthrough/zero-processing strategy\n\n\nprompt_strategies.kto.llama3\nKTO strategies for llama-3 chat template\n\n\nprompt_strategies.kto.chatml\nKTO strategies for chatml\n\n\nprompt_strategies.kto.user_defined\nUser-defined KTO strategies\n\n\nprompt_strategies.orpo.chat_template\nchatml prompt tokenization strategy for ORPO\n\n\nprompt_strategies.bradley_terry.llama3\nchatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template\n\n\n\n\n\n\nLow-level performance optimizations\n\n\n\nkernels.lora\nModule for definition of Low-Rank Adaptation (LoRA) Triton kernels.\n\n\nkernels.geglu\nModule for definition of GEGLU Triton kernels.\n\n\nkernels.swiglu\nModule for definition of SwiGLU Triton kernels.\n\n\nkernels.quantize\nDequantization utilities for bitsandbytes and FP8 integration.\n\n\nkernels.utils\nUtilities for axolotl.kernels submodules.\n\n\n\n\n\n\nRuntime patches for model optimizations\n\n\n\nmonkeypatch.llama_attn_hijack_flash\nFlash attention monkey patch for llama model\n\n\nmonkeypatch.llama_attn_hijack_xformers\nDirectly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments\n\n\nmonkeypatch.mistral_attn_hijack_flash\nFlash attention monkey patch for mistral model\n\n\nmonkeypatch.multipack\nmultipack patching for v2 of sample packing\n\n\nmonkeypatch.relora\nImplements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.\n\n\nmonkeypatch.lora_kernels\nModule for patching custom LoRA Triton kernels and torch.autograd functions.\n\n\nmonkeypatch.utils\nShared utils for the monkeypatches\n\n\nmonkeypatch.btlm_attn_hijack_flash\nFlash attention monkey patch for cerebras btlm model\n\n\nmonkeypatch.stablelm_attn_hijack_flash\nPyTorch StableLM Epoch model.\n\n\nmonkeypatch.trainer_fsdp_optim\nfix for FSDP optimizer save in trainer w 4.47.0\n\n\nmonkeypatch.transformers_fa_utils\nsee https://github.com/huggingface/transformers/pull/35834\n\n\nmonkeypatch.unsloth_\nmodule for patching with unsloth optimizations\n\n\nmonkeypatch.data.batch_dataset_fetcher\nMonkey patches for the dataset fetcher to handle batches of packed indexes.\n\n\nmonkeypatch.mixtral\nPatches to support multipack for mixtral\n\n\nmonkeypatch.gradient_checkpointing.offload_cpu\nCPU offloaded checkpointing\n\n\nmonkeypatch.gradient_checkpointing.offload_disk\nDISCO - DIsk-based Storage and Checkpointing with Optimized prefetching\n\n\n\n\n\n\nUtility functions\n\n\n\nutils.tokenization\nModule for tokenization utilities\n\n\nutils.chat_templates\nThis module provides functionality for selecting chat templates based on user choices.\n\n\nutils.lora\nmodule to get the state dict of a merged lora model\n\n\nutils.model_shard_quant\nmodule to handle loading model on cpu/meta device for FSDP\n\n\nutils.bench\nBenchmarking and measurement utilities\n\n\nutils.freeze\nmodule to freeze/unfreeze parameters by name\n\n\nutils.trainer\nModule containing the Trainer class and related functions\n\n\nutils.schedulers\nModule for custom LRScheduler class\n\n\nutils.distributed\nUtilities for distributed functionality.\n\n\nutils.dict\nModule containing the DictDefault class\n\n\nutils.optimizers.adopt\nCopied from https://github.com/iShohei220/adopt\n\n\nutils.data.streaming\nData handling specific to streaming datasets.\n\n\nutils.data.sft\nData handling specific to SFT.\n\n\nutils.quantization\nUtilities for quantization including QAT and PTQ using torchao.\n\n\n\n\n\n\nPydantic data models for Axolotl config\n\n\n\nutils.schemas.config\nModule with Pydantic models for configuration.\n\n\nutils.schemas.model\nPydantic models for model input / output, etc. configuration\n\n\nutils.schemas.training\nPydantic models for training hyperparameters\n\n\nutils.schemas.datasets\nPydantic models for datasets-related configuration\n\n\nutils.schemas.peft\nPydantic models for PEFT-related configuration\n\n\nutils.schemas.trl\nPydantic models for TRL trainer configuration\n\n\nutils.schemas.multimodal\nPydantic models for multimodal-related configuration\n\n\nutils.schemas.integrations\nPydantic models for Axolotl integrations\n\n\nutils.schemas.enums\nEnums for Axolotl input config\n\n\nutils.schemas.utils\nUtilities for Axolotl Pydantic models\n\n\n\n\n\n\nThird-party integrations and extensions\n\n\n\nintegrations.base\nBase class for all plugins.\n\n\nintegrations.cut_cross_entropy.args\nModule for handling Cut Cross Entropy input arguments.\n\n\nintegrations.grokfast.optimizer\n\n\n\nintegrations.kd.trainer\nKD trainer\n\n\nintegrations.liger.args\nModule for handling LIGER input arguments.\n\n\nintegrations.lm_eval.args\nModule for handling lm eval harness input arguments.\n\n\nintegrations.spectrum.args\nModule for handling Spectrum input arguments.\n\n\n\n\n\n\nCommon utilities and shared functionality\n\n\n\ncommon.architectures\nCommon architecture specific constants\n\n\ncommon.const\nVarious shared constants\n\n\ncommon.datasets\nDataset loading utilities.\n\n\n\n\n\n\nCustom model implementations\n\n\n\nmodels.mamba.modeling_mamba\n\n\n\n\n\n\n\nData processing utilities\n\n\n\nutils.collators.core\nbasic shared collator constants\n\n\nutils.collators.batching\nData collators for axolotl to pad labels and position_ids for packed sequences\n\n\nutils.collators.mamba\ncollators for Mamba\n\n\nutils.collators.mm_chat\nCollators for multi-modal chat messages and packing\n\n\nutils.samplers.multipack\nMultipack Batch Sampler - An efficient batch sampler for packing variable-length sequences\n\n\n\n\n\n\nTraining callbacks\n\n\n\nutils.callbacks.perplexity\ncallback to calculate perplexity as an evaluation metric.\n\n\nutils.callbacks.profiler\nHF Trainer callback for creating pytorch profiling snapshots\n\n\nutils.callbacks.lisa\nmodule for LISA\n\n\nutils.callbacks.mlflow_\nMLFlow module for trainer callbacks\n\n\nutils.callbacks.comet_\nComet module for trainer callbacks\n\n\nutils.callbacks.qat\nQAT Callback for HF Causal Trainer"
},
{
"objectID": "docs/api/index.html#core",
"href": "docs/api/index.html#core",
"title": "API Reference",
"section": "",
"text": "Core functionality for training\n\n\n\ntrain\nPrepare and train a model on a dataset. Can also infer from a model or merge lora\n\n\nevaluate\nModule for evaluating models.\n\n\ndatasets\nModule containing dataset functionality.\n\n\nconvert\nModule containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes\n\n\nprompt_tokenizers\nModule containing PromptTokenizingStrategy and Prompter classes\n\n\nlogging_config\nCommon logging module for axolotl.\n\n\ncore.builders.base\nBase class for trainer builder\n\n\ncore.builders.causal\nBuilder for causal trainers\n\n\ncore.builders.rl\nBuilder for RLHF trainers\n\n\ncore.training_args\nextra axolotl specific training args\n\n\ncore.chat.messages\ninternal message representations of chat messages\n\n\ncore.chat.format.chatml\nChatML transformation functions for MessageContents\n\n\ncore.chat.format.llama3x\nLlama 3.x chat formatting functions for MessageContents\n\n\ncore.chat.format.shared\nshared functions for format transforms\n\n\ncore.datasets.chat\nchat dataset module\n\n\ncore.datasets.transforms.chat_builder\nThis module contains a function that builds a transform that takes a row from the"
},
{
"objectID": "docs/api/index.html#cli",
"href": "docs/api/index.html#cli",
"title": "API Reference",
"section": "",
"text": "Command-line interface\n\n\n\ncli.main\nClick CLI definitions for various axolotl commands.\n\n\ncli.train\nCLI to run training on a model.\n\n\ncli.evaluate\nCLI to run evaluation on a model.\n\n\ncli.args\nModule for axolotl CLI command arguments.\n\n\ncli.art\nAxolotl ASCII logo utils.\n\n\ncli.checks\nVarious checks for Axolotl CLI.\n\n\ncli.config\nConfiguration loading and processing.\n\n\ncli.delinearize_llama4\nCLI tool to delinearize quantized/Linearized Llama-4 models.\n\n\ncli.inference\nCLI to run inference on a trained model.\n\n\ncli.merge_lora\nCLI to merge a trained LoRA into a base model.\n\n\ncli.merge_sharded_fsdp_weights\nCLI to merge sharded FSDP model checkpoints into a single combined checkpoint.\n\n\ncli.preprocess\nCLI to run preprocessing of a dataset.\n\n\ncli.quantize\nCLI to post-training quantize a model using torchao\n\n\ncli.vllm_serve\nCLI to start the vllm server for online RL\n\n\ncli.cloud.base\nbase class for cloud platforms from cli\n\n\ncli.cloud.modal_\nModal Cloud support from CLI\n\n\ncli.utils\nInit for axolotl.cli.utils module.\n\n\ncli.utils.args\nUtilities for axolotl CLI args.\n\n\ncli.utils.fetch\nUtilities for axolotl fetch CLI command.\n\n\ncli.utils.load\nUtilities for model, tokenizer, etc. loading.\n\n\ncli.utils.sweeps\nUtilities for handling sweeps over configs for axolotl train CLI command\n\n\ncli.utils.train\nUtilities for axolotl train CLI command."
},
{
"objectID": "docs/api/index.html#trainers",
"href": "docs/api/index.html#trainers",
"title": "API Reference",
"section": "",
"text": "Training implementations\n\n\n\ncore.trainers.base\nModule for customized trainers\n\n\ncore.trainers.trl\nModule for TRL RL trainers\n\n\ncore.trainers.mamba\nModule for mamba trainer\n\n\ncore.trainers.dpo.trainer\nDPO trainer for axolotl\n\n\ncore.trainers.grpo.trainer\nAxolotl GRPO trainers (with and without sequence parallelism handling)\n\n\ncore.trainers.grpo.sampler\nRepeat random sampler (similar to the one implemented in\n\n\ncore.trainers.utils\nUtils for Axolotl trainers"
},
{
"objectID": "docs/api/index.html#model-loading",
"href": "docs/api/index.html#model-loading",
"title": "API Reference",
"section": "",
"text": "Functionality for loading and patching models, tokenizers, etc.\n\n\n\nloaders.model\nModel loader class implementation for loading, configuring, and patching various models.\n\n\nloaders.tokenizer\nTokenizer loading functionality and associated utils\n\n\nloaders.processor\nProcessor loading functionality for multi-modal models\n\n\nloaders.adapter\nAdapter loading functionality, including LoRA / QLoRA and associated utils\n\n\nloaders.patch_manager\nPatch manager class implementation to complement axolotl.loaders.ModelLoader.\n\n\nloaders.constants\nShared constants for axolotl.loaders module"
},
{
"objectID": "docs/api/index.html#mixins",
"href": "docs/api/index.html#mixins",
"title": "API Reference",
"section": "",
"text": "Mixin classes for augmenting trainers\n\n\n\ncore.trainers.mixins.optimizer\nModule for Axolotl trainer optimizer mixin\n\n\ncore.trainers.mixins.rng_state_loader\nTemporary fix/override for bug in resume from checkpoint\n\n\ncore.trainers.mixins.scheduler\nModule for Axolotl trainer scheduler mixin"
},
{
"objectID": "docs/api/index.html#context-managers",
"href": "docs/api/index.html#context-managers",
"title": "API Reference",
"section": "",
"text": "Context managers for altering trainer behaviors\n\n\n\nutils.ctx_managers.sequence_parallel\nModule for Axolotl trainer sequence parallelism manager and utilities"
},
{
"objectID": "docs/api/index.html#prompt-strategies",
"href": "docs/api/index.html#prompt-strategies",
"title": "API Reference",
"section": "",
"text": "Prompt formatting strategies\n\n\n\nprompt_strategies.base\nmodule for base dataset transform strategies\n\n\nprompt_strategies.chat_template\nHF Chat Templates prompt strategy\n\n\nprompt_strategies.alpaca_chat\nModule for Alpaca prompt strategy classes\n\n\nprompt_strategies.alpaca_instruct\nModule loading the AlpacaInstructPromptTokenizingStrategy class\n\n\nprompt_strategies.alpaca_w_system\nPrompt strategies loader for alpaca instruction datasets with system prompts\n\n\nprompt_strategies.user_defined\nUser Defined prompts with configuration from the YML config\n\n\nprompt_strategies.llama2_chat\nPrompt Strategy for finetuning Llama2 chat models\n\n\nprompt_strategies.completion\nBasic completion text\n\n\nprompt_strategies.input_output\nModule for plain input/output prompt pairs\n\n\nprompt_strategies.stepwise_supervised\nModule for stepwise datasets, typically including a prompt and reasoning traces,\n\n\nprompt_strategies.metharme\nModule containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class\n\n\nprompt_strategies.orcamini\nPrompt Strategy for finetuning Orca Mini (v2) models\n\n\nprompt_strategies.pygmalion\nModule containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class\n\n\nprompt_strategies.messages.chat\nChat dataset wrapping strategy for new internal messages representations\n\n\nprompt_strategies.dpo.chat_template\nDPO prompt strategies for using tokenizer chat templates.\n\n\nprompt_strategies.dpo.llama3\nDPO strategies for llama-3 chat template\n\n\nprompt_strategies.dpo.chatml\nDPO strategies for chatml\n\n\nprompt_strategies.dpo.zephyr\nDPO strategies for zephyr\n\n\nprompt_strategies.dpo.user_defined\nUser-defined DPO strategies\n\n\nprompt_strategies.dpo.passthrough\nDPO prompt strategies passthrough/zero-processing strategy\n\n\nprompt_strategies.kto.llama3\nKTO strategies for llama-3 chat template\n\n\nprompt_strategies.kto.chatml\nKTO strategies for chatml\n\n\nprompt_strategies.kto.user_defined\nUser-defined KTO strategies\n\n\nprompt_strategies.orpo.chat_template\nchatml prompt tokenization strategy for ORPO\n\n\nprompt_strategies.bradley_terry.llama3\nchatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template"
},
{
"objectID": "docs/api/index.html#kernels",
"href": "docs/api/index.html#kernels",
"title": "API Reference",
"section": "",
"text": "Low-level performance optimizations\n\n\n\nkernels.lora\nModule for definition of Low-Rank Adaptation (LoRA) Triton kernels.\n\n\nkernels.geglu\nModule for definition of GEGLU Triton kernels.\n\n\nkernels.swiglu\nModule for definition of SwiGLU Triton kernels.\n\n\nkernels.quantize\nDequantization utilities for bitsandbytes and FP8 integration.\n\n\nkernels.utils\nUtilities for axolotl.kernels submodules."
},
{
"objectID": "docs/api/index.html#monkey-patches",
"href": "docs/api/index.html#monkey-patches",
"title": "API Reference",
"section": "",
"text": "Runtime patches for model optimizations\n\n\n\nmonkeypatch.llama_attn_hijack_flash\nFlash attention monkey patch for llama model\n\n\nmonkeypatch.llama_attn_hijack_xformers\nDirectly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments\n\n\nmonkeypatch.mistral_attn_hijack_flash\nFlash attention monkey patch for mistral model\n\n\nmonkeypatch.multipack\nmultipack patching for v2 of sample packing\n\n\nmonkeypatch.relora\nImplements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune.\n\n\nmonkeypatch.lora_kernels\nModule for patching custom LoRA Triton kernels and torch.autograd functions.\n\n\nmonkeypatch.utils\nShared utils for the monkeypatches\n\n\nmonkeypatch.btlm_attn_hijack_flash\nFlash attention monkey patch for cerebras btlm model\n\n\nmonkeypatch.stablelm_attn_hijack_flash\nPyTorch StableLM Epoch model.\n\n\nmonkeypatch.trainer_fsdp_optim\nfix for FSDP optimizer save in trainer w 4.47.0\n\n\nmonkeypatch.transformers_fa_utils\nsee https://github.com/huggingface/transformers/pull/35834\n\n\nmonkeypatch.unsloth_\nmodule for patching with unsloth optimizations\n\n\nmonkeypatch.data.batch_dataset_fetcher\nMonkey patches for the dataset fetcher to handle batches of packed indexes.\n\n\nmonkeypatch.mixtral\nPatches to support multipack for mixtral\n\n\nmonkeypatch.gradient_checkpointing.offload_cpu\nCPU offloaded checkpointing\n\n\nmonkeypatch.gradient_checkpointing.offload_disk\nDISCO - DIsk-based Storage and Checkpointing with Optimized prefetching"
},
{
"objectID": "docs/api/index.html#utils",
"href": "docs/api/index.html#utils",
"title": "API Reference",
"section": "",
"text": "Utility functions\n\n\n\nutils.tokenization\nModule for tokenization utilities\n\n\nutils.chat_templates\nThis module provides functionality for selecting chat templates based on user choices.\n\n\nutils.lora\nmodule to get the state dict of a merged lora model\n\n\nutils.model_shard_quant\nmodule to handle loading model on cpu/meta device for FSDP\n\n\nutils.bench\nBenchmarking and measurement utilities\n\n\nutils.freeze\nmodule to freeze/unfreeze parameters by name\n\n\nutils.trainer\nModule containing the Trainer class and related functions\n\n\nutils.schedulers\nModule for custom LRScheduler class\n\n\nutils.distributed\nUtilities for distributed functionality.\n\n\nutils.dict\nModule containing the DictDefault class\n\n\nutils.optimizers.adopt\nCopied from https://github.com/iShohei220/adopt\n\n\nutils.data.streaming\nData handling specific to streaming datasets.\n\n\nutils.data.sft\nData handling specific to SFT.\n\n\nutils.quantization\nUtilities for quantization including QAT and PTQ using torchao."
},
{
"objectID": "docs/api/index.html#schemas",
"href": "docs/api/index.html#schemas",
"title": "API Reference",
"section": "",
"text": "Pydantic data models for Axolotl config\n\n\n\nutils.schemas.config\nModule with Pydantic models for configuration.\n\n\nutils.schemas.model\nPydantic models for model input / output, etc. configuration\n\n\nutils.schemas.training\nPydantic models for training hyperparameters\n\n\nutils.schemas.datasets\nPydantic models for datasets-related configuration\n\n\nutils.schemas.peft\nPydantic models for PEFT-related configuration\n\n\nutils.schemas.trl\nPydantic models for TRL trainer configuration\n\n\nutils.schemas.multimodal\nPydantic models for multimodal-related configuration\n\n\nutils.schemas.integrations\nPydantic models for Axolotl integrations\n\n\nutils.schemas.enums\nEnums for Axolotl input config\n\n\nutils.schemas.utils\nUtilities for Axolotl Pydantic models"
},
{
"objectID": "docs/api/index.html#integrations",
"href": "docs/api/index.html#integrations",
"title": "API Reference",
"section": "",
"text": "Third-party integrations and extensions\n\n\n\nintegrations.base\nBase class for all plugins.\n\n\nintegrations.cut_cross_entropy.args\nModule for handling Cut Cross Entropy input arguments.\n\n\nintegrations.grokfast.optimizer\n\n\n\nintegrations.kd.trainer\nKD trainer\n\n\nintegrations.liger.args\nModule for handling LIGER input arguments.\n\n\nintegrations.lm_eval.args\nModule for handling lm eval harness input arguments.\n\n\nintegrations.spectrum.args\nModule for handling Spectrum input arguments."
},
{
"objectID": "docs/api/index.html#common",
"href": "docs/api/index.html#common",
"title": "API Reference",
"section": "",
"text": "Common utilities and shared functionality\n\n\n\ncommon.architectures\nCommon architecture specific constants\n\n\ncommon.const\nVarious shared constants\n\n\ncommon.datasets\nDataset loading utilities."
},
{
"objectID": "docs/api/index.html#models",
"href": "docs/api/index.html#models",
"title": "API Reference",
"section": "",
"text": "Custom model implementations\n\n\n\nmodels.mamba.modeling_mamba"
},
{
"objectID": "docs/api/index.html#data-processing",
"href": "docs/api/index.html#data-processing",
"title": "API Reference",
"section": "",
"text": "Data processing utilities\n\n\n\nutils.collators.core\nbasic shared collator constants\n\n\nutils.collators.batching\nData collators for axolotl to pad labels and position_ids for packed sequences\n\n\nutils.collators.mamba\ncollators for Mamba\n\n\nutils.collators.mm_chat\nCollators for multi-modal chat messages and packing\n\n\nutils.samplers.multipack\nMultipack Batch Sampler - An efficient batch sampler for packing variable-length sequences"
},
{
"objectID": "docs/api/index.html#callbacks",
"href": "docs/api/index.html#callbacks",
"title": "API Reference",
"section": "",
"text": "Training callbacks\n\n\n\nutils.callbacks.perplexity\ncallback to calculate perplexity as an evaluation metric.\n\n\nutils.callbacks.profiler\nHF Trainer callback for creating pytorch profiling snapshots\n\n\nutils.callbacks.lisa\nmodule for LISA\n\n\nutils.callbacks.mlflow_\nMLFlow module for trainer callbacks\n\n\nutils.callbacks.comet_\nComet module for trainer callbacks\n\n\nutils.callbacks.qat\nQAT Callback for HF Causal Trainer"
},
{
"objectID": "docs/api/prompt_strategies.base.html",
"href": "docs/api/prompt_strategies.base.html",
"title": "prompt_strategies.base",
"section": "",
"text": "prompt_strategies.base\nprompt_strategies.base\nmodule for base dataset transform strategies"
},
{
"objectID": "docs/api/cli.evaluate.html",
"href": "docs/api/cli.evaluate.html",
"title": "cli.evaluate",
"section": "",
"text": "cli.evaluate\nCLI to run evaluation on a model.\n\n\n\n\n\nName\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_evaluate.\n\n\ndo_evaluate\nEvaluates a transformers model by first loading the dataset(s) specified in the\n\n\n\n\n\ncli.evaluate.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls do_evaluate.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.evaluate.do_evaluate(cfg, cli_args)\nEvaluates a transformers model by first loading the dataset(s) specified in the\naxolotl config, and then calling axolotl.evaluate.evaluate, which computes\nevaluation metrics on the given dataset(s) and writes them to disk.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nTrainerCliArgs\nCLI arguments.\nrequired"
},
{
"objectID": "docs/api/cli.evaluate.html#functions",
"href": "docs/api/cli.evaluate.html#functions",
"title": "cli.evaluate",
"section": "",
"text": "Name\nDescription\n\n\n\n\ndo_cli\nParses axolotl config, CLI args, and calls do_evaluate.\n\n\ndo_evaluate\nEvaluates a transformers model by first loading the dataset(s) specified in the\n\n\n\n\n\ncli.evaluate.do_cli(config=Path('examples/'), **kwargs)\nParses axolotl config, CLI args, and calls do_evaluate.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nUnion[Path, str]\nPath to axolotl config YAML file.\nPath('examples/')\n\n\nkwargs\n\nAdditional keyword arguments to override config file values.\n{}\n\n\n\n\n\n\n\ncli.evaluate.do_evaluate(cfg, cli_args)\nEvaluates a transformers model by first loading the dataset(s) specified in the\naxolotl config, and then calling axolotl.evaluate.evaluate, which computes\nevaluation metrics on the given dataset(s) and writes them to disk.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ncli_args\nTrainerCliArgs\nCLI arguments.\nrequired"
},
{
"objectID": "docs/api/train.html",
"href": "docs/api/train.html",
"title": "train",
"section": "",
"text": "train\nPrepare and train a model on a dataset. Can also infer from a model or merge lora\n\n\n\n\n\nName\nDescription\n\n\n\n\ncreate_model_card\nCreate a model card for the trained model if needed.\n\n\nexecute_training\nExecute the training process with appropriate SDP kernel configurations.\n\n\nhandle_untrained_tokens_fix\nApply fixes for untrained tokens if configured.\n\n\nsave_initial_configs\nSave initial configurations before training.\n\n\nsave_trained_model\nSave the trained model according to configuration and training setup.\n\n\nsetup_model_and_tokenizer\nLoad the tokenizer, processor (for multimodal models), and model based on\n\n\nsetup_model_and_trainer\nLoad model, tokenizer, trainer, etc. Helper function to encapsulate the full\n\n\nsetup_model_card\nSet up the Axolotl badge and add the Axolotl config to the model card if available.\n\n\nsetup_reference_model\nSet up the reference model for RL training if needed.\n\n\nsetup_signal_handler\nSet up signal handler for graceful termination.\n\n\ntrain\nTrain a model on the given dataset.\n\n\n\n\n\ntrain.create_model_card(cfg, trainer)\nCreate a model card for the trained model if needed.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object with model card creation capabilities.\nrequired\n\n\n\n\n\n\n\ntrain.execute_training(cfg, trainer, resume_from_checkpoint)\nExecute the training process with appropriate SDP kernel configurations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntrainer\nAny\nThe configured trainer object.\nrequired\n\n\nresume_from_checkpoint\nstr | None\nPath to checkpoint to resume from, if applicable.\nrequired\n\n\n\n\n\n\n\ntrain.handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)\nApply fixes for untrained tokens if configured.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\nmodel\nPreTrainedModel\nThe model to apply fixes to.\nrequired\n\n\ntokenizer\nPreTrainedTokenizer\nThe tokenizer for token identification.\nrequired\n\n\ntrain_dataset\nDataset\nThe training dataset to use.\nrequired\n\n\n\n\n\n\n\ntrain.save_initial_configs(cfg, tokenizer, model, peft_config, processor)\nSave initial configurations before training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntokenizer\nPreTrainedTokenizer\nThe tokenizer to save.\nrequired\n\n\nmodel\nPreTrainedModel\nThe model to save configuration for.\nrequired\n\n\npeft_config\nPeftConfig | None\nThe PEFT configuration to save if applicable.\nrequired\n\n\n\n\n\n\n\ntrain.save_trained_model(cfg, trainer, model)\nSave the trained model according to configuration and training setup.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntrainer\nAny\nThe trainer object.\nrequired\n\n\nmodel\nPreTrainedModel\nThe trained model to save.\nrequired\n\n\n\n\n\n\n\ntrain.setup_model_and_tokenizer(cfg)\nLoad the tokenizer, processor (for multimodal models), and model based on\nconfiguration.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None]\nTuple containing model, tokenizer, peft_config (if LoRA / QLoRA, else None), and processor (if multimodal, else None).\n\n\n\n\n\n\n\ntrain.setup_model_and_trainer(cfg, dataset_meta)\nLoad model, tokenizer, trainer, etc. Helper function to encapsulate the full\ntrainer setup.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration dictionary with training parameters.\nrequired\n\n\ndataset_meta\nTrainDatasetMeta\nObject with training, validation datasets and metadata.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple['HFRLTrainerBuilder' | 'HFCausalTrainerBuilder', PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None]\nTuple of: - Trainer (Causal or RLHF) - Model - Tokenizer - PEFT config - Processor\n\n\n\n\n\n\n\ntrain.setup_model_card(cfg)\nSet up the Axolotl badge and add the Axolotl config to the model card if available.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\ntrain.setup_reference_model(cfg, tokenizer)\nSet up the reference model for RL training if needed.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntokenizer\nPreTrainedTokenizer\nThe tokenizer to use for the reference model.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nPreTrainedModel | None\nReference model if needed for RL training, None otherwise.\n\n\n\n\n\n\n\ntrain.setup_signal_handler(cfg, model)\nSet up signal handler for graceful termination.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\nmodel\nPreTrainedModel\nThe model to save on termination\nrequired\n\n\n\n\n\n\n\ntrain.train(cfg, dataset_meta)\nTrain a model on the given dataset.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration dictionary with training parameters\nrequired\n\n\ndataset_meta\nTrainDatasetMeta\nObject with training, validation datasets and metadata\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]\nTuple of (model, tokenizer) after training"
},
{
"objectID": "docs/api/train.html#functions",
"href": "docs/api/train.html#functions",
"title": "train",
"section": "",
"text": "Name\nDescription\n\n\n\n\ncreate_model_card\nCreate a model card for the trained model if needed.\n\n\nexecute_training\nExecute the training process with appropriate SDP kernel configurations.\n\n\nhandle_untrained_tokens_fix\nApply fixes for untrained tokens if configured.\n\n\nsave_initial_configs\nSave initial configurations before training.\n\n\nsave_trained_model\nSave the trained model according to configuration and training setup.\n\n\nsetup_model_and_tokenizer\nLoad the tokenizer, processor (for multimodal models), and model based on\n\n\nsetup_model_and_trainer\nLoad model, tokenizer, trainer, etc. Helper function to encapsulate the full\n\n\nsetup_model_card\nSet up the Axolotl badge and add the Axolotl config to the model card if available.\n\n\nsetup_reference_model\nSet up the reference model for RL training if needed.\n\n\nsetup_signal_handler\nSet up signal handler for graceful termination.\n\n\ntrain\nTrain a model on the given dataset.\n\n\n\n\n\ntrain.create_model_card(cfg, trainer)\nCreate a model card for the trained model if needed.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object with model card creation capabilities.\nrequired\n\n\n\n\n\n\n\ntrain.execute_training(cfg, trainer, resume_from_checkpoint)\nExecute the training process with appropriate SDP kernel configurations.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntrainer\nAny\nThe configured trainer object.\nrequired\n\n\nresume_from_checkpoint\nstr | None\nPath to checkpoint to resume from, if applicable.\nrequired\n\n\n\n\n\n\n\ntrain.handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)\nApply fixes for untrained tokens if configured.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\nmodel\nPreTrainedModel\nThe model to apply fixes to.\nrequired\n\n\ntokenizer\nPreTrainedTokenizer\nThe tokenizer for token identification.\nrequired\n\n\ntrain_dataset\nDataset\nThe training dataset to use.\nrequired\n\n\n\n\n\n\n\ntrain.save_initial_configs(cfg, tokenizer, model, peft_config, processor)\nSave initial configurations before training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntokenizer\nPreTrainedTokenizer\nThe tokenizer to save.\nrequired\n\n\nmodel\nPreTrainedModel\nThe model to save configuration for.\nrequired\n\n\npeft_config\nPeftConfig | None\nThe PEFT configuration to save if applicable.\nrequired\n\n\n\n\n\n\n\ntrain.save_trained_model(cfg, trainer, model)\nSave the trained model according to configuration and training setup.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntrainer\nAny\nThe trainer object.\nrequired\n\n\nmodel\nPreTrainedModel\nThe trained model to save.\nrequired\n\n\n\n\n\n\n\ntrain.setup_model_and_tokenizer(cfg)\nLoad the tokenizer, processor (for multimodal models), and model based on\nconfiguration.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None]\nTuple containing model, tokenizer, peft_config (if LoRA / QLoRA, else None), and processor (if multimodal, else None).\n\n\n\n\n\n\n\ntrain.setup_model_and_trainer(cfg, dataset_meta)\nLoad model, tokenizer, trainer, etc. Helper function to encapsulate the full\ntrainer setup.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration dictionary with training parameters.\nrequired\n\n\ndataset_meta\nTrainDatasetMeta\nObject with training, validation datasets and metadata.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple['HFRLTrainerBuilder' | 'HFCausalTrainerBuilder', PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None]\nTuple of: - Trainer (Causal or RLHF) - Model - Tokenizer - PEFT config - Processor\n\n\n\n\n\n\n\ntrain.setup_model_card(cfg)\nSet up the Axolotl badge and add the Axolotl config to the model card if available.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\n\n\n\n\n\ntrain.setup_reference_model(cfg, tokenizer)\nSet up the reference model for RL training if needed.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\ntokenizer\nPreTrainedTokenizer\nThe tokenizer to use for the reference model.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nPreTrainedModel | None\nReference model if needed for RL training, None otherwise.\n\n\n\n\n\n\n\ntrain.setup_signal_handler(cfg, model)\nSet up signal handler for graceful termination.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nDictionary mapping axolotl config keys to values.\nrequired\n\n\nmodel\nPreTrainedModel\nThe model to save on termination\nrequired\n\n\n\n\n\n\n\ntrain.train(cfg, dataset_meta)\nTrain a model on the given dataset.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration dictionary with training parameters\nrequired\n\n\ndataset_meta\nTrainDatasetMeta\nObject with training, validation datasets and metadata\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntuple[PeftModel | PreTrainedModel, PreTrainedTokenizer, Trainer]\nTuple of (model, tokenizer) after training"
},
{
"objectID": "docs/api/common.architectures.html",
"href": "docs/api/common.architectures.html",
"title": "common.architectures",
"section": "",
"text": "common.architectures\ncommon.architectures\nCommon architecture specific constants"
},
{
"objectID": "docs/api/prompt_strategies.kto.llama3.html",
"href": "docs/api/prompt_strategies.kto.llama3.html",
"title": "prompt_strategies.kto.llama3",
"section": "",
"text": "prompt_strategies.kto.llama3\nKTO strategies for llama-3 chat template\n\n\n\n\n\nName\nDescription\n\n\n\n\nargilla_chat\nfor argilla/kto-mix-15k conversations\n\n\nintel\nFor Intel Orca KTO\n\n\nultra\nfor ultrafeedback binarized conversations\n\n\n\n\n\nprompt_strategies.kto.llama3.argilla_chat(cfg, **kwargs)\nfor argilla/kto-mix-15k conversations\n\n\n\nprompt_strategies.kto.llama3.intel(cfg, **kwargs)\nFor Intel Orca KTO\nex: argilla/distilabel-intel-orca-kto\n\n\n\nprompt_strategies.kto.llama3.ultra(cfg, **kwargs)\nfor ultrafeedback binarized conversations\nex: argilla/ultrafeedback-binarized-preferences-cleaned-kto"
},
{
"objectID": "docs/api/prompt_strategies.kto.llama3.html#functions",
"href": "docs/api/prompt_strategies.kto.llama3.html#functions",
"title": "prompt_strategies.kto.llama3",
"section": "",
"text": "Name\nDescription\n\n\n\n\nargilla_chat\nfor argilla/kto-mix-15k conversations\n\n\nintel\nFor Intel Orca KTO\n\n\nultra\nfor ultrafeedback binarized conversations\n\n\n\n\n\nprompt_strategies.kto.llama3.argilla_chat(cfg, **kwargs)\nfor argilla/kto-mix-15k conversations\n\n\n\nprompt_strategies.kto.llama3.intel(cfg, **kwargs)\nFor Intel Orca KTO\nex: argilla/distilabel-intel-orca-kto\n\n\n\nprompt_strategies.kto.llama3.ultra(cfg, **kwargs)\nfor ultrafeedback binarized conversations\nex: argilla/ultrafeedback-binarized-preferences-cleaned-kto"
},
{
"objectID": "docs/api/utils.callbacks.lisa.html",
"href": "docs/api/utils.callbacks.lisa.html",
"title": "utils.callbacks.lisa",
"section": "",
"text": "utils.callbacks.lisa\nutils.callbacks.lisa\nmodule for LISA\nAdapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl\nArxiv: https://arxiv.org/abs/2403.17919\nLicense: Apache 2.0"
},
{
"objectID": "docs/api/cli.utils.train.html",
"href": "docs/api/cli.utils.train.html",
"title": "cli.utils.train",
"section": "",
"text": "cli.utils.train\nUtilities for axolotl train CLI command.\n\n\n\n\n\nName\nDescription\n\n\n\n\nbuild_command\nBuild command list from base command and options.\n\n\ngenerate_config_files\nGenerate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating\n\n\nlaunch_training\nExecute training with the given configuration.\n\n\n\n\n\ncli.utils.train.build_command(base_cmd, options)\nBuild command list from base command and options.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nbase_cmd\nlist[str]\nCommand without options.\nrequired\n\n\noptions\ndict[str, Any]\nOptions to parse and append to base command.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[str]\nList of strings giving shell command.\n\n\n\n\n\n\n\ncli.utils.train.generate_config_files(config, sweep)\nGenerate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating\nwhether this is a group of configurations (i.e., a sweep).\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nstr\nBase configuration file\nrequired\n\n\nsweep\nstr | None\nSweep configuration file\nrequired\n\n\n\n\n\n\n\ncli.utils.train.launch_training(\n cfg_file,\n launcher,\n cloud,\n kwargs,\n launcher_args=None,\n use_exec=False,\n)\nExecute training with the given configuration."
},
{
"objectID": "docs/api/cli.utils.train.html#functions",
"href": "docs/api/cli.utils.train.html#functions",
"title": "cli.utils.train",
"section": "",
"text": "Name\nDescription\n\n\n\n\nbuild_command\nBuild command list from base command and options.\n\n\ngenerate_config_files\nGenerate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating\n\n\nlaunch_training\nExecute training with the given configuration.\n\n\n\n\n\ncli.utils.train.build_command(base_cmd, options)\nBuild command list from base command and options.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nbase_cmd\nlist[str]\nCommand without options.\nrequired\n\n\noptions\ndict[str, Any]\nOptions to parse and append to base command.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[str]\nList of strings giving shell command.\n\n\n\n\n\n\n\ncli.utils.train.generate_config_files(config, sweep)\nGenerate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating\nwhether this is a group of configurations (i.e., a sweep).\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nstr\nBase configuration file\nrequired\n\n\nsweep\nstr | None\nSweep configuration file\nrequired\n\n\n\n\n\n\n\ncli.utils.train.launch_training(\n cfg_file,\n launcher,\n cloud,\n kwargs,\n launcher_args=None,\n use_exec=False,\n)\nExecute training with the given configuration."
},
{
"objectID": "docs/api/integrations.liger.args.html",
"href": "docs/api/integrations.liger.args.html",
"title": "integrations.liger.args",
"section": "",
"text": "integrations.liger.args\nModule for handling LIGER input arguments.\n\n\n\n\n\nName\nDescription\n\n\n\n\nLigerArgs\nInput args for LIGER.\n\n\n\n\n\nintegrations.liger.args.LigerArgs()\nInput args for LIGER."
},
{
"objectID": "docs/api/integrations.liger.args.html#classes",
"href": "docs/api/integrations.liger.args.html#classes",
"title": "integrations.liger.args",
"section": "",
"text": "Name\nDescription\n\n\n\n\nLigerArgs\nInput args for LIGER.\n\n\n\n\n\nintegrations.liger.args.LigerArgs()\nInput args for LIGER."
},
{
"objectID": "docs/api/prompt_tokenizers.html",
"href": "docs/api/prompt_tokenizers.html",
"title": "prompt_tokenizers",
"section": "",
"text": "prompt_tokenizers\nModule containing PromptTokenizingStrategy and Prompter classes\n\n\n\n\n\nName\nDescription\n\n\n\n\nAlpacaMultipleChoicePromptTokenizingStrategy\nTokenizing strategy for Alpaca Multiple Choice prompts.\n\n\nAlpacaPromptTokenizingStrategy\nTokenizing strategy for Alpaca prompts.\n\n\nAlpacaReflectionPTStrategy\nTokenizing strategy for Alpaca Reflection prompts.\n\n\nDatasetWrappingStrategy\nAbstract class for wrapping datasets for Chat Messages\n\n\nGPTeacherPromptTokenizingStrategy\nTokenizing strategy for GPTeacher prompts.\n\n\nInstructionPromptTokenizingStrategy\nTokenizing strategy for instruction-based prompts.\n\n\nInvalidDataException\nException raised when the data is invalid\n\n\nJeopardyPromptTokenizingStrategy\nTokenizing strategy for Jeopardy prompts.\n\n\nNomicGPT4AllPromptTokenizingStrategy\nTokenizing strategy for NomicGPT4All prompts.\n\n\nOpenAssistantPromptTokenizingStrategy\nTokenizing strategy for OpenAssistant prompts.\n\n\nPromptTokenizingStrategy\nAbstract class for tokenizing strategies\n\n\nReflectionPromptTokenizingStrategy\nTokenizing strategy for Reflection prompts.\n\n\nSummarizeTLDRPromptTokenizingStrategy\nTokenizing strategy for SummarizeTLDR prompts.\n\n\n\n\n\nprompt_tokenizers.AlpacaMultipleChoicePromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Alpaca Multiple Choice prompts.\n\n\n\nprompt_tokenizers.AlpacaPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Alpaca prompts.\n\n\n\nprompt_tokenizers.AlpacaReflectionPTStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Alpaca Reflection prompts.\n\n\n\nprompt_tokenizers.DatasetWrappingStrategy()\nAbstract class for wrapping datasets for Chat Messages\n\n\n\nprompt_tokenizers.GPTeacherPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for GPTeacher prompts.\n\n\n\nprompt_tokenizers.InstructionPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for instruction-based prompts.\n\n\n\nprompt_tokenizers.InvalidDataException()\nException raised when the data is invalid\n\n\n\nprompt_tokenizers.JeopardyPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Jeopardy prompts.\n\n\n\nprompt_tokenizers.NomicGPT4AllPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for NomicGPT4All prompts.\n\n\n\nprompt_tokenizers.OpenAssistantPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for OpenAssistant prompts.\n\n\n\nprompt_tokenizers.PromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nAbstract class for tokenizing strategies\n\n\n\nprompt_tokenizers.ReflectionPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Reflection prompts.\n\n\n\nprompt_tokenizers.SummarizeTLDRPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for SummarizeTLDR prompts.\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nparse_tokenized_to_result\nParses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result\n\n\ntokenize_prompt_default\nReturns the default values for the tokenize prompt function\n\n\n\n\n\nprompt_tokenizers.parse_tokenized_to_result(\n result,\n current_len,\n res,\n labels,\n pad_token_id=None,\n)\nParses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result\n\n\n\nprompt_tokenizers.tokenize_prompt_default()\nReturns the default values for the tokenize prompt function"
},
{
"objectID": "docs/api/prompt_tokenizers.html#classes",
"href": "docs/api/prompt_tokenizers.html#classes",
"title": "prompt_tokenizers",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAlpacaMultipleChoicePromptTokenizingStrategy\nTokenizing strategy for Alpaca Multiple Choice prompts.\n\n\nAlpacaPromptTokenizingStrategy\nTokenizing strategy for Alpaca prompts.\n\n\nAlpacaReflectionPTStrategy\nTokenizing strategy for Alpaca Reflection prompts.\n\n\nDatasetWrappingStrategy\nAbstract class for wrapping datasets for Chat Messages\n\n\nGPTeacherPromptTokenizingStrategy\nTokenizing strategy for GPTeacher prompts.\n\n\nInstructionPromptTokenizingStrategy\nTokenizing strategy for instruction-based prompts.\n\n\nInvalidDataException\nException raised when the data is invalid\n\n\nJeopardyPromptTokenizingStrategy\nTokenizing strategy for Jeopardy prompts.\n\n\nNomicGPT4AllPromptTokenizingStrategy\nTokenizing strategy for NomicGPT4All prompts.\n\n\nOpenAssistantPromptTokenizingStrategy\nTokenizing strategy for OpenAssistant prompts.\n\n\nPromptTokenizingStrategy\nAbstract class for tokenizing strategies\n\n\nReflectionPromptTokenizingStrategy\nTokenizing strategy for Reflection prompts.\n\n\nSummarizeTLDRPromptTokenizingStrategy\nTokenizing strategy for SummarizeTLDR prompts.\n\n\n\n\n\nprompt_tokenizers.AlpacaMultipleChoicePromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Alpaca Multiple Choice prompts.\n\n\n\nprompt_tokenizers.AlpacaPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Alpaca prompts.\n\n\n\nprompt_tokenizers.AlpacaReflectionPTStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Alpaca Reflection prompts.\n\n\n\nprompt_tokenizers.DatasetWrappingStrategy()\nAbstract class for wrapping datasets for Chat Messages\n\n\n\nprompt_tokenizers.GPTeacherPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for GPTeacher prompts.\n\n\n\nprompt_tokenizers.InstructionPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for instruction-based prompts.\n\n\n\nprompt_tokenizers.InvalidDataException()\nException raised when the data is invalid\n\n\n\nprompt_tokenizers.JeopardyPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Jeopardy prompts.\n\n\n\nprompt_tokenizers.NomicGPT4AllPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for NomicGPT4All prompts.\n\n\n\nprompt_tokenizers.OpenAssistantPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for OpenAssistant prompts.\n\n\n\nprompt_tokenizers.PromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nAbstract class for tokenizing strategies\n\n\n\nprompt_tokenizers.ReflectionPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for Reflection prompts.\n\n\n\nprompt_tokenizers.SummarizeTLDRPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for SummarizeTLDR prompts."
},
{
"objectID": "docs/api/prompt_tokenizers.html#functions",
"href": "docs/api/prompt_tokenizers.html#functions",
"title": "prompt_tokenizers",
"section": "",
"text": "Name\nDescription\n\n\n\n\nparse_tokenized_to_result\nParses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result\n\n\ntokenize_prompt_default\nReturns the default values for the tokenize prompt function\n\n\n\n\n\nprompt_tokenizers.parse_tokenized_to_result(\n result,\n current_len,\n res,\n labels,\n pad_token_id=None,\n)\nParses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result\n\n\n\nprompt_tokenizers.tokenize_prompt_default()\nReturns the default values for the tokenize prompt function"
},
{
"objectID": "docs/api/cli.utils.sweeps.html",
"href": "docs/api/cli.utils.sweeps.html",
"title": "cli.utils.sweeps",
"section": "",
"text": "cli.utils.sweeps\nUtilities for handling sweeps over configs for axolotl train CLI command\n\n\n\n\n\nName\nDescription\n\n\n\n\ngenerate_sweep_configs\nRecursively generates all possible configurations by applying sweeps to the base config.\n\n\n\n\n\ncli.utils.sweeps.generate_sweep_configs(base_config, sweeps_config)\nRecursively generates all possible configurations by applying sweeps to the base config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nbase_config\ndict\nThe original configuration dictionary\nrequired\n\n\nsweeps_config\ndict\nDictionary where keys are parameters and values are either: - lists of values to sweep independently - or for paired values, a list of dicts under the _ key\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nlist\nlist[dict[str, Any]]\nList of all possible configuration dictionaries\n\n\n\n\n\n\nsweeps_config = {\nlearning_rate: [0.1, 0.01],\n_: [\n{load_in_8bit: True, adapter: lora},\n{load_in_4bit: True, adapter: qlora}\n]\n}"
},
{
"objectID": "docs/api/cli.utils.sweeps.html#functions",
"href": "docs/api/cli.utils.sweeps.html#functions",
"title": "cli.utils.sweeps",
"section": "",
"text": "Name\nDescription\n\n\n\n\ngenerate_sweep_configs\nRecursively generates all possible configurations by applying sweeps to the base config.\n\n\n\n\n\ncli.utils.sweeps.generate_sweep_configs(base_config, sweeps_config)\nRecursively generates all possible configurations by applying sweeps to the base config.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nbase_config\ndict\nThe original configuration dictionary\nrequired\n\n\nsweeps_config\ndict\nDictionary where keys are parameters and values are either: - lists of values to sweep independently - or for paired values, a list of dicts under the _ key\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nlist\nlist[dict[str, Any]]\nList of all possible configuration dictionaries\n\n\n\n\n\n\nsweeps_config = {\nlearning_rate: [0.1, 0.01],\n_: [\n{load_in_8bit: True, adapter: lora},\n{load_in_4bit: True, adapter: qlora}\n]\n}"
},
{
"objectID": "docs/api/cli.utils.args.html",
"href": "docs/api/cli.utils.args.html",
"title": "cli.utils.args",
"section": "",
"text": "cli.utils.args\nUtilities for axolotl CLI args.\n\n\n\n\n\nName\nDescription\n\n\n\n\nadd_options_from_config\nCreate Click options from the fields of a Pydantic model.\n\n\nadd_options_from_dataclass\nCreate Click options from the fields of a dataclass.\n\n\nfilter_none_kwargs\nWraps function to remove None-valued kwargs.\n\n\n\n\n\ncli.utils.args.add_options_from_config(config_class)\nCreate Click options from the fields of a Pydantic model.\nFor fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are\ngenerated for each sub-field (e.g., --trl.beta=0.1).\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig_class\nType[BaseModel]\nPyDantic model with fields to parse from the CLI\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nCallable\nFunction decorator for Axolotl CLI command.\n\n\n\n\n\n\n\ncli.utils.args.add_options_from_dataclass(config_class)\nCreate Click options from the fields of a dataclass.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig_class\nType[Any]\nDataclass with fields to parse from the CLI.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nCallable\nFunction decorator for Axolotl CLI command.\n\n\n\n\n\n\n\ncli.utils.args.filter_none_kwargs(func)\nWraps function to remove None-valued kwargs.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nfunc\nCallable\nFunction to wrap.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nCallable\nWrapped function."
},
{
"objectID": "docs/api/cli.utils.args.html#functions",
"href": "docs/api/cli.utils.args.html#functions",
"title": "cli.utils.args",
"section": "",
"text": "Name\nDescription\n\n\n\n\nadd_options_from_config\nCreate Click options from the fields of a Pydantic model.\n\n\nadd_options_from_dataclass\nCreate Click options from the fields of a dataclass.\n\n\nfilter_none_kwargs\nWraps function to remove None-valued kwargs.\n\n\n\n\n\ncli.utils.args.add_options_from_config(config_class)\nCreate Click options from the fields of a Pydantic model.\nFor fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are\ngenerated for each sub-field (e.g., --trl.beta=0.1).\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig_class\nType[BaseModel]\nPyDantic model with fields to parse from the CLI\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nCallable\nFunction decorator for Axolotl CLI command.\n\n\n\n\n\n\n\ncli.utils.args.add_options_from_dataclass(config_class)\nCreate Click options from the fields of a dataclass.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig_class\nType[Any]\nDataclass with fields to parse from the CLI.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nCallable\nFunction decorator for Axolotl CLI command.\n\n\n\n\n\n\n\ncli.utils.args.filter_none_kwargs(func)\nWraps function to remove None-valued kwargs.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nfunc\nCallable\nFunction to wrap.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nCallable\nWrapped function."
},
{
"objectID": "docs/api/utils.chat_templates.html",
"href": "docs/api/utils.chat_templates.html",
"title": "utils.chat_templates",
"section": "",
"text": "utils.chat_templates\nutils.chat_templates\nThis module provides functionality for selecting chat templates based on user choices.\nThese templates are used for formatting messages in a conversation."
},
{
"objectID": "docs/api/utils.schemas.config.html",
"href": "docs/api/utils.schemas.config.html",
"title": "utils.schemas.config",
"section": "",
"text": "utils.schemas.config\nModule with Pydantic models for configuration.\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlConfigWCapabilities\nWrapper to valdiate GPU capabilities with the configured options\n\n\nAxolotlInputConfig\nWrapper of all config options.\n\n\nEBFTConfig\nConfiguration for Energy-Based Fine-Tuning (EBFT)\n\n\n\n\n\nutils.schemas.config.AxolotlConfigWCapabilities()\nWrapper to valdiate GPU capabilities with the configured options\n\n\n\nutils.schemas.config.AxolotlInputConfig()\nWrapper of all config options.\n\n\n\nutils.schemas.config.EBFTConfig()\nConfiguration for Energy-Based Fine-Tuning (EBFT)"
},
{
"objectID": "docs/api/utils.schemas.config.html#classes",
"href": "docs/api/utils.schemas.config.html#classes",
"title": "utils.schemas.config",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlConfigWCapabilities\nWrapper to valdiate GPU capabilities with the configured options\n\n\nAxolotlInputConfig\nWrapper of all config options.\n\n\nEBFTConfig\nConfiguration for Energy-Based Fine-Tuning (EBFT)\n\n\n\n\n\nutils.schemas.config.AxolotlConfigWCapabilities()\nWrapper to valdiate GPU capabilities with the configured options\n\n\n\nutils.schemas.config.AxolotlInputConfig()\nWrapper of all config options.\n\n\n\nutils.schemas.config.EBFTConfig()\nConfiguration for Energy-Based Fine-Tuning (EBFT)"
},
{
"objectID": "docs/api/prompt_strategies.user_defined.html",
"href": "docs/api/prompt_strategies.user_defined.html",
"title": "prompt_strategies.user_defined",
"section": "",
"text": "prompt_strategies.user_defined\nUser Defined prompts with configuration from the YML config\n\n\n\n\n\nName\nDescription\n\n\n\n\nUserDefinedDatasetConfig\ndataclass configuration representing a userdefined dataset type\n\n\nUserDefinedPromptTokenizationStrategy\nPrompt Tokenization Strategy for user defined prompts\n\n\n\n\n\nprompt_strategies.user_defined.UserDefinedDatasetConfig(\n system_prompt='',\n field_system='system',\n field_instruction='instruction',\n field_input='input',\n field_output='output',\n format='{instruction} {input} ',\n no_input_format='{instruction} ',\n system_format='{system}',\n)\ndataclass configuration representing a userdefined dataset type\n\n\n\nprompt_strategies.user_defined.UserDefinedPromptTokenizationStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nPrompt Tokenization Strategy for user defined prompts"
},
{
"objectID": "docs/api/prompt_strategies.user_defined.html#classes",
"href": "docs/api/prompt_strategies.user_defined.html#classes",
"title": "prompt_strategies.user_defined",
"section": "",
"text": "Name\nDescription\n\n\n\n\nUserDefinedDatasetConfig\ndataclass configuration representing a userdefined dataset type\n\n\nUserDefinedPromptTokenizationStrategy\nPrompt Tokenization Strategy for user defined prompts\n\n\n\n\n\nprompt_strategies.user_defined.UserDefinedDatasetConfig(\n system_prompt='',\n field_system='system',\n field_instruction='instruction',\n field_input='input',\n field_output='output',\n format='{instruction} {input} ',\n no_input_format='{instruction} ',\n system_format='{system}',\n)\ndataclass configuration representing a userdefined dataset type\n\n\n\nprompt_strategies.user_defined.UserDefinedPromptTokenizationStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nPrompt Tokenization Strategy for user defined prompts"
},
{
"objectID": "docs/api/utils.schemas.datasets.html",
"href": "docs/api/utils.schemas.datasets.html",
"title": "utils.schemas.datasets",
"section": "",
"text": "utils.schemas.datasets\nPydantic models for datasets-related configuration\n\n\n\n\n\nName\nDescription\n\n\n\n\nDPODataset\nDPO configuration subset\n\n\nKTODataset\nKTO configuration subset\n\n\nPretrainingDataset\nPretraining dataset configuration subset\n\n\nSFTDataset\nSFT configuration subset\n\n\nStepwiseSupervisedDataset\nStepwise supervised dataset configuration subset\n\n\nSyntheticDataset\nSynthetic dataset configuration for benchmarking and testing.\n\n\nUserDefinedDPOType\nUser defined typing for DPO\n\n\nUserDefinedKTOType\nUser defined typing for KTO\n\n\nUserDefinedPrompterType\nStructure for user defined prompt types\n\n\n\n\n\nutils.schemas.datasets.DPODataset()\nDPO configuration subset\n\n\n\nutils.schemas.datasets.KTODataset()\nKTO configuration subset\n\n\n\nutils.schemas.datasets.PretrainingDataset()\nPretraining dataset configuration subset\n\n\n\nutils.schemas.datasets.SFTDataset()\nSFT configuration subset\n\n\n\n\n\nName\nDescription\n\n\n\n\nhandle_legacy_message_fields\nHandle backwards compatibility between legacy message field mapping and new property mapping system.\n\n\n\n\n\nutils.schemas.datasets.SFTDataset.handle_legacy_message_fields(data)\nHandle backwards compatibility between legacy message field mapping and new property mapping system.\n\n\n\n\n\nutils.schemas.datasets.StepwiseSupervisedDataset()\nStepwise supervised dataset configuration subset\n\n\n\nutils.schemas.datasets.SyntheticDataset()\nSynthetic dataset configuration for benchmarking and testing.\nGenerates datasets with configurable sequence length, dataset size, and token ID\nranges. Useful for benchmarking memory usage and speed by sequence length, and for\nvalidating weighted dataset mixes.\n\n\n\nutils.schemas.datasets.UserDefinedDPOType()\nUser defined typing for DPO\n\n\n\nutils.schemas.datasets.UserDefinedKTOType()\nUser defined typing for KTO\n\n\n\nutils.schemas.datasets.UserDefinedPrompterType()\nStructure for user defined prompt types"
},
{
"objectID": "docs/api/utils.schemas.datasets.html#classes",
"href": "docs/api/utils.schemas.datasets.html#classes",
"title": "utils.schemas.datasets",
"section": "",
"text": "Name\nDescription\n\n\n\n\nDPODataset\nDPO configuration subset\n\n\nKTODataset\nKTO configuration subset\n\n\nPretrainingDataset\nPretraining dataset configuration subset\n\n\nSFTDataset\nSFT configuration subset\n\n\nStepwiseSupervisedDataset\nStepwise supervised dataset configuration subset\n\n\nSyntheticDataset\nSynthetic dataset configuration for benchmarking and testing.\n\n\nUserDefinedDPOType\nUser defined typing for DPO\n\n\nUserDefinedKTOType\nUser defined typing for KTO\n\n\nUserDefinedPrompterType\nStructure for user defined prompt types\n\n\n\n\n\nutils.schemas.datasets.DPODataset()\nDPO configuration subset\n\n\n\nutils.schemas.datasets.KTODataset()\nKTO configuration subset\n\n\n\nutils.schemas.datasets.PretrainingDataset()\nPretraining dataset configuration subset\n\n\n\nutils.schemas.datasets.SFTDataset()\nSFT configuration subset\n\n\n\n\n\nName\nDescription\n\n\n\n\nhandle_legacy_message_fields\nHandle backwards compatibility between legacy message field mapping and new property mapping system.\n\n\n\n\n\nutils.schemas.datasets.SFTDataset.handle_legacy_message_fields(data)\nHandle backwards compatibility between legacy message field mapping and new property mapping system.\n\n\n\n\n\nutils.schemas.datasets.StepwiseSupervisedDataset()\nStepwise supervised dataset configuration subset\n\n\n\nutils.schemas.datasets.SyntheticDataset()\nSynthetic dataset configuration for benchmarking and testing.\nGenerates datasets with configurable sequence length, dataset size, and token ID\nranges. Useful for benchmarking memory usage and speed by sequence length, and for\nvalidating weighted dataset mixes.\n\n\n\nutils.schemas.datasets.UserDefinedDPOType()\nUser defined typing for DPO\n\n\n\nutils.schemas.datasets.UserDefinedKTOType()\nUser defined typing for KTO\n\n\n\nutils.schemas.datasets.UserDefinedPrompterType()\nStructure for user defined prompt types"
},
{
"objectID": "docs/api/integrations.base.html",
"href": "docs/api/integrations.base.html",
"title": "integrations.base",
"section": "",
"text": "integrations.base\nBase class for all plugins.\nA plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.\nPlugins can be used to integrate third-party models, modify the training process, or add new features.\nTo create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.\n\n\n\n\n\nName\nDescription\n\n\n\n\nBaseOptimizerFactory\nBase class for factories to create custom optimizers\n\n\nBasePlugin\nBase class for all plugins. Defines the interface for plugin methods.\n\n\nPluginManager\nThe PluginManager class is responsible for loading and managing plugins. It\n\n\n\n\n\nintegrations.base.BaseOptimizerFactory()\nBase class for factories to create custom optimizers\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_decay_parameter_names\nGet all parameter names that weight decay will be applied to.\n\n\n\n\n\nintegrations.base.BaseOptimizerFactory.get_decay_parameter_names(model)\nGet all parameter names that weight decay will be applied to.\nThis function filters out parameters in two ways:\n1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)\n2. By parameter name patterns (containing bias, or variation of norm)\n\n\n\n\n\nintegrations.base.BasePlugin()\nBase class for all plugins. Defines the interface for plugin methods.\nA plugin is a reusable, modular, and self-contained piece of code that extends\nthe functionality of Axolotl. Plugins can be used to integrate third-party models,\nmodify the training process, or add new features.\nTo create a new plugin, you need to inherit from the BasePlugin class and\nimplement the required methods.\n\n\nPlugin methods include:\n- register(cfg): Registers the plugin with the given configuration.\n- load_datasets(cfg): Loads and preprocesses the dataset for training.\n- pre_model_load(cfg): Performs actions before the model is loaded.\n- post_model_build(cfg, model): Performs actions after the model is loaded, but\nbefore LoRA adapters are applied.\n- pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.\n- post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.\n- post_model_load(cfg, model): Performs actions after the model is loaded,\ninclusive of any adapters.\n- post_trainer_create(cfg, trainer): Performs actions after the trainer is\ncreated.\n- create_optimizer(cfg, trainer): Creates and returns an optimizer for training.\n- create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and\nreturns a learning rate scheduler.\n- add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before\ntraining.\n- add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after\ntraining.\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nadd_callbacks_post_trainer\nAdds callbacks to the trainer after creating the trainer. This is useful for\n\n\nadd_callbacks_pre_trainer\nSet up callbacks before creating the trainer.\n\n\ncreate_lr_scheduler\nCreates and returns a learning rate scheduler.\n\n\ncreate_optimizer\nCreates and returns an optimizer for training.\n\n\nget_collator_cls_and_kwargs\nReturns a custom class for the collator.\n\n\nget_input_args\nReturns a pydantic model for the plugins input arguments.\n\n\nget_trainer_cls\nReturns a custom class for the trainer.\n\n\nget_training_args\nReturns custom training arguments to set on TrainingArgs.\n\n\nget_training_args_mixin\nReturns a dataclass model for the plugins training arguments.\n\n\nload_datasets\nLoads and preprocesses the dataset for training.\n\n\non_rollouts_scored\nCalled after rollouts are scored during online RL (GRPO/PPO).\n\n\npost_lora_load\nPerforms actions after LoRA weights are loaded.\n\n\npost_model_build\nPerforms actions after the model is built/loaded, but before any adapters are applied.\n\n\npost_model_load\nPerforms actions after the model is loaded.\n\n\npost_train\nPerforms actions after training is complete.\n\n\npost_train_unload\nPerforms actions after training is complete and the model is unloaded.\n\n\npost_trainer_create\nPerforms actions after the trainer is created.\n\n\npre_lora_load\nPerforms actions before LoRA weights are loaded.\n\n\npre_model_load\nPerforms actions before the model is loaded.\n\n\nregister\nRegisters the plugin with the given configuration as an unparsed dict.\n\n\n\n\n\nintegrations.base.BasePlugin.add_callbacks_post_trainer(cfg, trainer)\nAdds callbacks to the trainer after creating the trainer. This is useful for\ncallbacks that require access to the model or trainer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[Callable]\nA list of callback functions to be added\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.add_callbacks_pre_trainer(cfg, model)\nSet up callbacks before creating the trainer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[Callable]\nA list of callback functions to be added to the TrainingArgs.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.create_lr_scheduler(\n cfg,\n trainer,\n optimizer,\n num_training_steps,\n)\nCreates and returns a learning rate scheduler.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\noptimizer\nOptimizer\nThe optimizer for training.\nrequired\n\n\nnum_training_steps\nint\nTotal number of training steps\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nLRScheduler | None\nThe created learning rate scheduler.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.create_optimizer(cfg, trainer)\nCreates and returns an optimizer for training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nOptimizer | None\nThe created optimizer.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.get_collator_cls_and_kwargs(cfg, is_eval=False)\nReturns a custom class for the collator.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe global axolotl configuration.\nrequired\n\n\nis_eval\nbool\nWhether this is an eval split.\nFalse\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nclass\n\nThe class for the collator.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.get_input_args()\nReturns a pydantic model for the plugins input arguments.\n\n\n\nintegrations.base.BasePlugin.get_trainer_cls(cfg)\nReturns a custom class for the trainer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe global axolotl configuration.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntype[Trainer] | None\nThe first non-None trainer class returned by a plugin.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.get_training_args(cfg)\nReturns custom training arguments to set on TrainingArgs.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe global axolotl configuration.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nobject\n\ndict containing the training arguments.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.get_training_args_mixin()\nReturns a dataclass model for the plugins training arguments.\n\n\n\nintegrations.base.BasePlugin.load_datasets(cfg, preprocess=False)\nLoads and preprocesses the dataset for training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\npreprocess\nbool\nWhether this is the preprocess step of the datasets.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\ndataset_meta\nUnion['TrainDatasetMeta', None]\nThe metadata for the training dataset.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.on_rollouts_scored(\n cfg,\n trainer,\n prompts,\n completions,\n rewards,\n advantages,\n)\nCalled after rollouts are scored during online RL (GRPO/PPO).\nProvides access to the full scored rollout data for logging, trace\nstorage, or analysis. Called once per scoring step with all samples\nfrom that step.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe axolotl configuration.\nrequired\n\n\ntrainer\n\nThe trainer instance.\nrequired\n\n\nprompts\nlist[str]\nList of prompt texts (one per sample).\nrequired\n\n\ncompletions\nlist[str]\nList of completion texts (one per sample).\nrequired\n\n\nrewards\ndict[str, list[float]]\nDict mapping reward function name to list of reward values.\nrequired\n\n\nadvantages\nlist[float]\nList of advantage values (one per sample).\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_lora_load(cfg, model)\nPerforms actions after LoRA weights are loaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_model_build(cfg, model)\nPerforms actions after the model is built/loaded, but before any adapters are applied.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_model_load(cfg, model)\nPerforms actions after the model is loaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_train(cfg, model)\nPerforms actions after training is complete.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe axolotl configuration.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_train_unload(cfg)\nPerforms actions after training is complete and the model is unloaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_trainer_create(cfg, trainer)\nPerforms actions after the trainer is created.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.pre_lora_load(cfg, model)\nPerforms actions before LoRA weights are loaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.pre_model_load(cfg)\nPerforms actions before the model is loaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.register(cfg)\nRegisters the plugin with the given configuration as an unparsed dict.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\ndict\nThe configuration for the plugin.\nrequired\n\n\n\n\n\n\n\n\n\nintegrations.base.PluginManager()\nThe PluginManager class is responsible for loading and managing plugins. It\nshould be a singleton so it can be accessed from anywhere in the codebase.\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nplugins\nOrderedDict[str, BasePlugin]\nA list of loaded plugins.\n\n\n\n\n\n\nKey methods include:\n- get_instance(): Static method to get the singleton instance of PluginManager.\n- register(plugin_name: str): Registers a new plugin by its name.\n- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nadd_callbacks_post_trainer\nCalls the add_callbacks_post_trainer method of all registered plugins.\n\n\nadd_callbacks_pre_trainer\nCalls the add_callbacks_pre_trainer method of all registered plugins.\n\n\ncreate_lr_scheduler\nCalls the create_lr_scheduler method of all registered plugins and returns\n\n\ncreate_optimizer\nCalls the create_optimizer method of all registered plugins and returns\n\n\nget_collator_cls_and_kwargs\nCalls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.\n\n\nget_input_args\nReturns a list of Pydantic classes for all registered plugins input arguments.\n\n\nget_instance\nReturns the singleton instance of PluginManager. If the instance doesnt\n\n\nget_trainer_cls\nCalls the get_trainer_cls method of all registered plugins and returns the\n\n\nget_training_args\nCalls the get_training_args method of all registered plugins and returns the combined training arguments.\n\n\nget_training_args_mixin\nReturns a list of dataclasses for all registered plugins training args mixins\n\n\nload_datasets\nCalls the load_datasets method of each registered plugin.\n\n\non_rollouts_scored\nCalls the on_rollouts_scored method of all registered plugins.\n\n\npost_lora_load\nCalls the post_lora_load method of all registered plugins.\n\n\npost_model_build\nCalls the post_model_build method of all registered plugins after the\n\n\npost_model_load\nCalls the post_model_load method of all registered plugins after the model\n\n\npost_train\nCalls the post_train method of all registered plugins.\n\n\npost_train_unload\nCalls the post_train_unload method of all registered plugins.\n\n\npost_trainer_create\nCalls the post_trainer_create method of all registered plugins.\n\n\npre_lora_load\nCalls the pre_lora_load method of all registered plugins.\n\n\npre_model_load\nCalls the pre_model_load method of all registered plugins.\n\n\nregister\nRegisters a new plugin by its name.\n\n\n\n\n\nintegrations.base.PluginManager.add_callbacks_post_trainer(cfg, trainer)\nCalls the add_callbacks_post_trainer method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[Callable]\nA list of callback functions to be added to the TrainingArgs.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.add_callbacks_pre_trainer(cfg, model)\nCalls the add_callbacks_pre_trainer method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[Callable]\nA list of callback functions to be added to the TrainingArgs.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.create_lr_scheduler(\n trainer,\n optimizer,\n num_training_steps,\n)\nCalls the create_lr_scheduler method of all registered plugins and returns\nthe first non-None scheduler.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\noptimizer\nOptimizer\nThe optimizer for training.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nLRScheduler | None\nThe created learning rate scheduler, or None if not found.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.create_optimizer(trainer)\nCalls the create_optimizer method of all registered plugins and returns\nthe first non-None optimizer.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nOptimizer | None\nThe created optimizer, or None if none was found.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.get_collator_cls_and_kwargs(cfg, is_eval=False)\nCalls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.\nParameters:\ncfg (dict): The configuration for the plugins.\nis_eval (bool): Whether this is an eval split.\nReturns:\nobject: The collator class, or None if none was found.\n\n\n\nintegrations.base.PluginManager.get_input_args()\nReturns a list of Pydantic classes for all registered plugins input arguments.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[str]\nA list of Pydantic classes for all registered plugins input arguments.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.get_instance()\nReturns the singleton instance of PluginManager. If the instance doesnt\nexist, it creates a new one.\n\n\n\nintegrations.base.PluginManager.get_trainer_cls(cfg)\nCalls the get_trainer_cls method of all registered plugins and returns the\nfirst non-None trainer class.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nTrainer | None\nThe first non-None trainer class returned by a plugin.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.get_training_args(cfg)\nCalls the get_training_args method of all registered plugins and returns the combined training arguments.\nParameters:\ncfg (dict): The configuration for the plugins.\nReturns:\nobject: The training arguments\n\n\n\nintegrations.base.PluginManager.get_training_args_mixin()\nReturns a list of dataclasses for all registered plugins training args mixins\nReturns:\nlist[str]: A list of dataclsses\n\n\n\nintegrations.base.PluginManager.load_datasets(cfg, preprocess=False)\nCalls the load_datasets method of each registered plugin.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\npreprocess\nbool\nWhether this is preprocess step of the datasets.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nUnion['TrainDatasetMeta', None]\nThe dataset metadata loaded from all registered plugins.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.on_rollouts_scored(\n cfg,\n trainer,\n prompts,\n completions,\n rewards,\n advantages,\n)\nCalls the on_rollouts_scored method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\ntrainer\n\nThe trainer instance.\nrequired\n\n\nprompts\nlist[str]\nList of prompt texts.\nrequired\n\n\ncompletions\nlist[str]\nList of completion texts.\nrequired\n\n\nrewards\ndict[str, list[float]]\nDict mapping reward function name to list of rewards.\nrequired\n\n\nadvantages\nlist[float]\nList of advantage values.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_lora_load(cfg, model)\nCalls the post_lora_load method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_model_build(cfg, model)\nCalls the post_model_build method of all registered plugins after the\nmodel has been built / loaded, but before any adapters have been applied.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_model_load(cfg, model)\nCalls the post_model_load method of all registered plugins after the model\nhas been loaded inclusive of any adapters.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_train(cfg, model)\nCalls the post_train method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_train_unload(cfg)\nCalls the post_train_unload method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_trainer_create(cfg, trainer)\nCalls the post_trainer_create method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.pre_lora_load(cfg, model)\nCalls the pre_lora_load method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.pre_model_load(cfg)\nCalls the pre_model_load method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.register(plugin_name)\nRegisters a new plugin by its name.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nplugin_name\nstr\nThe name of the plugin to be registered.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nImportError\nIf the plugin module cannot be imported.\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nload_plugin\nLoads a plugin based on the given plugin name.\n\n\n\n\n\nintegrations.base.load_plugin(plugin_name)\nLoads a plugin based on the given plugin name.\nThe plugin name should be in the format “module_name.class_name”. This function\nsplits the plugin name into module and class, imports the module, retrieves the\nclass from the module, and creates an instance of the class.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nplugin_name\nstr\nThe name of the plugin to be loaded. The name should be in the format “module_name.class_name”.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nBasePlugin\nAn instance of the loaded plugin.\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nImportError\nIf the plugin module cannot be imported."
},
{
"objectID": "docs/api/integrations.base.html#classes",
"href": "docs/api/integrations.base.html#classes",
"title": "integrations.base",
"section": "",
"text": "Name\nDescription\n\n\n\n\nBaseOptimizerFactory\nBase class for factories to create custom optimizers\n\n\nBasePlugin\nBase class for all plugins. Defines the interface for plugin methods.\n\n\nPluginManager\nThe PluginManager class is responsible for loading and managing plugins. It\n\n\n\n\n\nintegrations.base.BaseOptimizerFactory()\nBase class for factories to create custom optimizers\n\n\n\n\n\nName\nDescription\n\n\n\n\nget_decay_parameter_names\nGet all parameter names that weight decay will be applied to.\n\n\n\n\n\nintegrations.base.BaseOptimizerFactory.get_decay_parameter_names(model)\nGet all parameter names that weight decay will be applied to.\nThis function filters out parameters in two ways:\n1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)\n2. By parameter name patterns (containing bias, or variation of norm)\n\n\n\n\n\nintegrations.base.BasePlugin()\nBase class for all plugins. Defines the interface for plugin methods.\nA plugin is a reusable, modular, and self-contained piece of code that extends\nthe functionality of Axolotl. Plugins can be used to integrate third-party models,\nmodify the training process, or add new features.\nTo create a new plugin, you need to inherit from the BasePlugin class and\nimplement the required methods.\n\n\nPlugin methods include:\n- register(cfg): Registers the plugin with the given configuration.\n- load_datasets(cfg): Loads and preprocesses the dataset for training.\n- pre_model_load(cfg): Performs actions before the model is loaded.\n- post_model_build(cfg, model): Performs actions after the model is loaded, but\nbefore LoRA adapters are applied.\n- pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.\n- post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.\n- post_model_load(cfg, model): Performs actions after the model is loaded,\ninclusive of any adapters.\n- post_trainer_create(cfg, trainer): Performs actions after the trainer is\ncreated.\n- create_optimizer(cfg, trainer): Creates and returns an optimizer for training.\n- create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and\nreturns a learning rate scheduler.\n- add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before\ntraining.\n- add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after\ntraining.\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nadd_callbacks_post_trainer\nAdds callbacks to the trainer after creating the trainer. This is useful for\n\n\nadd_callbacks_pre_trainer\nSet up callbacks before creating the trainer.\n\n\ncreate_lr_scheduler\nCreates and returns a learning rate scheduler.\n\n\ncreate_optimizer\nCreates and returns an optimizer for training.\n\n\nget_collator_cls_and_kwargs\nReturns a custom class for the collator.\n\n\nget_input_args\nReturns a pydantic model for the plugins input arguments.\n\n\nget_trainer_cls\nReturns a custom class for the trainer.\n\n\nget_training_args\nReturns custom training arguments to set on TrainingArgs.\n\n\nget_training_args_mixin\nReturns a dataclass model for the plugins training arguments.\n\n\nload_datasets\nLoads and preprocesses the dataset for training.\n\n\non_rollouts_scored\nCalled after rollouts are scored during online RL (GRPO/PPO).\n\n\npost_lora_load\nPerforms actions after LoRA weights are loaded.\n\n\npost_model_build\nPerforms actions after the model is built/loaded, but before any adapters are applied.\n\n\npost_model_load\nPerforms actions after the model is loaded.\n\n\npost_train\nPerforms actions after training is complete.\n\n\npost_train_unload\nPerforms actions after training is complete and the model is unloaded.\n\n\npost_trainer_create\nPerforms actions after the trainer is created.\n\n\npre_lora_load\nPerforms actions before LoRA weights are loaded.\n\n\npre_model_load\nPerforms actions before the model is loaded.\n\n\nregister\nRegisters the plugin with the given configuration as an unparsed dict.\n\n\n\n\n\nintegrations.base.BasePlugin.add_callbacks_post_trainer(cfg, trainer)\nAdds callbacks to the trainer after creating the trainer. This is useful for\ncallbacks that require access to the model or trainer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[Callable]\nA list of callback functions to be added\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.add_callbacks_pre_trainer(cfg, model)\nSet up callbacks before creating the trainer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[Callable]\nA list of callback functions to be added to the TrainingArgs.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.create_lr_scheduler(\n cfg,\n trainer,\n optimizer,\n num_training_steps,\n)\nCreates and returns a learning rate scheduler.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\noptimizer\nOptimizer\nThe optimizer for training.\nrequired\n\n\nnum_training_steps\nint\nTotal number of training steps\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nLRScheduler | None\nThe created learning rate scheduler.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.create_optimizer(cfg, trainer)\nCreates and returns an optimizer for training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nOptimizer | None\nThe created optimizer.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.get_collator_cls_and_kwargs(cfg, is_eval=False)\nReturns a custom class for the collator.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe global axolotl configuration.\nrequired\n\n\nis_eval\nbool\nWhether this is an eval split.\nFalse\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nclass\n\nThe class for the collator.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.get_input_args()\nReturns a pydantic model for the plugins input arguments.\n\n\n\nintegrations.base.BasePlugin.get_trainer_cls(cfg)\nReturns a custom class for the trainer.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe global axolotl configuration.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\ntype[Trainer] | None\nThe first non-None trainer class returned by a plugin.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.get_training_args(cfg)\nReturns custom training arguments to set on TrainingArgs.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe global axolotl configuration.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nobject\n\ndict containing the training arguments.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.get_training_args_mixin()\nReturns a dataclass model for the plugins training arguments.\n\n\n\nintegrations.base.BasePlugin.load_datasets(cfg, preprocess=False)\nLoads and preprocesses the dataset for training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\npreprocess\nbool\nWhether this is the preprocess step of the datasets.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\ndataset_meta\nUnion['TrainDatasetMeta', None]\nThe metadata for the training dataset.\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.on_rollouts_scored(\n cfg,\n trainer,\n prompts,\n completions,\n rewards,\n advantages,\n)\nCalled after rollouts are scored during online RL (GRPO/PPO).\nProvides access to the full scored rollout data for logging, trace\nstorage, or analysis. Called once per scoring step with all samples\nfrom that step.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe axolotl configuration.\nrequired\n\n\ntrainer\n\nThe trainer instance.\nrequired\n\n\nprompts\nlist[str]\nList of prompt texts (one per sample).\nrequired\n\n\ncompletions\nlist[str]\nList of completion texts (one per sample).\nrequired\n\n\nrewards\ndict[str, list[float]]\nDict mapping reward function name to list of reward values.\nrequired\n\n\nadvantages\nlist[float]\nList of advantage values (one per sample).\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_lora_load(cfg, model)\nPerforms actions after LoRA weights are loaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_model_build(cfg, model)\nPerforms actions after the model is built/loaded, but before any adapters are applied.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_model_load(cfg, model)\nPerforms actions after the model is loaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_train(cfg, model)\nPerforms actions after training is complete.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe axolotl configuration.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_train_unload(cfg)\nPerforms actions after training is complete and the model is unloaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.post_trainer_create(cfg, trainer)\nPerforms actions after the trainer is created.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.pre_lora_load(cfg, model)\nPerforms actions before LoRA weights are loaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.pre_model_load(cfg)\nPerforms actions before the model is loaded.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugin.\nrequired\n\n\n\n\n\n\n\nintegrations.base.BasePlugin.register(cfg)\nRegisters the plugin with the given configuration as an unparsed dict.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\ndict\nThe configuration for the plugin.\nrequired\n\n\n\n\n\n\n\n\n\nintegrations.base.PluginManager()\nThe PluginManager class is responsible for loading and managing plugins. It\nshould be a singleton so it can be accessed from anywhere in the codebase.\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nplugins\nOrderedDict[str, BasePlugin]\nA list of loaded plugins.\n\n\n\n\n\n\nKey methods include:\n- get_instance(): Static method to get the singleton instance of PluginManager.\n- register(plugin_name: str): Registers a new plugin by its name.\n- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nadd_callbacks_post_trainer\nCalls the add_callbacks_post_trainer method of all registered plugins.\n\n\nadd_callbacks_pre_trainer\nCalls the add_callbacks_pre_trainer method of all registered plugins.\n\n\ncreate_lr_scheduler\nCalls the create_lr_scheduler method of all registered plugins and returns\n\n\ncreate_optimizer\nCalls the create_optimizer method of all registered plugins and returns\n\n\nget_collator_cls_and_kwargs\nCalls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.\n\n\nget_input_args\nReturns a list of Pydantic classes for all registered plugins input arguments.\n\n\nget_instance\nReturns the singleton instance of PluginManager. If the instance doesnt\n\n\nget_trainer_cls\nCalls the get_trainer_cls method of all registered plugins and returns the\n\n\nget_training_args\nCalls the get_training_args method of all registered plugins and returns the combined training arguments.\n\n\nget_training_args_mixin\nReturns a list of dataclasses for all registered plugins training args mixins\n\n\nload_datasets\nCalls the load_datasets method of each registered plugin.\n\n\non_rollouts_scored\nCalls the on_rollouts_scored method of all registered plugins.\n\n\npost_lora_load\nCalls the post_lora_load method of all registered plugins.\n\n\npost_model_build\nCalls the post_model_build method of all registered plugins after the\n\n\npost_model_load\nCalls the post_model_load method of all registered plugins after the model\n\n\npost_train\nCalls the post_train method of all registered plugins.\n\n\npost_train_unload\nCalls the post_train_unload method of all registered plugins.\n\n\npost_trainer_create\nCalls the post_trainer_create method of all registered plugins.\n\n\npre_lora_load\nCalls the pre_lora_load method of all registered plugins.\n\n\npre_model_load\nCalls the pre_model_load method of all registered plugins.\n\n\nregister\nRegisters a new plugin by its name.\n\n\n\n\n\nintegrations.base.PluginManager.add_callbacks_post_trainer(cfg, trainer)\nCalls the add_callbacks_post_trainer method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[Callable]\nA list of callback functions to be added to the TrainingArgs.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.add_callbacks_pre_trainer(cfg, model)\nCalls the add_callbacks_pre_trainer method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[Callable]\nA list of callback functions to be added to the TrainingArgs.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.create_lr_scheduler(\n trainer,\n optimizer,\n num_training_steps,\n)\nCalls the create_lr_scheduler method of all registered plugins and returns\nthe first non-None scheduler.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\noptimizer\nOptimizer\nThe optimizer for training.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nLRScheduler | None\nThe created learning rate scheduler, or None if not found.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.create_optimizer(trainer)\nCalls the create_optimizer method of all registered plugins and returns\nthe first non-None optimizer.\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nOptimizer | None\nThe created optimizer, or None if none was found.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.get_collator_cls_and_kwargs(cfg, is_eval=False)\nCalls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.\nParameters:\ncfg (dict): The configuration for the plugins.\nis_eval (bool): Whether this is an eval split.\nReturns:\nobject: The collator class, or None if none was found.\n\n\n\nintegrations.base.PluginManager.get_input_args()\nReturns a list of Pydantic classes for all registered plugins input arguments.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[str]\nA list of Pydantic classes for all registered plugins input arguments.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.get_instance()\nReturns the singleton instance of PluginManager. If the instance doesnt\nexist, it creates a new one.\n\n\n\nintegrations.base.PluginManager.get_trainer_cls(cfg)\nCalls the get_trainer_cls method of all registered plugins and returns the\nfirst non-None trainer class.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nTrainer | None\nThe first non-None trainer class returned by a plugin.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.get_training_args(cfg)\nCalls the get_training_args method of all registered plugins and returns the combined training arguments.\nParameters:\ncfg (dict): The configuration for the plugins.\nReturns:\nobject: The training arguments\n\n\n\nintegrations.base.PluginManager.get_training_args_mixin()\nReturns a list of dataclasses for all registered plugins training args mixins\nReturns:\nlist[str]: A list of dataclsses\n\n\n\nintegrations.base.PluginManager.load_datasets(cfg, preprocess=False)\nCalls the load_datasets method of each registered plugin.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\npreprocess\nbool\nWhether this is preprocess step of the datasets.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nUnion['TrainDatasetMeta', None]\nThe dataset metadata loaded from all registered plugins.\n\n\n\n\n\n\n\nintegrations.base.PluginManager.on_rollouts_scored(\n cfg,\n trainer,\n prompts,\n completions,\n rewards,\n advantages,\n)\nCalls the on_rollouts_scored method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\ntrainer\n\nThe trainer instance.\nrequired\n\n\nprompts\nlist[str]\nList of prompt texts.\nrequired\n\n\ncompletions\nlist[str]\nList of completion texts.\nrequired\n\n\nrewards\ndict[str, list[float]]\nDict mapping reward function name to list of rewards.\nrequired\n\n\nadvantages\nlist[float]\nList of advantage values.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_lora_load(cfg, model)\nCalls the post_lora_load method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_model_build(cfg, model)\nCalls the post_model_build method of all registered plugins after the\nmodel has been built / loaded, but before any adapters have been applied.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_model_load(cfg, model)\nCalls the post_model_load method of all registered plugins after the model\nhas been loaded inclusive of any adapters.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_train(cfg, model)\nCalls the post_train method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel | PeftModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_train_unload(cfg)\nCalls the post_train_unload method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.post_trainer_create(cfg, trainer)\nCalls the post_trainer_create method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\ntrainer\nTrainer\nThe trainer object for training.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.pre_lora_load(cfg, model)\nCalls the pre_lora_load method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\nmodel\nPreTrainedModel\nThe loaded model.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.pre_model_load(cfg)\nCalls the pre_model_load method of all registered plugins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ncfg\nDictDefault\nThe configuration for the plugins.\nrequired\n\n\n\n\n\n\n\nintegrations.base.PluginManager.register(plugin_name)\nRegisters a new plugin by its name.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nplugin_name\nstr\nThe name of the plugin to be registered.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nImportError\nIf the plugin module cannot be imported."
},
{
"objectID": "docs/api/integrations.base.html#functions",
"href": "docs/api/integrations.base.html#functions",
"title": "integrations.base",
"section": "",
"text": "Name\nDescription\n\n\n\n\nload_plugin\nLoads a plugin based on the given plugin name.\n\n\n\n\n\nintegrations.base.load_plugin(plugin_name)\nLoads a plugin based on the given plugin name.\nThe plugin name should be in the format “module_name.class_name”. This function\nsplits the plugin name into module and class, imports the module, retrieves the\nclass from the module, and creates an instance of the class.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nplugin_name\nstr\nThe name of the plugin to be loaded. The name should be in the format “module_name.class_name”.\nrequired\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nBasePlugin\nAn instance of the loaded plugin.\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nImportError\nIf the plugin module cannot be imported."
},
{
"objectID": "docs/api/utils.tokenization.html",
"href": "docs/api/utils.tokenization.html",
"title": "utils.tokenization",
"section": "",
"text": "utils.tokenization\nModule for tokenization utilities\n\n\n\n\n\nName\nDescription\n\n\n\n\ncolor_token_for_rl_debug\nHelper function to color tokens based on their type.\n\n\nprocess_tokens_for_rl_debug\nHelper function to process and color tokens.\n\n\n\n\n\nutils.tokenization.color_token_for_rl_debug(\n decoded_token,\n encoded_token,\n color,\n text_only,\n)\nHelper function to color tokens based on their type.\n\n\n\nutils.tokenization.process_tokens_for_rl_debug(\n tokens,\n color,\n tokenizer,\n text_only,\n)\nHelper function to process and color tokens."
},
{
"objectID": "docs/api/utils.tokenization.html#functions",
"href": "docs/api/utils.tokenization.html#functions",
"title": "utils.tokenization",
"section": "",
"text": "Name\nDescription\n\n\n\n\ncolor_token_for_rl_debug\nHelper function to color tokens based on their type.\n\n\nprocess_tokens_for_rl_debug\nHelper function to process and color tokens.\n\n\n\n\n\nutils.tokenization.color_token_for_rl_debug(\n decoded_token,\n encoded_token,\n color,\n text_only,\n)\nHelper function to color tokens based on their type.\n\n\n\nutils.tokenization.process_tokens_for_rl_debug(\n tokens,\n color,\n tokenizer,\n text_only,\n)\nHelper function to process and color tokens."
},
{
"objectID": "docs/api/monkeypatch.multipack.html",
"href": "docs/api/monkeypatch.multipack.html",
"title": "monkeypatch.multipack",
"section": "",
"text": "monkeypatch.multipack\nmonkeypatch.multipack\nmultipack patching for v2 of sample packing"
},
{
"objectID": "docs/api/integrations.kd.trainer.html",
"href": "docs/api/integrations.kd.trainer.html",
"title": "integrations.kd.trainer",
"section": "",
"text": "integrations.kd.trainer\nKD trainer\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlKDTrainer\nCustom trainer subclass for Knowledge Distillation (KD)\n\n\n\n\n\nintegrations.kd.trainer.AxolotlKDTrainer(*args, **kwargs)\nCustom trainer subclass for Knowledge Distillation (KD)\n\n\n\n\n\nName\nDescription\n\n\n\n\ncompute_loss\nHow the loss is computed by Trainer. By default, all models return the loss in the first element.\n\n\n\n\n\nintegrations.kd.trainer.AxolotlKDTrainer.compute_loss(\n model,\n inputs,\n return_outputs=False,\n num_items_in_batch=None,\n)\nHow the loss is computed by Trainer. By default, all models return the loss in the first element.\nSubclass and override for custom behavior."
},
{
"objectID": "docs/api/integrations.kd.trainer.html#classes",
"href": "docs/api/integrations.kd.trainer.html#classes",
"title": "integrations.kd.trainer",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlKDTrainer\nCustom trainer subclass for Knowledge Distillation (KD)\n\n\n\n\n\nintegrations.kd.trainer.AxolotlKDTrainer(*args, **kwargs)\nCustom trainer subclass for Knowledge Distillation (KD)\n\n\n\n\n\nName\nDescription\n\n\n\n\ncompute_loss\nHow the loss is computed by Trainer. By default, all models return the loss in the first element.\n\n\n\n\n\nintegrations.kd.trainer.AxolotlKDTrainer.compute_loss(\n model,\n inputs,\n return_outputs=False,\n num_items_in_batch=None,\n)\nHow the loss is computed by Trainer. By default, all models return the loss in the first element.\nSubclass and override for custom behavior."
},
{
"objectID": "docs/api/monkeypatch.mixtral.html",
"href": "docs/api/monkeypatch.mixtral.html",
"title": "monkeypatch.mixtral",
"section": "",
"text": "monkeypatch.mixtral\nmonkeypatch.mixtral\nPatches to support multipack for mixtral"
},
{
"objectID": "docs/api/core.trainers.base.html",
"href": "docs/api/core.trainers.base.html",
"title": "core.trainers.base",
"section": "",
"text": "core.trainers.base\nModule for customized trainers\n\n\n\n\n\nName\nDescription\n\n\n\n\nAxolotlTrainer\nExtend the base Trainer for axolotl helpers\n\n\n\n\n\ncore.trainers.base.AxolotlTrainer(\n *_args,\n bench_data_collator=None,\n eval_data_collator=None,\n dataset_tags=None,\n **kwargs,\n)\nExtend the base Trainer for axolotl helpers\n\n\n\n\n\nName\nDescription\n\n\n\n\nlog\nLog logs on the various objects watching training, including stored metrics.\n\n\npush_to_hub\nOverwrite the push_to_hub method in order to force-add the tags when pushing the\n\n\nstore_metrics\nStore metrics with specified reduction type.\n\n\n\n\n\ncore.trainers.base.AxolotlTrainer.log(logs, start_time=None)\nLog logs on the various objects watching training, including stored metrics.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nlogs\ndict[str, float]\nThe values to log.\nrequired\n\n\nstart_time\nfloat | None\nThe start of training.\nNone\n\n\n\n\n\n\n\ncore.trainers.base.AxolotlTrainer.push_to_hub(*args, **kwargs)\nOverwrite the push_to_hub method in order to force-add the tags when pushing the\nmodel on the Hub. Please refer to ~transformers.Trainer.push_to_hub for more details.\n\n\n\ncore.trainers.base.AxolotlTrainer.store_metrics(\n metrics,\n train_eval='train',\n reduction='mean',\n)\nStore metrics with specified reduction type.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmetrics\ndict[str, float] | dict[str, tuple[int | float, str]]\nDictionary of metric names to values, or metric names to (value, reduction_type) tuples.\nrequired\n\n\ntrain_eval\nLiteral['train', 'eval']\nWhether this is for training or evaluation.\n'train'"
},
{
"objectID": "docs/api/core.trainers.base.html#classes",
"href": "docs/api/core.trainers.base.html#classes",
"title": "core.trainers.base",
"section": "",
"text": "Name\nDescription\n\n\n\n\nAxolotlTrainer\nExtend the base Trainer for axolotl helpers\n\n\n\n\n\ncore.trainers.base.AxolotlTrainer(\n *_args,\n bench_data_collator=None,\n eval_data_collator=None,\n dataset_tags=None,\n **kwargs,\n)\nExtend the base Trainer for axolotl helpers\n\n\n\n\n\nName\nDescription\n\n\n\n\nlog\nLog logs on the various objects watching training, including stored metrics.\n\n\npush_to_hub\nOverwrite the push_to_hub method in order to force-add the tags when pushing the\n\n\nstore_metrics\nStore metrics with specified reduction type.\n\n\n\n\n\ncore.trainers.base.AxolotlTrainer.log(logs, start_time=None)\nLog logs on the various objects watching training, including stored metrics.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nlogs\ndict[str, float]\nThe values to log.\nrequired\n\n\nstart_time\nfloat | None\nThe start of training.\nNone\n\n\n\n\n\n\n\ncore.trainers.base.AxolotlTrainer.push_to_hub(*args, **kwargs)\nOverwrite the push_to_hub method in order to force-add the tags when pushing the\nmodel on the Hub. Please refer to ~transformers.Trainer.push_to_hub for more details.\n\n\n\ncore.trainers.base.AxolotlTrainer.store_metrics(\n metrics,\n train_eval='train',\n reduction='mean',\n)\nStore metrics with specified reduction type.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nmetrics\ndict[str, float] | dict[str, tuple[int | float, str]]\nDictionary of metric names to values, or metric names to (value, reduction_type) tuples.\nrequired\n\n\ntrain_eval\nLiteral['train', 'eval']\nWhether this is for training or evaluation.\n'train'"
},
{
"objectID": "docs/api/utils.schemas.integrations.html",
"href": "docs/api/utils.schemas.integrations.html",
"title": "utils.schemas.integrations",
"section": "",
"text": "utils.schemas.integrations\nPydantic models for Axolotl integrations\n\n\n\n\n\nName\nDescription\n\n\n\n\nCometConfig\nComet configuration subset\n\n\nGradioConfig\nGradio configuration subset\n\n\nLISAConfig\nLISA configuration subset\n\n\nMLFlowConfig\nMLFlow configuration subset\n\n\nOpenTelemetryConfig\nOpenTelemetry configuration subset\n\n\nRayConfig\nRay launcher configuration subset\n\n\nTrackioConfig\nTrackio configuration subset\n\n\nWandbConfig\nWandb configuration subset\n\n\n\n\n\nutils.schemas.integrations.CometConfig()\nComet configuration subset\n\n\n\nutils.schemas.integrations.GradioConfig()\nGradio configuration subset\n\n\n\nutils.schemas.integrations.LISAConfig()\nLISA configuration subset\n\n\n\nutils.schemas.integrations.MLFlowConfig()\nMLFlow configuration subset\n\n\n\nutils.schemas.integrations.OpenTelemetryConfig()\nOpenTelemetry configuration subset\n\n\n\nutils.schemas.integrations.RayConfig()\nRay launcher configuration subset\n\n\n\nutils.schemas.integrations.TrackioConfig()\nTrackio configuration subset\n\n\n\nutils.schemas.integrations.WandbConfig()\nWandb configuration subset"
},
{
"objectID": "docs/api/utils.schemas.integrations.html#classes",
"href": "docs/api/utils.schemas.integrations.html#classes",
"title": "utils.schemas.integrations",
"section": "",
"text": "Name\nDescription\n\n\n\n\nCometConfig\nComet configuration subset\n\n\nGradioConfig\nGradio configuration subset\n\n\nLISAConfig\nLISA configuration subset\n\n\nMLFlowConfig\nMLFlow configuration subset\n\n\nOpenTelemetryConfig\nOpenTelemetry configuration subset\n\n\nRayConfig\nRay launcher configuration subset\n\n\nTrackioConfig\nTrackio configuration subset\n\n\nWandbConfig\nWandb configuration subset\n\n\n\n\n\nutils.schemas.integrations.CometConfig()\nComet configuration subset\n\n\n\nutils.schemas.integrations.GradioConfig()\nGradio configuration subset\n\n\n\nutils.schemas.integrations.LISAConfig()\nLISA configuration subset\n\n\n\nutils.schemas.integrations.MLFlowConfig()\nMLFlow configuration subset\n\n\n\nutils.schemas.integrations.OpenTelemetryConfig()\nOpenTelemetry configuration subset\n\n\n\nutils.schemas.integrations.RayConfig()\nRay launcher configuration subset\n\n\n\nutils.schemas.integrations.TrackioConfig()\nTrackio configuration subset\n\n\n\nutils.schemas.integrations.WandbConfig()\nWandb configuration subset"
},
{
"objectID": "docs/api/core.trainers.mixins.rng_state_loader.html",
"href": "docs/api/core.trainers.mixins.rng_state_loader.html",
"title": "core.trainers.mixins.rng_state_loader",
"section": "",
"text": "core.trainers.mixins.rng_state_loader\nTemporary fix/override for bug in resume from checkpoint\nSee https://github.com/huggingface/transformers/pull/37162\nTODO: Remove when upstream added PR to release\n\n\n\n\n\nName\nDescription\n\n\n\n\nRngLoaderMixin\nmixin for method override to load RNG states from a checkpoint\n\n\n\n\n\ncore.trainers.mixins.rng_state_loader.RngLoaderMixin()\nmixin for method override to load RNG states from a checkpoint"
},
{
"objectID": "docs/api/core.trainers.mixins.rng_state_loader.html#classes",
"href": "docs/api/core.trainers.mixins.rng_state_loader.html#classes",
"title": "core.trainers.mixins.rng_state_loader",
"section": "",
"text": "Name\nDescription\n\n\n\n\nRngLoaderMixin\nmixin for method override to load RNG states from a checkpoint\n\n\n\n\n\ncore.trainers.mixins.rng_state_loader.RngLoaderMixin()\nmixin for method override to load RNG states from a checkpoint"
},
{
"objectID": "docs/api/cli.main.html",
"href": "docs/api/cli.main.html",
"title": "cli.main",
"section": "",
"text": "cli.main\nClick CLI definitions for various axolotl commands.\n\n\n\n\n\nName\nDescription\n\n\n\n\ncli\nAxolotl CLI - Train and fine-tune large language models\n\n\nevaluate\nEvaluate a model.\n\n\nfetch\nFetch example configs or other resources.\n\n\ninference\nRun inference with a trained model.\n\n\nmerge_lora\nMerge trained LoRA adapters into a base model.\n\n\nmerge_sharded_fsdp_weights\nMerge sharded FSDP model weights.\n\n\npreprocess\nPreprocess datasets before training.\n\n\ntrain\nTrain or fine-tune a model.\n\n\n\n\n\ncli.main.cli()\nAxolotl CLI - Train and fine-tune large language models\n\n\n\ncli.main.evaluate(ctx, config, launcher, **kwargs)\nEvaluate a model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\nclick.Context\nClick context for extra args.\nrequired\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nlauncher\nstr\nLauncher to use for multi-GPU evaluation (“accelerate”, “torchrun”, or “python”).\nrequired\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.fetch(directory, dest)\nFetch example configs or other resources.\nAvailable directories:\n- examples: Example configuration files\n- deepspeed_configs: DeepSpeed configuration files\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ndirectory\nstr\nOne of examples, deepspeed_configs.\nrequired\n\n\ndest\nOptional[str]\nOptional destination directory.\nrequired\n\n\n\n\n\n\n\ncli.main.inference(ctx, config, launcher, gradio, **kwargs)\nRun inference with a trained model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\nclick.Context\nClick context for extra args.\nrequired\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nlauncher\nstr\nLauncher to use for multi-GPU inference (“accelerate”, “torchrun”, or “python”).\nrequired\n\n\ngradio\nbool\nWhether to use Gradio browser interface or command line for inference.\nrequired\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.merge_lora(config, **kwargs)\nMerge trained LoRA adapters into a base model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.merge_sharded_fsdp_weights(ctx, config, launcher, **kwargs)\nMerge sharded FSDP model weights.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\nclick.Context\nClick context for extra args.\nrequired\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nlauncher\nstr\nLauncher to use for weight merging (“accelerate”, “torchrun”, or “python”).\nrequired\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.preprocess(config, cloud=None, **kwargs)\nPreprocess datasets before training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\ncloud\nOptional[str]\nPath to a cloud accelerator configuration file.\nNone\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.train(\n ctx,\n config,\n launcher='accelerate',\n cloud=None,\n sweep=None,\n **kwargs,\n)\nTrain or fine-tune a model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\nclick.Context\nClick context for extra args.\nrequired\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nlauncher\nLiteral['accelerate', 'torchrun', 'python']\nLauncher to use for multi-GPU training (“accelerate”, “torchrun”, or “python”).\n'accelerate'\n\n\ncloud\nstr | None\nPath to a cloud accelerator configuration file\nNone\n\n\nsweep\nstr | None\nPath to YAML config for sweeping hyperparameters.\nNone\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}"
},
{
"objectID": "docs/api/cli.main.html#functions",
"href": "docs/api/cli.main.html#functions",
"title": "cli.main",
"section": "",
"text": "Name\nDescription\n\n\n\n\ncli\nAxolotl CLI - Train and fine-tune large language models\n\n\nevaluate\nEvaluate a model.\n\n\nfetch\nFetch example configs or other resources.\n\n\ninference\nRun inference with a trained model.\n\n\nmerge_lora\nMerge trained LoRA adapters into a base model.\n\n\nmerge_sharded_fsdp_weights\nMerge sharded FSDP model weights.\n\n\npreprocess\nPreprocess datasets before training.\n\n\ntrain\nTrain or fine-tune a model.\n\n\n\n\n\ncli.main.cli()\nAxolotl CLI - Train and fine-tune large language models\n\n\n\ncli.main.evaluate(ctx, config, launcher, **kwargs)\nEvaluate a model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\nclick.Context\nClick context for extra args.\nrequired\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nlauncher\nstr\nLauncher to use for multi-GPU evaluation (“accelerate”, “torchrun”, or “python”).\nrequired\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.fetch(directory, dest)\nFetch example configs or other resources.\nAvailable directories:\n- examples: Example configuration files\n- deepspeed_configs: DeepSpeed configuration files\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ndirectory\nstr\nOne of examples, deepspeed_configs.\nrequired\n\n\ndest\nOptional[str]\nOptional destination directory.\nrequired\n\n\n\n\n\n\n\ncli.main.inference(ctx, config, launcher, gradio, **kwargs)\nRun inference with a trained model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\nclick.Context\nClick context for extra args.\nrequired\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nlauncher\nstr\nLauncher to use for multi-GPU inference (“accelerate”, “torchrun”, or “python”).\nrequired\n\n\ngradio\nbool\nWhether to use Gradio browser interface or command line for inference.\nrequired\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.merge_lora(config, **kwargs)\nMerge trained LoRA adapters into a base model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.merge_sharded_fsdp_weights(ctx, config, launcher, **kwargs)\nMerge sharded FSDP model weights.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\nclick.Context\nClick context for extra args.\nrequired\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nlauncher\nstr\nLauncher to use for weight merging (“accelerate”, “torchrun”, or “python”).\nrequired\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.preprocess(config, cloud=None, **kwargs)\nPreprocess datasets before training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\ncloud\nOptional[str]\nPath to a cloud accelerator configuration file.\nNone\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}\n\n\n\n\n\n\n\ncli.main.train(\n ctx,\n config,\n launcher='accelerate',\n cloud=None,\n sweep=None,\n **kwargs,\n)\nTrain or fine-tune a model.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nctx\nclick.Context\nClick context for extra args.\nrequired\n\n\nconfig\nstr\nPath to axolotl config YAML file.\nrequired\n\n\nlauncher\nLiteral['accelerate', 'torchrun', 'python']\nLauncher to use for multi-GPU training (“accelerate”, “torchrun”, or “python”).\n'accelerate'\n\n\ncloud\nstr | None\nPath to a cloud accelerator configuration file\nNone\n\n\nsweep\nstr | None\nPath to YAML config for sweeping hyperparameters.\nNone\n\n\nkwargs\n\nAdditional keyword arguments which correspond to CLI args or axolotl config options.\n{}"
},
{
"objectID": "docs/api/monkeypatch.trainer_fsdp_optim.html",
"href": "docs/api/monkeypatch.trainer_fsdp_optim.html",
"title": "monkeypatch.trainer_fsdp_optim",
"section": "",
"text": "monkeypatch.trainer_fsdp_optim\nfix for FSDP optimizer save in trainer w 4.47.0\n\n\n\n\n\nName\nDescription\n\n\n\n\npatch_training_loop_for_fsdp\nmonkeypatch for fixing the training loop for fsdp with optimizer save\n\n\n\n\n\nmonkeypatch.trainer_fsdp_optim.patch_training_loop_for_fsdp()\nmonkeypatch for fixing the training loop for fsdp with optimizer save"
},
{
"objectID": "docs/api/monkeypatch.trainer_fsdp_optim.html#functions",
"href": "docs/api/monkeypatch.trainer_fsdp_optim.html#functions",
"title": "monkeypatch.trainer_fsdp_optim",
"section": "",
"text": "Name\nDescription\n\n\n\n\npatch_training_loop_for_fsdp\nmonkeypatch for fixing the training loop for fsdp with optimizer save\n\n\n\n\n\nmonkeypatch.trainer_fsdp_optim.patch_training_loop_for_fsdp()\nmonkeypatch for fixing the training loop for fsdp with optimizer save"
},
{
"objectID": "docs/api/core.datasets.transforms.chat_builder.html",
"href": "docs/api/core.datasets.transforms.chat_builder.html",
"title": "core.datasets.transforms.chat_builder",
"section": "",
"text": "core.datasets.transforms.chat_builder\nThis module contains a function that builds a transform that takes a row from the\ndataset and converts it to a Chat.\n\n\n\n\n\nName\nDescription\n\n\n\n\nchat_message_transform_builder\nBuilds a transform that takes a row from the dataset and converts it to a Chat\n\n\n\n\n\ncore.datasets.transforms.chat_builder.chat_message_transform_builder(\n train_on_inputs=False,\n conversations_field='messages',\n message_field_role=None,\n message_field_content=None,\n message_field_training=None,\n)\nBuilds a transform that takes a row from the dataset and converts it to a Chat\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntrain_on_inputs\nbool\nIf True, the transform will train on the inputs. If False, the transform will train on the targets. Defaults to False.\nFalse\n\n\nconversations_field\nstr\nThe field name of the conversations. Defaults to “messages”.\n'messages'\n\n\nmessage_field_role\nstr | list[str]\nThe field name of the role.\nNone\n\n\nmessage_field_content\nstr | list[str]\nThe field name of the message content.\nNone\n\n\nmessage_field_training\nstr | list[str]\nThe field name of the train/weight.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nCallable\n\nA function that takes a list of conversations and returns a list of messages."
},
{
"objectID": "docs/api/core.datasets.transforms.chat_builder.html#functions",
"href": "docs/api/core.datasets.transforms.chat_builder.html#functions",
"title": "core.datasets.transforms.chat_builder",
"section": "",
"text": "Name\nDescription\n\n\n\n\nchat_message_transform_builder\nBuilds a transform that takes a row from the dataset and converts it to a Chat\n\n\n\n\n\ncore.datasets.transforms.chat_builder.chat_message_transform_builder(\n train_on_inputs=False,\n conversations_field='messages',\n message_field_role=None,\n message_field_content=None,\n message_field_training=None,\n)\nBuilds a transform that takes a row from the dataset and converts it to a Chat\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntrain_on_inputs\nbool\nIf True, the transform will train on the inputs. If False, the transform will train on the targets. Defaults to False.\nFalse\n\n\nconversations_field\nstr\nThe field name of the conversations. Defaults to “messages”.\n'messages'\n\n\nmessage_field_role\nstr | list[str]\nThe field name of the role.\nNone\n\n\nmessage_field_content\nstr | list[str]\nThe field name of the message content.\nNone\n\n\nmessage_field_training\nstr | list[str]\nThe field name of the train/weight.\nNone\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nCallable\n\nA function that takes a list of conversations and returns a list of messages."
},
{
"objectID": "docs/api/prompt_strategies.alpaca_w_system.html",
"href": "docs/api/prompt_strategies.alpaca_w_system.html",
"title": "prompt_strategies.alpaca_w_system",
"section": "",
"text": "prompt_strategies.alpaca_w_system\nPrompt strategies loader for alpaca instruction datasets with system prompts\n\n\n\n\n\nName\nDescription\n\n\n\n\nInstructionWSystemPromptTokenizingStrategy\nTokenizing strategy for instruction-based prompts.\n\n\nOpenOrcaPromptTokenizingStrategy\nTokenizing strategy for OpenOrca datasets\n\n\nOpenOrcaSystemDataPrompter\nAlpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts\n\n\nSystemDataPrompter\nAlpaca Style Prompter that uses system prompts from the dataset\n\n\n\n\n\nprompt_strategies.alpaca_w_system.InstructionWSystemPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for instruction-based prompts.\n\n\n\nprompt_strategies.alpaca_w_system.OpenOrcaPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for OpenOrca datasets\n\n\n\nprompt_strategies.alpaca_w_system.OpenOrcaSystemDataPrompter(\n prompt_style=PromptStyle.INSTRUCT.value,\n)\nAlpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts\n\n\n\nprompt_strategies.alpaca_w_system.SystemDataPrompter(\n prompt_style=PromptStyle.INSTRUCT.value,\n)\nAlpaca Style Prompter that uses system prompts from the dataset"
},
{
"objectID": "docs/api/prompt_strategies.alpaca_w_system.html#classes",
"href": "docs/api/prompt_strategies.alpaca_w_system.html#classes",
"title": "prompt_strategies.alpaca_w_system",
"section": "",
"text": "Name\nDescription\n\n\n\n\nInstructionWSystemPromptTokenizingStrategy\nTokenizing strategy for instruction-based prompts.\n\n\nOpenOrcaPromptTokenizingStrategy\nTokenizing strategy for OpenOrca datasets\n\n\nOpenOrcaSystemDataPrompter\nAlpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts\n\n\nSystemDataPrompter\nAlpaca Style Prompter that uses system prompts from the dataset\n\n\n\n\n\nprompt_strategies.alpaca_w_system.InstructionWSystemPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for instruction-based prompts.\n\n\n\nprompt_strategies.alpaca_w_system.OpenOrcaPromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for OpenOrca datasets\n\n\n\nprompt_strategies.alpaca_w_system.OpenOrcaSystemDataPrompter(\n prompt_style=PromptStyle.INSTRUCT.value,\n)\nAlpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts\n\n\n\nprompt_strategies.alpaca_w_system.SystemDataPrompter(\n prompt_style=PromptStyle.INSTRUCT.value,\n)\nAlpaca Style Prompter that uses system prompts from the dataset"
},
{
"objectID": "docs/api/integrations.cut_cross_entropy.args.html",
"href": "docs/api/integrations.cut_cross_entropy.args.html",
"title": "integrations.cut_cross_entropy.args",
"section": "",
"text": "integrations.cut_cross_entropy.args\nModule for handling Cut Cross Entropy input arguments.\n\n\n\n\n\nName\nDescription\n\n\n\n\nCutCrossEntropyArgs\nInput args for Cut Cross Entropy.\n\n\n\n\n\nintegrations.cut_cross_entropy.args.CutCrossEntropyArgs()\nInput args for Cut Cross Entropy."
},
{
"objectID": "docs/api/integrations.cut_cross_entropy.args.html#classes",
"href": "docs/api/integrations.cut_cross_entropy.args.html#classes",
"title": "integrations.cut_cross_entropy.args",
"section": "",
"text": "Name\nDescription\n\n\n\n\nCutCrossEntropyArgs\nInput args for Cut Cross Entropy.\n\n\n\n\n\nintegrations.cut_cross_entropy.args.CutCrossEntropyArgs()\nInput args for Cut Cross Entropy."
},
{
"objectID": "docs/api/monkeypatch.transformers_fa_utils.html",
"href": "docs/api/monkeypatch.transformers_fa_utils.html",
"title": "monkeypatch.transformers_fa_utils",
"section": "",
"text": "monkeypatch.transformers_fa_utils\nsee https://github.com/huggingface/transformers/pull/35834\n\n\n\n\n\nName\nDescription\n\n\n\n\nfixed_fa_peft_integration_check\nPEFT usually casts the layer norms in float32 for training stability reasons\n\n\n\n\n\nmonkeypatch.transformers_fa_utils.fixed_fa_peft_integration_check(\n query,\n key,\n value,\n target_dtype=None,\n preferred_dtype=None,\n)\nPEFT usually casts the layer norms in float32 for training stability reasons\ntherefore the input hidden states gets silently casted in float32. Hence, we need\ncast them back in float16 / bfloat16 just to be sure everything works as expected.\nThis might slowdown training & inference so it is recommended to not cast the LayerNorms!\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nquery\ntorch.Tensor\nInput query states to be passed to Flash Attention API\nrequired\n\n\nkey\ntorch.Tensor\nInput key states to be passed to Flash Attention API\nrequired\n\n\nvalue\ntorch.Tensor\nInput value states to be passed to Flash Attention API\nrequired\n\n\ntarget_dtype\ntorch.dtype, optional\nThe dtype to convert the attention tensors to. Conversion can be ignored by not providing the target dtype.\nNone\n\n\npreferred_dtype\ntorch.dtype, optional\nThe preferred dtype to convert the attention tensors to regardless of the target dtype.\nNone"
},
{
"objectID": "docs/api/monkeypatch.transformers_fa_utils.html#functions",
"href": "docs/api/monkeypatch.transformers_fa_utils.html#functions",
"title": "monkeypatch.transformers_fa_utils",
"section": "",
"text": "Name\nDescription\n\n\n\n\nfixed_fa_peft_integration_check\nPEFT usually casts the layer norms in float32 for training stability reasons\n\n\n\n\n\nmonkeypatch.transformers_fa_utils.fixed_fa_peft_integration_check(\n query,\n key,\n value,\n target_dtype=None,\n preferred_dtype=None,\n)\nPEFT usually casts the layer norms in float32 for training stability reasons\ntherefore the input hidden states gets silently casted in float32. Hence, we need\ncast them back in float16 / bfloat16 just to be sure everything works as expected.\nThis might slowdown training & inference so it is recommended to not cast the LayerNorms!\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nquery\ntorch.Tensor\nInput query states to be passed to Flash Attention API\nrequired\n\n\nkey\ntorch.Tensor\nInput key states to be passed to Flash Attention API\nrequired\n\n\nvalue\ntorch.Tensor\nInput value states to be passed to Flash Attention API\nrequired\n\n\ntarget_dtype\ntorch.dtype, optional\nThe dtype to convert the attention tensors to. Conversion can be ignored by not providing the target dtype.\nNone\n\n\npreferred_dtype\ntorch.dtype, optional\nThe preferred dtype to convert the attention tensors to regardless of the target dtype.\nNone"
},
{
"objectID": "docs/api/utils.data.streaming.html",
"href": "docs/api/utils.data.streaming.html",
"title": "utils.data.streaming",
"section": "",
"text": "utils.data.streaming\nutils.data.streaming\nData handling specific to streaming datasets."
},
{
"objectID": "docs/api/utils.collators.batching.html",
"href": "docs/api/utils.collators.batching.html",
"title": "utils.collators.batching",
"section": "",
"text": "utils.collators.batching\nData collators for axolotl to pad labels and position_ids for packed sequences\n\n\n\n\n\nName\nDescription\n\n\n\n\nBatchSamplerDataCollatorForSeq2Seq\nCollator for multipack specific to the using the BatchSampler\n\n\nDataCollatorForSeq2Seq\nData collator that will dynamically pad the inputs received, as well as the labels and position_ids\n\n\nPretrainingBatchSamplerDataCollatorForSeq2Seq\nCollator for multipack specific to the using the BatchSampler\n\n\nV2BatchSamplerDataCollatorForSeq2Seq\nCollator for multipack specific to the using the BatchSampler\n\n\n\n\n\nutils.collators.batching.BatchSamplerDataCollatorForSeq2Seq(\n tokenizer,\n model=None,\n padding=True,\n max_length=None,\n pad_to_multiple_of=None,\n label_pad_token_id=-100,\n position_pad_token_id=0,\n return_tensors='pt',\n)\nCollator for multipack specific to the using the BatchSampler\n\n\n\nutils.collators.batching.DataCollatorForSeq2Seq(\n tokenizer,\n model=None,\n padding=True,\n max_length=None,\n pad_to_multiple_of=None,\n label_pad_token_id=-100,\n position_pad_token_id=0,\n return_tensors='pt',\n)\nData collator that will dynamically pad the inputs received, as well as the labels and position_ids\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntokenizer\n[PreTrainedTokenizer] or [PreTrainedTokenizerFast]\nThe tokenizer used for encoding the data.\nrequired\n\n\nmodel\n[PreTrainedModel]\nThe model that is being trained. If set and has the prepare_decoder_input_ids_from_labels, use it to prepare the decoder_input_ids This is useful when using label_smoothing to avoid calculating loss twice.\nNone\n\n\npadding\nbool, str or [~utils.PaddingStrategy], optional, defaults to True\nSelect a strategy to pad the returned sequences (according to the models padding side and padding index) among: - True or 'longest' (default): Pad to the longest sequence in the batch (or no padding if only a single sequence is provided). - 'max_length': Pad to a maximum length specified with the argument max_length or to the maximum acceptable input length for the model if that argument is not provided. - False or 'do_not_pad': No padding (i.e., can output a batch with sequences of different lengths).\nTrue\n\n\nmax_length\nint, optional\nMaximum length of the returned list and optionally padding length (see above).\nNone\n\n\npad_to_multiple_of\nint, optional\nIf set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).\nNone\n\n\nlabel_pad_token_id\nint, optional, defaults to -100\nThe id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).\n-100\n\n\nreturn_tensors\nstr\nThe type of Tensor to return. Allowable values are “np”, “pt” and “tf”.\n'pt'\n\n\n\n\n\n\n\nutils.collators.batching.PretrainingBatchSamplerDataCollatorForSeq2Seq(\n *args,\n multipack_attn=True,\n **kwargs,\n)\nCollator for multipack specific to the using the BatchSampler\n\n\n\nutils.collators.batching.V2BatchSamplerDataCollatorForSeq2Seq(\n tokenizer,\n model=None,\n padding=True,\n max_length=None,\n pad_to_multiple_of=None,\n label_pad_token_id=-100,\n position_pad_token_id=0,\n return_tensors='pt',\n squash_position_ids=False,\n)\nCollator for multipack specific to the using the BatchSampler"
},
{
"objectID": "docs/api/utils.collators.batching.html#classes",
"href": "docs/api/utils.collators.batching.html#classes",
"title": "utils.collators.batching",
"section": "",
"text": "Name\nDescription\n\n\n\n\nBatchSamplerDataCollatorForSeq2Seq\nCollator for multipack specific to the using the BatchSampler\n\n\nDataCollatorForSeq2Seq\nData collator that will dynamically pad the inputs received, as well as the labels and position_ids\n\n\nPretrainingBatchSamplerDataCollatorForSeq2Seq\nCollator for multipack specific to the using the BatchSampler\n\n\nV2BatchSamplerDataCollatorForSeq2Seq\nCollator for multipack specific to the using the BatchSampler\n\n\n\n\n\nutils.collators.batching.BatchSamplerDataCollatorForSeq2Seq(\n tokenizer,\n model=None,\n padding=True,\n max_length=None,\n pad_to_multiple_of=None,\n label_pad_token_id=-100,\n position_pad_token_id=0,\n return_tensors='pt',\n)\nCollator for multipack specific to the using the BatchSampler\n\n\n\nutils.collators.batching.DataCollatorForSeq2Seq(\n tokenizer,\n model=None,\n padding=True,\n max_length=None,\n pad_to_multiple_of=None,\n label_pad_token_id=-100,\n position_pad_token_id=0,\n return_tensors='pt',\n)\nData collator that will dynamically pad the inputs received, as well as the labels and position_ids\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\ntokenizer\n[PreTrainedTokenizer] or [PreTrainedTokenizerFast]\nThe tokenizer used for encoding the data.\nrequired\n\n\nmodel\n[PreTrainedModel]\nThe model that is being trained. If set and has the prepare_decoder_input_ids_from_labels, use it to prepare the decoder_input_ids This is useful when using label_smoothing to avoid calculating loss twice.\nNone\n\n\npadding\nbool, str or [~utils.PaddingStrategy], optional, defaults to True\nSelect a strategy to pad the returned sequences (according to the models padding side and padding index) among: - True or 'longest' (default): Pad to the longest sequence in the batch (or no padding if only a single sequence is provided). - 'max_length': Pad to a maximum length specified with the argument max_length or to the maximum acceptable input length for the model if that argument is not provided. - False or 'do_not_pad': No padding (i.e., can output a batch with sequences of different lengths).\nTrue\n\n\nmax_length\nint, optional\nMaximum length of the returned list and optionally padding length (see above).\nNone\n\n\npad_to_multiple_of\nint, optional\nIf set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).\nNone\n\n\nlabel_pad_token_id\nint, optional, defaults to -100\nThe id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).\n-100\n\n\nreturn_tensors\nstr\nThe type of Tensor to return. Allowable values are “np”, “pt” and “tf”.\n'pt'\n\n\n\n\n\n\n\nutils.collators.batching.PretrainingBatchSamplerDataCollatorForSeq2Seq(\n *args,\n multipack_attn=True,\n **kwargs,\n)\nCollator for multipack specific to the using the BatchSampler\n\n\n\nutils.collators.batching.V2BatchSamplerDataCollatorForSeq2Seq(\n tokenizer,\n model=None,\n padding=True,\n max_length=None,\n pad_to_multiple_of=None,\n label_pad_token_id=-100,\n position_pad_token_id=0,\n return_tensors='pt',\n squash_position_ids=False,\n)\nCollator for multipack specific to the using the BatchSampler"
},
{
"objectID": "docs/api/utils.samplers.multipack.html",
"href": "docs/api/utils.samplers.multipack.html",
"title": "utils.samplers.multipack",
"section": "",
"text": "utils.samplers.multipack\nMultipack Batch Sampler - An efficient batch sampler for packing variable-length sequences\ninto fixed-capacity batches to optimize memory usage and training throughput.\n\n\n\n\n\nName\nDescription\n\n\n\n\nMultipackBatchSampler\nBatch sampler class for efficient packing of variable-length sequences\n\n\n\n\n\nutils.samplers.multipack.MultipackBatchSampler(\n sampler,\n batch_size,\n batch_max_len,\n lengths,\n bin_size,\n packing_efficiency_estimate=1.0,\n drop_last=True,\n num_count_samples=4,\n sequential=False,\n group_size=100000,\n num_processes=None,\n safe_mode=True,\n mp_start_method='fork',\n **kwargs,\n)\nBatch sampler class for efficient packing of variable-length sequences\nThis sampler packs sequences into fixed-capacity bins (batches) to maximize\nGPU memory utilization and training throughput by reducing padding.\nIt supports both parallel packing (using FFD algorithm) and\nsequential packing (preserving original sequence order).\n\n\n\n\n\nName\nDescription\n\n\n\n\nefficiency\nCalculate the packing efficiency (ratio of tokens used to total token slots).\n\n\ngather_efficiency\nGather and synchronize packing efficiency estimates across all distributed\n\n\ngather_len_batches\nGather and synchronize batch counts across all distributed ranks. Returns\n\n\ngenerate_batches\nGenerate packed batches for training.\n\n\nset_epoch\nSet the epoch number, used for reproducible shuffling across epochs\n\n\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.efficiency()\nCalculate the packing efficiency (ratio of tokens used to total token slots).\nHigher is better - 1.0 would mean perfect packing with no wasted space.\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.gather_efficiency()\nGather and synchronize packing efficiency estimates across all distributed\nranks.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nfloat\nA conservative efficiency estimate based on the measurements.\n\n\n\n\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.gather_len_batches(num)\nGather and synchronize batch counts across all distributed ranks. Returns\nthe minimum number of batches available on any rank.\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.generate_batches(set_stats=False)\nGenerate packed batches for training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nset_stats\nbool\nWhether to update efficiency statistics.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[list[list[int]]]\nList of batches, where each batch contains multiple bins, and each bin contains multiple sequence indices.\n\n\n\n\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.set_epoch(epoch)\nSet the epoch number, used for reproducible shuffling across epochs\n\n\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nallocate_sequentially\nSequential allocator that preserves example order.\n\n\nffd_check\nFirst-fit-decreasing bin packing algorithm check.\n\n\npack_group\nPack a group of sequences into bins using First-Fit Decreasing algorithm.\n\n\npack_parallel\nPack sequences into bins using parallel processing.\n\n\n\n\n\nutils.samplers.multipack.allocate_sequentially(\n sequence_lengths,\n rank,\n bin_capacity,\n num_ranks,\n)\nSequential allocator that preserves example order.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nsequence_lengths\nnp.ndarray\nThe lengths of all examples.\nrequired\n\n\nrank\nint\nThe current rank (for distributed training).\nrequired\n\n\nbin_capacity\nint\nThe capacity of each bin (maximum sequence length).\nrequired\n\n\nnum_ranks\nint\nNumber of ranks (processes / GPUs).\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nrank_batches\nlist[list[int]]\nList of batches for the current rank.\n\n\ntotal_tokens_used\nint\nNumber of actual example tokens.\n\n\ntotal_token_slots\nint\nMaximum theoretical number of example tokens (number of bins * bin capacity).\n\n\n\n\n\n\n\nutils.samplers.multipack.ffd_check(sequence_lengths, bin_capacity, num_bins)\nFirst-fit-decreasing bin packing algorithm check.\nChecks if sequences with the given lengths could fit in the specified number of\nbins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nsequence_lengths\nnp.ndarray\nArray of sequence lengths.\nrequired\n\n\nbin_capacity\nint\nMaximum capacity of each bin.\nrequired\n\n\nnum_bins\nint\nNumber of bins available.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nbool\nTrue if all sequences can be packed, False otherwise.\n\n\n\n\n\n\n\nutils.samplers.multipack.pack_group(\n sequence_lengths,\n group_offset,\n bin_capacity,\n max_bins,\n bin_size,\n safe_mode=True,\n)\nPack a group of sequences into bins using First-Fit Decreasing algorithm.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nsequence_lengths\nnp.ndarray\nArray of sequence lengths.\nrequired\n\n\ngroup_offset\nint\nOffset to apply to indices when returning results.\nrequired\n\n\nbin_capacity\nint\nMaximum capacity of each bin.\nrequired\n\n\nmax_bins\nint\nMaximum number of bins to use.\nrequired\n\n\nbin_size\nint\nMaximum number of sequences per bin.\nrequired\n\n\nsafe_mode\nbool\nIf True, use a more conservative packing approach.\nTrue\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[list[int]]\nList of bins, where each bin contains indices of sequences assigned to it.\n\n\n\n\n\n\n\nutils.samplers.multipack.pack_parallel(\n sequence_lengths,\n bin_capacity,\n group_size,\n bin_size,\n num_processes=None,\n safe_mode=True,\n mp_start_method='fork',\n)\nPack sequences into bins using parallel processing.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nsequence_lengths\nnp.ndarray\nArray of sequence lengths.\nrequired\n\n\nbin_capacity\nint\nMaximum capacity of each bin as total number of tokens.\nrequired\n\n\ngroup_size\nint\nNumber of sequences to process in each group.\nrequired\n\n\nbin_size\nint\nMaximum number of bins to use.\nrequired\n\n\nnum_processes\nint | None\nNumber of parallel processes to use.\nNone\n\n\nsafe_mode\nbool\nIf True, use a more conservative packing approach.\nTrue\n\n\nmp_start_method\nstr | None\nMultiprocessing start method (fork, spawn, forkserver). spawn is often safer with Numba/PyTorch. Set to None to use system default.\n'fork'\n\n\n\nReturns:\nList of bins, where each bin contains indices of sequences assigned to it."
},
{
"objectID": "docs/api/utils.samplers.multipack.html#classes",
"href": "docs/api/utils.samplers.multipack.html#classes",
"title": "utils.samplers.multipack",
"section": "",
"text": "Name\nDescription\n\n\n\n\nMultipackBatchSampler\nBatch sampler class for efficient packing of variable-length sequences\n\n\n\n\n\nutils.samplers.multipack.MultipackBatchSampler(\n sampler,\n batch_size,\n batch_max_len,\n lengths,\n bin_size,\n packing_efficiency_estimate=1.0,\n drop_last=True,\n num_count_samples=4,\n sequential=False,\n group_size=100000,\n num_processes=None,\n safe_mode=True,\n mp_start_method='fork',\n **kwargs,\n)\nBatch sampler class for efficient packing of variable-length sequences\nThis sampler packs sequences into fixed-capacity bins (batches) to maximize\nGPU memory utilization and training throughput by reducing padding.\nIt supports both parallel packing (using FFD algorithm) and\nsequential packing (preserving original sequence order).\n\n\n\n\n\nName\nDescription\n\n\n\n\nefficiency\nCalculate the packing efficiency (ratio of tokens used to total token slots).\n\n\ngather_efficiency\nGather and synchronize packing efficiency estimates across all distributed\n\n\ngather_len_batches\nGather and synchronize batch counts across all distributed ranks. Returns\n\n\ngenerate_batches\nGenerate packed batches for training.\n\n\nset_epoch\nSet the epoch number, used for reproducible shuffling across epochs\n\n\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.efficiency()\nCalculate the packing efficiency (ratio of tokens used to total token slots).\nHigher is better - 1.0 would mean perfect packing with no wasted space.\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.gather_efficiency()\nGather and synchronize packing efficiency estimates across all distributed\nranks.\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nfloat\nA conservative efficiency estimate based on the measurements.\n\n\n\n\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.gather_len_batches(num)\nGather and synchronize batch counts across all distributed ranks. Returns\nthe minimum number of batches available on any rank.\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.generate_batches(set_stats=False)\nGenerate packed batches for training.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nset_stats\nbool\nWhether to update efficiency statistics.\nFalse\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[list[list[int]]]\nList of batches, where each batch contains multiple bins, and each bin contains multiple sequence indices.\n\n\n\n\n\n\n\nutils.samplers.multipack.MultipackBatchSampler.set_epoch(epoch)\nSet the epoch number, used for reproducible shuffling across epochs"
},
{
"objectID": "docs/api/utils.samplers.multipack.html#functions",
"href": "docs/api/utils.samplers.multipack.html#functions",
"title": "utils.samplers.multipack",
"section": "",
"text": "Name\nDescription\n\n\n\n\nallocate_sequentially\nSequential allocator that preserves example order.\n\n\nffd_check\nFirst-fit-decreasing bin packing algorithm check.\n\n\npack_group\nPack a group of sequences into bins using First-Fit Decreasing algorithm.\n\n\npack_parallel\nPack sequences into bins using parallel processing.\n\n\n\n\n\nutils.samplers.multipack.allocate_sequentially(\n sequence_lengths,\n rank,\n bin_capacity,\n num_ranks,\n)\nSequential allocator that preserves example order.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nsequence_lengths\nnp.ndarray\nThe lengths of all examples.\nrequired\n\n\nrank\nint\nThe current rank (for distributed training).\nrequired\n\n\nbin_capacity\nint\nThe capacity of each bin (maximum sequence length).\nrequired\n\n\nnum_ranks\nint\nNumber of ranks (processes / GPUs).\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\nrank_batches\nlist[list[int]]\nList of batches for the current rank.\n\n\ntotal_tokens_used\nint\nNumber of actual example tokens.\n\n\ntotal_token_slots\nint\nMaximum theoretical number of example tokens (number of bins * bin capacity).\n\n\n\n\n\n\n\nutils.samplers.multipack.ffd_check(sequence_lengths, bin_capacity, num_bins)\nFirst-fit-decreasing bin packing algorithm check.\nChecks if sequences with the given lengths could fit in the specified number of\nbins.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nsequence_lengths\nnp.ndarray\nArray of sequence lengths.\nrequired\n\n\nbin_capacity\nint\nMaximum capacity of each bin.\nrequired\n\n\nnum_bins\nint\nNumber of bins available.\nrequired\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nbool\nTrue if all sequences can be packed, False otherwise.\n\n\n\n\n\n\n\nutils.samplers.multipack.pack_group(\n sequence_lengths,\n group_offset,\n bin_capacity,\n max_bins,\n bin_size,\n safe_mode=True,\n)\nPack a group of sequences into bins using First-Fit Decreasing algorithm.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nsequence_lengths\nnp.ndarray\nArray of sequence lengths.\nrequired\n\n\ngroup_offset\nint\nOffset to apply to indices when returning results.\nrequired\n\n\nbin_capacity\nint\nMaximum capacity of each bin.\nrequired\n\n\nmax_bins\nint\nMaximum number of bins to use.\nrequired\n\n\nbin_size\nint\nMaximum number of sequences per bin.\nrequired\n\n\nsafe_mode\nbool\nIf True, use a more conservative packing approach.\nTrue\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\n\n\n\n\n\nlist[list[int]]\nList of bins, where each bin contains indices of sequences assigned to it.\n\n\n\n\n\n\n\nutils.samplers.multipack.pack_parallel(\n sequence_lengths,\n bin_capacity,\n group_size,\n bin_size,\n num_processes=None,\n safe_mode=True,\n mp_start_method='fork',\n)\nPack sequences into bins using parallel processing.\n\n\n\n\n\n\n\n\n\n\n\nName\nType\nDescription\nDefault\n\n\n\n\nsequence_lengths\nnp.ndarray\nArray of sequence lengths.\nrequired\n\n\nbin_capacity\nint\nMaximum capacity of each bin as total number of tokens.\nrequired\n\n\ngroup_size\nint\nNumber of sequences to process in each group.\nrequired\n\n\nbin_size\nint\nMaximum number of bins to use.\nrequired\n\n\nnum_processes\nint | None\nNumber of parallel processes to use.\nNone\n\n\nsafe_mode\nbool\nIf True, use a more conservative packing approach.\nTrue\n\n\nmp_start_method\nstr | None\nMultiprocessing start method (fork, spawn, forkserver). spawn is often safer with Numba/PyTorch. Set to None to use system default.\n'fork'\n\n\n\nReturns:\nList of bins, where each bin contains indices of sequences assigned to it."
},
{
"objectID": "docs/api/prompt_strategies.dpo.chatml.html",
"href": "docs/api/prompt_strategies.dpo.chatml.html",
"title": "prompt_strategies.dpo.chatml",
"section": "",
"text": "prompt_strategies.dpo.chatml\nDPO strategies for chatml\n\n\n\n\n\nName\nDescription\n\n\n\n\nargilla_chat\nfor argilla/dpo-mix-7k conversations\n\n\nicr\nchatml transforms for datasets with system, input, chosen, rejected\n\n\nintel\nFor Intel Orca DPO Pairs\n\n\nultra\nfor ultrafeedback binarized conversations\n\n\n\n\n\nprompt_strategies.dpo.chatml.argilla_chat(cfg, **kwargs)\nfor argilla/dpo-mix-7k conversations\n\n\n\nprompt_strategies.dpo.chatml.icr(cfg, **kwargs)\nchatml transforms for datasets with system, input, chosen, rejected\nex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs\n\n\n\nprompt_strategies.dpo.chatml.intel(cfg, **kwargs)\nFor Intel Orca DPO Pairs\n\n\n\nprompt_strategies.dpo.chatml.ultra(cfg, **kwargs)\nfor ultrafeedback binarized conversations"
},
{
"objectID": "docs/api/prompt_strategies.dpo.chatml.html#functions",
"href": "docs/api/prompt_strategies.dpo.chatml.html#functions",
"title": "prompt_strategies.dpo.chatml",
"section": "",
"text": "Name\nDescription\n\n\n\n\nargilla_chat\nfor argilla/dpo-mix-7k conversations\n\n\nicr\nchatml transforms for datasets with system, input, chosen, rejected\n\n\nintel\nFor Intel Orca DPO Pairs\n\n\nultra\nfor ultrafeedback binarized conversations\n\n\n\n\n\nprompt_strategies.dpo.chatml.argilla_chat(cfg, **kwargs)\nfor argilla/dpo-mix-7k conversations\n\n\n\nprompt_strategies.dpo.chatml.icr(cfg, **kwargs)\nchatml transforms for datasets with system, input, chosen, rejected\nex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs\n\n\n\nprompt_strategies.dpo.chatml.intel(cfg, **kwargs)\nFor Intel Orca DPO Pairs\n\n\n\nprompt_strategies.dpo.chatml.ultra(cfg, **kwargs)\nfor ultrafeedback binarized conversations"
},
{
"objectID": "docs/api/utils.dict.html",
"href": "docs/api/utils.dict.html",
"title": "utils.dict",
"section": "",
"text": "utils.dict\nModule containing the DictDefault class\n\n\n\n\n\nName\nDescription\n\n\n\n\nDictDefault\nA Dict that returns None instead of returning empty Dict for missing keys.\n\n\n\n\n\nutils.dict.DictDefault()\nA Dict that returns None instead of returning empty Dict for missing keys.\n\n\n\n\n\n\n\nName\nDescription\n\n\n\n\nremove_none_values\nRemove null from a dictionary-like obj or list.\n\n\n\n\n\nutils.dict.remove_none_values(obj)\nRemove null from a dictionary-like obj or list.\nThese can appear due to Dataset loading causing schema merge.\nSee https://github.com/axolotl-ai-cloud/axolotl/pull/2909"
},
{
"objectID": "docs/api/utils.dict.html#classes",
"href": "docs/api/utils.dict.html#classes",
"title": "utils.dict",
"section": "",
"text": "Name\nDescription\n\n\n\n\nDictDefault\nA Dict that returns None instead of returning empty Dict for missing keys.\n\n\n\n\n\nutils.dict.DictDefault()\nA Dict that returns None instead of returning empty Dict for missing keys."
},
{
"objectID": "docs/api/utils.dict.html#functions",
"href": "docs/api/utils.dict.html#functions",
"title": "utils.dict",
"section": "",
"text": "Name\nDescription\n\n\n\n\nremove_none_values\nRemove null from a dictionary-like obj or list.\n\n\n\n\n\nutils.dict.remove_none_values(obj)\nRemove null from a dictionary-like obj or list.\nThese can appear due to Dataset loading causing schema merge.\nSee https://github.com/axolotl-ai-cloud/axolotl/pull/2909"
},
{
"objectID": "docs/api/prompt_strategies.dpo.zephyr.html",
"href": "docs/api/prompt_strategies.dpo.zephyr.html",
"title": "prompt_strategies.dpo.zephyr",
"section": "",
"text": "prompt_strategies.dpo.zephyr\nprompt_strategies.dpo.zephyr\nDPO strategies for zephyr"
},
{
"objectID": "docs/api/utils.optimizers.adopt.html",
"href": "docs/api/utils.optimizers.adopt.html",
"title": "utils.optimizers.adopt",
"section": "",
"text": "utils.optimizers.adopt\nCopied from https://github.com/iShohei220/adopt\nADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)\nTaniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka\n\n\n\n\n\nName\nDescription\n\n\n\n\nadopt\nFunctional API that performs ADOPT algorithm computation.\n\n\n\n\n\nutils.optimizers.adopt.adopt(\n params,\n grads,\n exp_avgs,\n exp_avg_sqs,\n state_steps,\n foreach=None,\n capturable=False,\n differentiable=False,\n fused=None,\n grad_scale=None,\n found_inf=None,\n has_complex=False,\n *,\n beta1,\n beta2,\n lr,\n clip_lambda,\n weight_decay,\n decouple,\n eps,\n maximize,\n)\nFunctional API that performs ADOPT algorithm computation."
},
{
"objectID": "docs/api/utils.optimizers.adopt.html#functions",
"href": "docs/api/utils.optimizers.adopt.html#functions",
"title": "utils.optimizers.adopt",
"section": "",
"text": "Name\nDescription\n\n\n\n\nadopt\nFunctional API that performs ADOPT algorithm computation.\n\n\n\n\n\nutils.optimizers.adopt.adopt(\n params,\n grads,\n exp_avgs,\n exp_avg_sqs,\n state_steps,\n foreach=None,\n capturable=False,\n differentiable=False,\n fused=None,\n grad_scale=None,\n found_inf=None,\n has_complex=False,\n *,\n beta1,\n beta2,\n lr,\n clip_lambda,\n weight_decay,\n decouple,\n eps,\n maximize,\n)\nFunctional API that performs ADOPT algorithm computation."
},
{
"objectID": "docs/api/prompt_strategies.metharme.html",
"href": "docs/api/prompt_strategies.metharme.html",
"title": "prompt_strategies.metharme",
"section": "",
"text": "prompt_strategies.metharme\nModule containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class\n\n\n\n\n\nName\nDescription\n\n\n\n\nMetharmePromptTokenizingStrategy\nTokenizing strategy for the Metharme models\n\n\nMetharmePrompter\nPrompter for the Metharme models.\n\n\n\n\n\nprompt_strategies.metharme.MetharmePromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for the Metharme models\n\n\n\nprompt_strategies.metharme.MetharmePrompter(*args, **kwargs)\nPrompter for the Metharme models."
},
{
"objectID": "docs/api/prompt_strategies.metharme.html#classes",
"href": "docs/api/prompt_strategies.metharme.html#classes",
"title": "prompt_strategies.metharme",
"section": "",
"text": "Name\nDescription\n\n\n\n\nMetharmePromptTokenizingStrategy\nTokenizing strategy for the Metharme models\n\n\nMetharmePrompter\nPrompter for the Metharme models.\n\n\n\n\n\nprompt_strategies.metharme.MetharmePromptTokenizingStrategy(\n prompter,\n tokenizer,\n train_on_inputs=False,\n sequence_len=2048,\n)\nTokenizing strategy for the Metharme models\n\n\n\nprompt_strategies.metharme.MetharmePrompter(*args, **kwargs)\nPrompter for the Metharme models."
},
{
"objectID": "docs/api/monkeypatch.gradient_checkpointing.offload_cpu.html",
"href": "docs/api/monkeypatch.gradient_checkpointing.offload_cpu.html",
"title": "monkeypatch.gradient_checkpointing.offload_cpu",
"section": "",
"text": "monkeypatch.gradient_checkpointing.offload_cpu\nCPU offloaded checkpointing\n\n\n\n\n\nName\nDescription\n\n\n\n\nCPU_Offloaded_Gradient_Checkpointer\nSaves VRAM by smartly offloading to RAM.\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_cpu.CPU_Offloaded_Gradient_Checkpointer(\n)\nSaves VRAM by smartly offloading to RAM.\nTiny hit to performance, since we mask the movement via non blocking calls."
},
{
"objectID": "docs/api/monkeypatch.gradient_checkpointing.offload_cpu.html#classes",
"href": "docs/api/monkeypatch.gradient_checkpointing.offload_cpu.html#classes",
"title": "monkeypatch.gradient_checkpointing.offload_cpu",
"section": "",
"text": "Name\nDescription\n\n\n\n\nCPU_Offloaded_Gradient_Checkpointer\nSaves VRAM by smartly offloading to RAM.\n\n\n\n\n\nmonkeypatch.gradient_checkpointing.offload_cpu.CPU_Offloaded_Gradient_Checkpointer(\n)\nSaves VRAM by smartly offloading to RAM.\nTiny hit to performance, since we mask the movement via non blocking calls."
},
{
"objectID": "docs/rlhf.html",
"href": "docs/rlhf.html",
"title": "RLHF (Beta)",
"section": "",
"text": "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human\nfeedback. Various methods include, but not limited to:\n\nDirect Preference Optimization (DPO)\nIdentity Preference Optimization (IPO)\nKahneman-Tversky Optimization (KTO)\nOdds Ratio Preference Optimization (ORPO)\nGroup Relative Policy Optimization (GRPO) — see also the GRPO deep dive for async features, custom rewards, and scaling\nGroup Reward-Decoupled Policy Optimization (GDPO)\nEnergy-Based Fine-Tuning (EBFT) — see also the EBFT guide for detailed mode comparisons and configuration\nNeMo Gym Integration\n\nFor help choosing between these methods, see Choosing a Fine-Tuning Method.",
"crumbs": [
"How To Guides",
"RLHF (Beta)"
]
},
{
"objectID": "docs/rlhf.html#overview",
"href": "docs/rlhf.html#overview",
"title": "RLHF (Beta)",
"section": "",
"text": "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human\nfeedback. Various methods include, but not limited to:\n\nDirect Preference Optimization (DPO)\nIdentity Preference Optimization (IPO)\nKahneman-Tversky Optimization (KTO)\nOdds Ratio Preference Optimization (ORPO)\nGroup Relative Policy Optimization (GRPO) — see also the GRPO deep dive for async features, custom rewards, and scaling\nGroup Reward-Decoupled Policy Optimization (GDPO)\nEnergy-Based Fine-Tuning (EBFT) — see also the EBFT guide for detailed mode comparisons and configuration\nNeMo Gym Integration\n\nFor help choosing between these methods, see Choosing a Fine-Tuning Method.",
"crumbs": [
"How To Guides",
"RLHF (Beta)"
]
},
{
"objectID": "docs/rlhf.html#rlhf-using-axolotl",
"href": "docs/rlhf.html#rlhf-using-axolotl",
"title": "RLHF (Beta)",
"section": "RLHF using Axolotl",
"text": "RLHF using Axolotl\n\n\n\n\n\n\nImportant\n\n\n\nThis is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.\n\n\nWe rely on the TRL library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.\n\n\n\n\n\n\nTip\n\n\n\nYou can find what each method supports by going into src/axolotl/prompt_strategies/{method} where {method} is one of our supported methods. The type: can be retrieved from {method}.{function_name}.\n\n\n\nDPO\nExample config:\nrl: dpo\ndatasets:\n - path: Intel/orca_dpo_pairs\n split: train\n type: chatml.intel\n - path: argilla/ultrafeedback-binarized-preferences\n split: train\n type: chatml\nDPO supports the following types with the following dataset format:\n\nchatml.argilla\n{\n \"system\": \"...\", // optional\n \"instruction\": \"...\",\n \"chosen_response\": \"...\",\n \"rejected_response\": \"...\"\n}\n\n\nchatml.argilla_chat\n{\n \"chosen\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ],\n \"rejected\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ]\n}\n\n\nchatml.icr\n{\n \"system\": \"...\", // optional\n \"input\": \"...\",\n \"chosen\": \"...\",\n \"rejected\": \"...\"\n}\n\n\nchatml.intel\n{\n \"system\": \"...\", // optional\n \"question\": \"...\",\n \"chosen\": \"...\",\n \"rejected\": \"...\"\n}\n\n\nchatml.prompt_pairs\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"chosen\": \"...\",\n \"rejected\": \"...\"\n}\n\n\nchatml.ultra\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"chosen\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ],\n \"rejected\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ]\n}\n\n\nllama3.argilla\n{\n \"system\": \"...\", // optional\n \"instruction\": \"...\",\n \"chosen_response\": \"...\",\n \"rejected_response\": \"...\"\n}\n\n\nllama3.argilla_chat\n{\n \"chosen\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ],\n \"rejected\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ]\n}\n\n\nllama3.icr\n{\n \"system\": \"...\", // optional\n \"input\": \"...\",\n \"chosen\": \"...\",\n \"rejected\": \"...\"\n}\n\n\nllama3.intel\n{\n \"system\": \"...\", // optional\n \"question\": \"...\",\n \"chosen\": \"...\",\n \"rejected\": \"...\"\n}\n\n\nllama3.prompt_pairs\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"chosen\": \"...\",\n \"rejected\": \"...\"\n}\n\n\nllama3.ultra\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"chosen\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ],\n \"rejected\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ]\n}\n\n\nzephyr.nectar\n{\n \"prompt\": \"...\",\n \"answers\": [\n {\n \"answer\": \"...\",\n \"rank\": 1\n },\n {\n \"answer\": \"...\",\n \"rank\": 2\n }\n // ... more answers with ranks\n ]\n}\n\n\nchat_template.argilla_chat\n{\n \"chosen\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ],\n \"rejected\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ]\n}\n\n\nchat_template.default\nrl: dpo\ndatasets:\n - path: ...\n split: train\n type: chat_template.default\n field_messages: \"messages\"\n field_chosen: \"chosen\"\n field_rejected: \"rejected\"\n message_property_mappings:\n role: role\n content: content\n roles:\n user: [\"user\"]\n assistant: [\"assistant\"]\n system: [\"system\"]\nSample input format:\n{\n \"messages\": [\n {\n \"role\": \"system\",\n \"content\": \"...\"\n },\n {\n \"role\": \"user\",\n \"content\": \"...\"\n },\n // ... more messages\n ],\n \"chosen\": {\n \"role\": \"assistant\",\n \"content\": \"...\"\n },\n \"rejected\": {\n \"role\": \"assistant\",\n \"content\": \"...\"\n }\n}\n\n\nuser_defined.default\nFor custom behaviors,\nrl: dpo\ndatasets:\n - path: ...\n split: train\n type:\n field_prompt: \"prompt\"\n field_system: \"system\"\n field_chosen: \"chosen\"\n field_rejected: \"rejected\"\n prompt_format: \"{prompt}\"\n chosen_format: \"{chosen}\"\n rejected_format: \"{rejected}\"\nThe input format is a simple JSON input with customizable fields based on the above config.\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"chosen\": \"...\",\n \"rejected\": \"...\"\n}\n\n\n\nIPO\nAs IPO is just DPO with a different loss function, all supported dataset formats for DPO are also supported for IPO.\nrl: ipo\n\n\nORPO\nPaper: https://arxiv.org/abs/2403.07691\nrl: orpo\norpo_alpha: 0.1\nremove_unused_columns: false\n\nchat_template: chatml\ndatasets:\n - path: argilla/ultrafeedback-binarized-preferences-cleaned\n type: chat_template.argilla\nORPO supports the following types with the following dataset format:\n\nchat_template.argilla\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\", // if available, will be taken as user message for single-turn instead of from list below\n\n // chosen/rejected should be same till last content and only even-number of alternating user/assistant turns\n \"chosen\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ],\n \"rejected\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ]\n}\n\n\n\nKTO\nrl: kto\nrl_beta: 0.1 # default\nkto_desirable_weight: 1.0 # default\nkto_undesirable_weight: 1.0 # default\n\nremove_unused_columns: false\n\ndatasets:\n - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto\n type: llama3.ultra\n split: train\n\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n use_reentrant: true\nKTO supports the following types with the following dataset format:\n\nchatml.argilla\n{\n \"system\": \"...\", // optional\n \"instruction\": \"...\",\n \"completion\": \"...\"\n}\n\n\nchatml.argilla_chat\n{\n \"chosen\": [\n {\"role\": \"user\", \"content\": \"...\"}\n ],\n \"completion\": [\n {\"role\": \"assistant\", \"content\": \"...\"}\n ]\n}\n\n\nchatml.intel\n{\n \"system\": \"...\", // optional\n \"question\": \"...\",\n \"completion\": \"...\"\n}\n\n\nchatml.prompt_pairs\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"completion\": \"...\"\n}\n\n\nchatml.ultra\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"completion\": \"...\"\n}\n\n\nllama3.argilla\n{\n \"system\": \"...\", // optional\n \"instruction\": \"...\",\n \"completion\": \"...\"\n}\n\n\nllama3.argilla_chat\n{\n \"completion\": [\n {\"role\": \"user\", \"content\": \"...\"},\n {\"role\": \"assistant\", \"content\": \"...\"}\n ]\n}\n\n\nllama3.intel\n{\n \"system\": \"...\", // optional\n \"question\": \"...\",\n \"completion\": \"...\"\n}\n\n\nllama3.prompt_pairs\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"completion\": \"...\"\n}\n\n\nllama3.ultra\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"completion\": \"...\"\n}\n\n\nuser_defined.default\nFor custom behaviors,\nrl: kto\ndatasets:\n - path: ...\n split: train\n type:\n field_prompt: \"prompt\"\n field_system: \"system\"\n field_completion: \"completion\"\n field_label: \"label\"\n prompt_format: \"{prompt}\"\n completion_format: \"{completion}\"\nThe input format is a simple JSON input with customizable fields based on the above config.\n{\n \"system\": \"...\", // optional\n \"prompt\": \"...\",\n \"completion\": \"...\",\n \"label\": \"...\"\n}\n\n\n\nGRPO\n\n\n\n\n\n\nTip\n\n\n\nCheck out our GRPO cookbook. For a comprehensive guide covering async training, custom rewards, importance sampling, and scaling, see the GRPO deep dive.\n\n\nIn the latest GRPO implementation, vLLM is used to significantly speedup trajectory generation during training. In this example, were using 4 GPUs - 2 for training, and 2 for vLLM:\n\n\n\n\n\n\nImportant\n\n\n\nMake sure youve installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. pip install axolotl[vllm].\n\n\nbase_model: Qwen/Qwen2.5-1.5B-Instruct\n\nvllm:\n host: 0.0.0.0\n port: 8000\n tensor_parallel_size: 2\n gpu_memory_utilization: 0.85\n dtype: auto\n # max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand\n\nrl: grpo\ntrl:\n use_vllm: true\n vllm_server_host: 0.0.0.0\n vllm_server_port: 8000\n vllm_server_timeout: 300\nCUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo.yaml\nYour vLLM instance will now attempt to spin up, and its time to kick off training utilizing our remaining two GPUs. In another terminal, execute:\nCUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2\n\n\n\n\n\n\nNote\n\n\n\nDue to TRLs implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use CUDA_VISIBLE_DEVICES=2,3 for the vLLM instance.\n\n\n\nReward functions\nGRPO uses custom reward functions and transformations. Please have them ready locally.\nFor example, to load OpenAIs GSM8K and use a random reward for completions:\n# rewards.py\nimport random\n\ndef rand_reward_func(completions, **kwargs) -> list[float]:\n return [random.uniform(0, 1) for _ in completions]\n\ndef oai_gsm8k_transform(cfg, *args, **kwargs):\n def transform_fn(example, tokenizer=None):\n label = example[\"answer\"].split(\"####\")[-1].strip().replace(\",\", \"\")\n return {\n \"prompt\": [{\"role\": \"user\", \"content\": example[\"question\"]},],\n \"answer\": label,\n }\n return transform_fn, {\"remove_columns\": [\"question\"]}\nrl: grpo\n\ntrl:\n beta: 0.001\n max_completion_length: 256\n use_vllm: True\n num_generations: 4\n reward_funcs: [\"rewards.rand_reward_func\"] # format: '{file_name}.{fn_name}'\n reward_weights: [1.0]\ndatasets:\n - path: openai/gsm8k\n name: main\n type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'\nTo see other examples of custom reward functions, please see TRL GRPO Docs.\nTo see all configs, please see TRLConfig.\n\n\nOpenEnv Rollout Functions\nGRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.\nFor example, to implement a simple math-solving environment with step-by-step verification:\n# math_env.py\nimport re\n\ndef math_solver_rollout(model, processing_class, prompts, generation_config=None):\n \"\"\"\n Custom rollout function that generates step-by-step math solutions.\n\n Args:\n model: The language model\n processing_class: The tokenizer/processing_class\n prompts: List of prompt dicts (with 'messages' key for chat format)\n generation_config: Optional generation configuration\n\n Returns:\n List of completion strings\n \"\"\"\n completions = []\n\n for prompt in prompts:\n # Apply chat template to prompt\n messages = prompt.get(\"messages\", [])\n formatted_prompt = processing_class.apply_chat_template(\n messages, processing_class=False, add_generation_prompt=True\n )\n\n # Generate step-by-step solution\n full_response = \"\"\n for step in range(5): # Max 5 reasoning steps\n current_input = formatted_prompt + full_response + \"\\nNext step:\"\n inputs = processing_class(current_input, return_tensors=\"pt\").to(model.device)\n\n outputs = model.generate(\n **inputs,\n max_new_tokens=100,\n generation_config=generation_config,\n )\n step_text = processing_class.decode(\n outputs[0][inputs.input_ids.shape[1]:],\n skip_special_tokens=True\n )\n\n # Check if solution is complete\n if \"FINAL ANSWER:\" in step_text:\n full_response += step_text\n break\n full_response += step_text + \"\\n\"\n\n completions.append(full_response)\n\n return completions\n\ndef math_reward(prompts, completions, answers, **kwargs):\n \"\"\"Reward function that checks mathematical correctness\"\"\"\n rewards = []\n for completion, correct_answer in zip(completions, answers):\n # Extract predicted answer\n match = re.search(r\"FINAL ANSWER:\\s*(.+)\", completion)\n predicted = match.group(1).strip() if match else \"\"\n\n # Compare with correct answer\n reward = 1.0 if predicted == str(correct_answer) else 0.0\n rewards.append(reward)\n\n return rewards\n\ndef math_transform(cfg, *args, **kwargs):\n \"\"\"Transform dataset to GRPO format with answer field\"\"\"\n def transform_fn(example, processing_class=None):\n return {\n \"prompt\": [{\"role\": \"user\", \"content\": example[\"question\"]}],\n \"answer\": str(example[\"answer\"]),\n }\n return transform_fn, {\"remove_columns\": [\"question\"]}\nrl: grpo\n\ntrl:\n beta: 0.001\n max_completion_length: 512\n num_generations: 4\n rollout_func: \"math_env.math_solver_rollout\" # Custom rollout function\n reward_funcs: [\"math_env.math_reward\"]\n reward_weights: [1.0]\n\ndatasets:\n - path: openai/gsm8k\n name: main\n type: math_env.math_transform\nThe rollout_func parameter accepts a fully qualified name (e.g., module_name.function_name) that points to a callable function in your local directory. The function receives:\n\nmodel: The language model\nprocessing_class: The tokenizer/processing class\nprompts: List of prompt dictionaries\ngeneration_config (optional): Generation configuration\n\nAnd should return a list of completion strings.\nFor more OpenEnv examples, see TRL OpenEnv Documentation.\n\n\nGRPO with DAPO/Dr. GRPO loss\nThe DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.\ntrl:\n loss_type: dr_grpo\n # Normalizes loss based on max completion length (default: 256)\n max_completion_length:\nFor more information, see GRPO docs.\n\n\nAsync GRPO\nAsync GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.\ntrl:\n use_data_producer: true # Enable data producer protocol\n use_vllm: true\n async_prefetch: true # Generate rollouts in background thread\n prefetch_depth: 1 # Number of rollouts to prefetch\n vllm_sync_interval: 2 # Sync weights to vLLM every N steps\n\n\n\n\n\n\nNote\n\n\n\nBecause the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by vllm_importance_sampling_correction: true (default when async is enabled).\n\n\n\nvLLM LoRA Sync\nBy default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.\nadapter: lora\nlora_r: 32\nlora_alpha: 64\nlora_target_linear: true\n\ntrl:\n vllm_lora_sync: true # Enable native LoRA sync\nWhen vllm_lora_sync: true is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml\nThen start training on a separate GPU:\nCUDA_VISIBLE_DEVICES=1 axolotl train config.yaml\n\n\n\n\n\n\nTip\n\n\n\nLoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.\n\n\n\n\nStreaming Partial Batch\nInstead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.\ntrl:\n streaming_partial_batch: true\n\n\nImportance Sampling Correction\nWhen using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.\ntrl:\n vllm_importance_sampling_correction: true # Enable IS correction\n importance_sampling_level: token # 'token' or 'sequence'\n off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this\n\nimportance_sampling_level: token applies per-token IS ratios (recommended with Liger kernel)\nimportance_sampling_level: sequence applies per-sequence IS ratios\noff_policy_mask_threshold masks out sequences where the IS ratio indicates they are too far off-policy\n\n\n\nReplay Buffer\nThe replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.\ntrl:\n replay_buffer_size: 100 # Max cached groups (0 = disabled)\n replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)\n\n\n\n\n\n\nNote\n\n\n\nWhen replay_recompute_logps: true (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.\n\n\n\n\nDeferred Re-rolling\nFailed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.\ntrl:\n reroll_start_fraction: 0.5 # Start re-rolling after 50% of training\n reroll_max_groups: 1 # Max groups to replace per batch\n\n\nZero-Advantage Batch Skipping\nWhen all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as skipped_zero_adv_batches=1.\ntrl:\n skip_zero_advantage_batches: true # default\n\n\nParallel Reward Workers\nReward functions that use signal.alarm() (e.g., math_verify) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.\ntrl:\n reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)\n\n\nFull Async GRPO Example\nbase_model: Qwen/Qwen2.5-1.5B-Instruct\n\nvllm:\n host: 0.0.0.0\n port: 8000\n gpu_memory_utilization: 0.35\n dtype: auto\n\nadapter: lora\nlora_r: 32\nlora_alpha: 64\nlora_target_linear: true\n\nrl: grpo\ntrl:\n use_data_producer: true\n use_vllm: true\n async_prefetch: true\n prefetch_depth: 1\n vllm_sync_interval: 2\n vllm_lora_sync: true\n streaming_partial_batch: true\n vllm_importance_sampling_correction: true\n off_policy_mask_threshold: 0.5\n importance_sampling_level: token\n num_generations: 8\n max_completion_length: 512\n reward_funcs:\n - rewards.accuracy_reward\n reroll_start_fraction: 0.5\n replay_buffer_size: 100\n reward_num_workers: 4\n skip_zero_advantage_batches: true\n\ndatasets:\n - path: AI-MO/NuminaMath-TIR\n type: rewards.prompt_transform\n split: train\n\ngradient_accumulation_steps: 4\nmicro_batch_size: 2\nmax_steps: 500\nlearning_rate: 1e-5\nbf16: true\ngradient_checkpointing: true\n# Terminal 1: Start vLLM on GPU 0\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml\n\n# Terminal 2: Train on GPU 1\nCUDA_VISIBLE_DEVICES=1 axolotl train config.yaml\n\n\nMulti-GPU Async GRPO\nAsync GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.\nFSDP:\nfsdp:\n - full_shard\n - auto_wrap\nfsdp_config:\n fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer\ngradient_checkpointing_kwargs:\n use_reentrant: false\nDeepSpeed ZeRO-3:\ndeepspeed: deepspeed_configs/zero3_bf16.json\ngradient_checkpointing_kwargs:\n use_reentrant: true # Required for ZeRO-3\n# Terminal 1: Start vLLM on GPU 0\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml\n\n# Terminal 2: Train on GPUs 0,1\nCUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml\n\n\n\n\n\n\nImportant\n\n\n\nWith multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.\n\n\n\n\n\n\nGDPO\nGDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the reward advantage collapse problem by normalizing each reward function independently before combining them.\n\n\n\n\n\n\nTip\n\n\n\nUse GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.\n\n\nPaper: https://arxiv.org/pdf/2501.05242\nGDPO uses TRLs native multi_objective_aggregation parameter under the hood. When you set rl: gdpo, axolotl automatically configures TRL to use normalize_then_sum aggregation.\nbase_model: Qwen/Qwen2.5-1.5B-Instruct\n\nvllm:\n host: 0.0.0.0\n port: 8000\n tensor_parallel_size: 2\n gpu_memory_utilization: 0.85\n\nrl: gdpo\n\ntrl:\n beta: 0.001\n max_completion_length: 256\n use_vllm: true\n num_generations: 4\n reward_funcs:\n - rewards.format_reward\n - rewards.correctness_reward\n reward_weights: [1.0, 2.0]\n\ndatasets:\n - path: openai/gsm8k\n name: main\n type: rewards.oai_gsm8k_transform\nYou can also use GRPO with explicit aggregation control:\nrl: grpo\ntrl:\n multi_objective_aggregation: normalize_then_sum # GDPO behavior\n # or: sum_then_normalize # Default GRPO behavior\n\nGDPO vs GRPO\n\n\n\n\n\n\n\n\nAspect\nGRPO\nGDPO\n\n\n\n\nAggregation\nsum_then_normalize\nnormalize_then_sum\n\n\nMulti-reward\nMay collapse advantages\nPreserves reward signals\n\n\nSingle reward\nStandard behavior\nEquivalent to GRPO\n\n\n\n\n\nWhy GDPO?\nWhen using multiple rewards with GRPO, different reward combinations can produce identical advantages:\n# Example: format + correctness rewards\n[format=0, correct=3] → sum=3\n[format=1, correct=2] → sum=3 ← GRPO sees these as equal!\n[format=2, correct=1] → sum=3\n[format=3, correct=0] → sum=3\nGDPO normalizes each reward independently, preserving their relative differences.\n\n\nReward Functions\nGDPO uses the same reward function format as GRPO:\n# rewards.py\ndef format_reward(completions, **kwargs) -> list[float]:\n return [1.0 if len(c) > 10 else 0.0 for c in completions]\n\ndef correctness_reward(completions, answers, **kwargs) -> list[float]:\n rewards = []\n for completion, answer in zip(completions, answers):\n # Your scoring logic here\n rewards.append(score)\n return rewards\n\n\nSequence Parallelism\nGDPO supports sequence parallelism for long-context training:\nrl: gdpo\ncontext_parallel_size: 2\n\n\n\nSimPO\nSimPO uses CPOTrainer but with alternative loss function.\nrl: simpo\nrl_beta: 0.1 # default in CPOTrainer\ncpo_alpha: 1.0 # default in CPOTrainer\nsimpo_gamma: 0.5 # default in CPOTrainer\nThis method uses the same dataset format as DPO.\n\n\nEBFT\n\n\n\n\n\n\nTip\n\n\n\nFor a detailed guide on EBFT modes, feature extraction, and configuration, see the EBFT guide.\n\n\nEBFT (Energy-Based Fine-Tuning) fine-tunes language models by optimizing a feature-matching loss rather than relying on external reward functions. A frozen copy of the model extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.\nPaper: “Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models” (Jelassi et al., 2026)\nKey advantages:\n\nNo reward model or verifier required — works on any (prompt, completion) data\nApplicable to non-verifiable tasks (code, translation, creative writing)\nOperates on model rollouts (not teacher forcing), reducing distribution shift\n\nEBFT supports two modes:\n\nStructured mode: For QA/instruction data with prompt + completion pairs. Uses vLLM for generation (like GRPO).\nStrided mode: For unstructured text without prompt/completion splits. Uses strided block-parallel generation with flex_attention — no vLLM needed.\n\n\nStructured Mode\nbase_model: Qwen/Qwen3-4B\n\nrl: ebft\n\nebft:\n feature_layers: [0.25, 0.5, 0.75] # Extract features at 25%, 50%, 75% depth\n embed_method: last_token\n use_whitening: false\n alignment_coef: 1.0 # Cosine similarity reward weight\n diversity_coef: 1.0 # Pairwise dot product penalty\n ce_coef: 0.0 # Cross-entropy on GT tokens (0 = off)\n\ntrl:\n num_generations: 4\n max_completion_length: 256\n temperature: 0.7\n use_vllm: true\n vllm_server_host: 0.0.0.0\n vllm_server_port: 8000\n vllm_lora_sync: true # LoRA adapter sync (recommended)\n vllm_sync_interval: 3\n use_data_producer: true\n async_prefetch: true # Set false for sync mode\n scale_rewards: true\n loss_type: grpo\n epsilon: 0.2\n\nvllm:\n gpu_memory_utilization: 0.5\n max_model_len: 2048\n\ndatasets:\n - path: nvidia/OpenCodeInstruct\n type: ebft_opencode.transform\n split: train[:500]\n\nadapter: lora\nlora_r: 16\nlora_alpha: 32\nlora_target_linear: true\n# Terminal 1: Start vLLM\nCUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml\n\n# Terminal 2: Train\nCUDA_VISIBLE_DEVICES=1 axolotl train config.yaml\n\n\nStrided Mode\nFor unstructured text (raw code, prose). No vLLM needed — runs on a single GPU.\nbase_model: meta-llama/Llama-3.2-1B\n\nrl: ebft\n\nebft:\n mode: strided\n stride: 8\n context_length: 8\n generate_max_len: 8\n n_samples_per_prompt: 4\n temperature: 0.6\n feature_layers: [0.25, 0.5, 0.75]\n embed_method: last_token\n use_whitening: true\n alignment_coef: 1.0\n diversity_coef: 1.0\n rl_coef: 1.0\n ce_coef: 0.03\n advantage_estimator: rloo\n\ndatasets:\n - path: nvidia/OpenCodeInstruct\n type: ebft_strided_structured.transform\n split: train[:1%]\n\nflash_attention: false\nflex_attention: true # Strided mode uses flex_attention\ngradient_checkpointing: true\ngradient_checkpointing_kwargs:\n use_reentrant: true # Required for flex_attention\nCUDA_VISIBLE_DEVICES=0 axolotl train config.yaml\n\n\n\n\n\n\nTip\n\n\n\nSee examples/ebft/ for complete example configs covering Llama 1B/3B/8B and Qwen3 4B/8B models in both modes.\n\n\n\n\nEBFT Configuration Reference\n\n\n\n\n\n\n\n\nParameter\nDefault\nDescription\n\n\n\n\nebft.feature_layers\n[0.25, 0.5, 0.75]\nLayer depths for feature extraction (fractional)\n\n\nebft.embed_method\nlast_token\nFeature pooling: last_token, mean_pooling, concat\n\n\nebft.use_whitening\nfalse\nSVD whitening of feature dimensions\n\n\nebft.alignment_coef\n1.0\nCosine similarity reward weight\n\n\nebft.diversity_coef\n1.0\nPairwise dot product penalty weight\n\n\nebft.ce_coef\n0.0\nCross-entropy loss on ground-truth tokens\n\n\nebft.mode\nstructured\nstructured (vLLM) or strided (no vLLM)\n\n\nebft.stride\n—\nTokens between anchor points (strided mode)\n\n\nebft.context_length\n—\nContext window per block (strided mode)\n\n\nebft.generate_max_len\n—\nTokens to generate per block (strided mode)\n\n\nebft.n_samples_per_prompt\n—\nRollouts per document (strided mode)\n\n\nebft.advantage_estimator\ngrpo\ngrpo or rloo (strided mode)\n\n\n\n\n\n\nNeMo Gym Integration\nNeMo Gym provides 50+ verified RL environments (math, coding, tool-use, reasoning) with deterministic reward signals. The axolotl integration supports both single-turn (call /verify after generation) and multi-turn (agent-based tool execution via /run).\n\nSingle-Turn (Simplest)\nFor environments that only need answer verification (math, coding challenges). No agent server needed — the reward function calls /verify directly on the resource server.\nbase_model: Qwen/Qwen2.5-0.5B-Instruct\n\nrl: grpo\nchat_template: tokenizer_default\n\ntrl:\n use_vllm: false # Colocate mode (single GPU)\n num_generations: 4\n max_completion_length: 128\n temperature: 0.9\n reward_funcs:\n - axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify\n\nplugins:\n - axolotl.integrations.nemo_gym.NemoGymPlugin\n\nnemo_gym_enabled: true\nnemo_gym_dir: ~/Gym\nnemo_gym_auto_start: false\nnemo_gym_head_port: 11000\nnemo_gym_datasets:\n - path: resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl\n server_name: reasoning_gym\n\ndatasets:\n - path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl\n type: chat_template\n field_messages: responses_create_params.input\n message_field_content: content\n message_field_role: role\n# Terminal 1: Start NeMo Gym resource server\ncd ~/Gym && .venv/bin/ng_run \\\n \"+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]\" \\\n \"+skip_venv_if_present=true\"\n\n# Terminal 2: Train\nCUDA_VISIBLE_DEVICES=0 axolotl train config.yaml\n\n\n\n\n\n\nNote\n\n\n\nnemo_gym_datasets.path is relative to nemo_gym_dir. Dont use absolute paths or they will be double-joined.\n\n\n\n\nMulti-Turn with Async GRPO (Recommended)\nFor environments with tool-use (weather, search, databases). An agent server orchestrates multi-turn interactions: generate → parse tool calls → execute tools → feed results back → repeat until done.\nbase_model: Qwen/Qwen3-0.6B\n\nrl: grpo\nchat_template: tokenizer_default\n\nadapter: lora\nlora_r: 16\nlora_alpha: 32\nlora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]\n\ntrl:\n use_vllm: true\n vllm_mode: server\n vllm_server_host: localhost\n vllm_server_port: 8000\n vllm_lora_sync: true\n vllm_sync_interval: 5\n use_data_producer: true\n async_prefetch: true # 3x speedup\n num_generations: 4\n max_completion_length: 512\n temperature: 0.8\n reward_funcs:\n - axolotl.integrations.nemo_gym.rewards.reward_env\n\nplugins:\n - axolotl.integrations.nemo_gym.NemoGymPlugin\n\nnemo_gym_enabled: true\nnemo_gym_auto_start: false\nnemo_gym_head_port: 11000\nnemo_gym_multi_turn: true\nnemo_gym_verify_timeout: 120\nnemo_gym_datasets:\n - path: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl\n server_name: example_single_tool_call\n\ndatasets:\n - path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl\n type: chat_template\n field_messages: responses_create_params.input\n message_field_content: content\n message_field_role: role\n\nvllm:\n gpu_memory_utilization: 0.85\n max_model_len: 2048\nMulti-turn requires three services running:\n# Terminal 1: vLLM with LoRA + tool calling\nVLLM_ALLOW_RUNTIME_LORA_UPDATING=1 CUDA_VISIBLE_DEVICES=0 \\\n python -m vllm.entrypoints.openai.api_server \\\n --model Qwen/Qwen3-0.6B --max-model-len 2048 \\\n --gpu-memory-utilization 0.85 \\\n --enable-lora --max-lora-rank 64 \\\n --enable-auto-tool-choice --tool-call-parser hermes\n\n# Terminal 2: NeMo Gym servers (resource + model proxy + agent)\ncd ~/Gym && .venv/bin/ng_run \\\n \"+config_paths=[configs/axolotl_tool_calling.yaml]\" \\\n \"+skip_venv_if_present=true\"\n\n# Terminal 3: Training\nCUDA_VISIBLE_DEVICES=1 axolotl train config.yaml\n\n\n\n\n\n\nImportant\n\n\n\nMulti-turn requires a NeMo Gym agent config YAML that defines three components: a resource server (tools + /verify), a model server proxy (forwards to your vLLM), and an agent server (orchestrates /run). See the NeMo Gym README for agent config format.\n\n\n\n\nNeMo Gym Prerequisites\n# Clone and set up NeMo Gym\ngit clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym\ncd ~/Gym\nuv venv --python 3.12 && source .venv/bin/activate && uv sync\n\n# Fix pycosat build (GCC 13+)\nCFLAGS=\"\" uv pip install pycosat --python .venv/bin/python --no-build-isolation\n\n\nNeMo Gym Configuration Reference\n\n\n\n\n\n\n\n\n\nParameter\nType\nDefault\nDescription\n\n\n\n\nnemo_gym_enabled\nbool\n—\nEnable the NeMo Gym integration\n\n\nnemo_gym_dir\nstr\n~/Gym\nPath to NeMo Gym repo\n\n\nnemo_gym_auto_start\nbool\ntrue\nAuto-start resource servers\n\n\nnemo_gym_head_port\nint\n11000\nHead server port\n\n\nnemo_gym_multi_turn\nbool\nfalse\nEnable multi-turn via agent /run\n\n\nnemo_gym_verify_timeout\nint\n30\nPer-request timeout (seconds)\n\n\nnemo_gym_datasets\nlist\nrequired\nDataset configs with path and server_name\n\n\n\n\n\nReward Functions\n\n\n\n\n\n\n\n\nFunction\nMode\nDescription\n\n\n\n\naxolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify\nSingle-turn\nCalls /verify, returns binary reward\n\n\naxolotl.integrations.nemo_gym.rewards.reward_env\nMulti-turn\nPassthrough reward from agent /run\n\n\n\n\n\n\nUsing local dataset files\ndatasets:\n - ds_type: json\n data_files:\n - orca_rlhf.jsonl\n split: train\n type: chatml.intel\n\n\nTRL auto-unwrapping for PEFT\nTRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:\n# load ref model when adapter training.\nrl_adapter_ref_model: true",
"crumbs": [
"How To Guides",
"RLHF (Beta)"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html",
"href": "docs/dataset-formats/inst_tune.html",
"title": "Instruction Tuning",
"section": "",
"text": "instruction; input(optional)\n\n\ndata.jsonl\n\n{\"instruction\": \"...\", \"input\": \"...\", \"output\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#alpaca",
"href": "docs/dataset-formats/inst_tune.html#alpaca",
"title": "Instruction Tuning",
"section": "",
"text": "instruction; input(optional)\n\n\ndata.jsonl\n\n{\"instruction\": \"...\", \"input\": \"...\", \"output\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#jeopardy",
"href": "docs/dataset-formats/inst_tune.html#jeopardy",
"title": "Instruction Tuning",
"section": "jeopardy",
"text": "jeopardy\nquestion and answer\n\n\ndata.jsonl\n\n{\"question\": \"...\", \"category\": \"...\", \"answer\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#oasst",
"href": "docs/dataset-formats/inst_tune.html#oasst",
"title": "Instruction Tuning",
"section": "oasst",
"text": "oasst\ninstruction\n\n\ndata.jsonl\n\n{\"INSTRUCTION\": \"...\", \"RESPONSE\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#gpteacher",
"href": "docs/dataset-formats/inst_tune.html#gpteacher",
"title": "Instruction Tuning",
"section": "gpteacher",
"text": "gpteacher\ninstruction; input(optional)\n\n\ndata.jsonl\n\n{\"instruction\": \"...\", \"input\": \"...\", \"response\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#reflection",
"href": "docs/dataset-formats/inst_tune.html#reflection",
"title": "Instruction Tuning",
"section": "reflection",
"text": "reflection\ninstruction with reflect; input(optional)\n\n\ndata.jsonl\n\n{\"instruction\": \"...\", \"input\": \"...\", \"output\": \"...\", \"reflection\": \"...\", \"corrected\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#explainchoice",
"href": "docs/dataset-formats/inst_tune.html#explainchoice",
"title": "Instruction Tuning",
"section": "explainchoice",
"text": "explainchoice\nquestion, choices, (solution OR explanation)\n\n\ndata.jsonl\n\n{\"question\": \"...\", \"choices\": [\"...\"], \"solution\": \"...\", \"explanation\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#concisechoice",
"href": "docs/dataset-formats/inst_tune.html#concisechoice",
"title": "Instruction Tuning",
"section": "concisechoice",
"text": "concisechoice\nquestion, choices, (solution OR explanation)\n\n\ndata.jsonl\n\n{\"question\": \"...\", \"choices\": [\"...\"], \"solution\": \"...\", \"explanation\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#summarizetldr",
"href": "docs/dataset-formats/inst_tune.html#summarizetldr",
"title": "Instruction Tuning",
"section": "summarizetldr",
"text": "summarizetldr\narticle and summary\n\n\ndata.jsonl\n\n{\"article\": \"...\", \"summary\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#alpaca_chat",
"href": "docs/dataset-formats/inst_tune.html#alpaca_chat",
"title": "Instruction Tuning",
"section": "alpaca_chat",
"text": "alpaca_chat\nbasic instruct for alpaca chat\n\n\ndata.jsonl\n\n{\"instruction\": \"...\", \"input\": \"...\", \"response\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#alpaca_chat.load_qa",
"href": "docs/dataset-formats/inst_tune.html#alpaca_chat.load_qa",
"title": "Instruction Tuning",
"section": "alpaca_chat.load_qa",
"text": "alpaca_chat.load_qa\nquestion and answer for alpaca chat\n\n\ndata.jsonl\n\n{\"question\": \"...\", \"answer\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#alpaca_chat.load_concise",
"href": "docs/dataset-formats/inst_tune.html#alpaca_chat.load_concise",
"title": "Instruction Tuning",
"section": "alpaca_chat.load_concise",
"text": "alpaca_chat.load_concise\nquestion and answer for alpaca chat, for concise answers\n\n\ndata.jsonl\n\n{\"instruction\": \"...\", \"input\": \"...\", \"response\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#alpaca_chat.load_camel_ai",
"href": "docs/dataset-formats/inst_tune.html#alpaca_chat.load_camel_ai",
"title": "Instruction Tuning",
"section": "alpaca_chat.load_camel_ai",
"text": "alpaca_chat.load_camel_ai\nquestion and answer for alpaca chat, for load_camel_ai\n\n\ndata.jsonl\n\n{\"message_1\": \"...\", \"message_2\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#alpaca_w_system.load_open_orca",
"href": "docs/dataset-formats/inst_tune.html#alpaca_w_system.load_open_orca",
"title": "Instruction Tuning",
"section": "alpaca_w_system.load_open_orca",
"text": "alpaca_w_system.load_open_orca\nsupport for open orca datasets with included system prompts, instruct\n\n\ndata.jsonl\n\n{\"system_prompt\": \"...\", \"question\": \"...\", \"response\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#context_qa",
"href": "docs/dataset-formats/inst_tune.html#context_qa",
"title": "Instruction Tuning",
"section": "context_qa",
"text": "context_qa\nin context question answering from an article\n\n\ndata.jsonl\n\n{\"article\": \"...\", \"question\": \"...\", \"answer\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#context_qa.load_v2",
"href": "docs/dataset-formats/inst_tune.html#context_qa.load_v2",
"title": "Instruction Tuning",
"section": "context_qa.load_v2",
"text": "context_qa.load_v2\nin context question answering (alternate)\n\n\ndata.jsonl\n\n{\"context\": \"...\", \"question\": \"...\", \"answer\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#context_qa.load_404",
"href": "docs/dataset-formats/inst_tune.html#context_qa.load_404",
"title": "Instruction Tuning",
"section": "context_qa.load_404",
"text": "context_qa.load_404\nin context question answering from an article, with default response for no answer from context\n\n\ndata.jsonl\n\n{\"article\": \"...\", \"unanswerable_question\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#creative_acr.load_answer",
"href": "docs/dataset-formats/inst_tune.html#creative_acr.load_answer",
"title": "Instruction Tuning",
"section": "creative_acr.load_answer",
"text": "creative_acr.load_answer\ninstruction and revision\n\n\ndata.jsonl\n\n{\"instruction\": \"...\", \"revision\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#creative_acr.load_critique",
"href": "docs/dataset-formats/inst_tune.html#creative_acr.load_critique",
"title": "Instruction Tuning",
"section": "creative_acr.load_critique",
"text": "creative_acr.load_critique\ncritique\n\n\ndata.jsonl\n\n{\"scores\": \"...\", \"critiques\": \"...\", \"instruction\": \"...\", \"answer\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#creative_acr.load_revise",
"href": "docs/dataset-formats/inst_tune.html#creative_acr.load_revise",
"title": "Instruction Tuning",
"section": "creative_acr.load_revise",
"text": "creative_acr.load_revise\ncritique and revise\n\n\ndata.jsonl\n\n{\"scores\": \"...\", \"critiques\": \"...\", \"instruction\": \"...\", \"answer\": \"...\", \"revision\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#metharme",
"href": "docs/dataset-formats/inst_tune.html#metharme",
"title": "Instruction Tuning",
"section": "metharme",
"text": "metharme\ninstruction, adds additional eos tokens\n\n\ndata.jsonl\n\n{\"prompt\": \"...\", \"generation\": \"...\"}",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/inst_tune.html#how-to-add-custom-prompt-format",
"href": "docs/dataset-formats/inst_tune.html#how-to-add-custom-prompt-format",
"title": "Instruction Tuning",
"section": "How to add custom prompt format",
"text": "How to add custom prompt format\nFor a dataset that is preprocessed for instruction purposes:\n\n\ndata.jsonl\n\n{\"input\": \"...\", \"output\": \"...\"}\n\nYou can use this example in your YAML config:\n\n\nconfig.yaml\n\ndatasets:\n - path: repo\n type:\n system_prompt: \"\"\n field_system: system\n field_instruction: input\n field_output: output\n format: \"[INST] {instruction} [/INST]\"\n no_input_format: \"[INST] {instruction} [/INST]\"\n\nSee full config options under here.",
"crumbs": [
"Dataset Formats",
"Instruction Tuning"
]
},
{
"objectID": "docs/dataset-formats/stepwise_supervised.html",
"href": "docs/dataset-formats/stepwise_supervised.html",
"title": "Stepwise Supervised Format",
"section": "",
"text": "The stepwise supervised format is designed for chain-of-thought (COT) reasoning\ndatasets where each example contains multiple completion steps and a preference label\nfor each step.\n\n\nHeres a simple example of a stepwise supervised dataset entry:\n{\n \"prompt\": \"Which number is larger, 9.8 or 9.11?\",\n \"completions\": [\n \"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.\",\n \"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8.\"\n ],\n \"labels\": [true, false]\n}",
"crumbs": [
"Dataset Formats",
"Stepwise Supervised Format"
]
},
{
"objectID": "docs/dataset-formats/stepwise_supervised.html#stepwise-supervised",
"href": "docs/dataset-formats/stepwise_supervised.html#stepwise-supervised",
"title": "Stepwise Supervised Format",
"section": "",
"text": "The stepwise supervised format is designed for chain-of-thought (COT) reasoning\ndatasets where each example contains multiple completion steps and a preference label\nfor each step.\n\n\nHeres a simple example of a stepwise supervised dataset entry:\n{\n \"prompt\": \"Which number is larger, 9.8 or 9.11?\",\n \"completions\": [\n \"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.\",\n \"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8.\"\n ],\n \"labels\": [true, false]\n}",
"crumbs": [
"Dataset Formats",
"Stepwise Supervised Format"
]
},
{
"objectID": "docs/dataset-formats/tokenized.html",
"href": "docs/dataset-formats/tokenized.html",
"title": "Custom Pre-Tokenized Dataset",
"section": "",
"text": "Pass an empty type: in your axolotl config.\nColumns in Dataset must be exactly input_ids, attention_mask, labels\nTo indicate that a token should be ignored during training, set its corresponding label to -100.\nYou must add BOS and EOS, and make sure that you are training on EOS by not setting its label to -100.\nFor pretraining, do not truncate/pad documents to the context window length.\nFor instruction training, documents must be truncated/padded as desired.\n\nSample config:\n\n\nconfig.yml\n\ndatasets:\n - path: /path/to/your/file.jsonl\n ds_type: json\n type:\n\nSample jsonl:\n{\"input_ids\":[271,299,99],\"attention_mask\":[1,1,1],\"labels\":[271,-100,99]}\n{\"input_ids\":[87,227,8383,12],\"attention_mask\":[1,1,1,1],\"labels\":[87,227,8383,12]}",
"crumbs": [
"Dataset Formats",
"Custom Pre-Tokenized Dataset"
]
},
{
"objectID": "docs/multimodal.html",
"href": "docs/multimodal.html",
"title": "MultiModal / Vision Language Models (BETA)",
"section": "",
"text": "Mllama\nLlama4\nPixtral\nLlava-1.5\nMistral-Small-3.1\nMistral-Small-4\nMagistral-Small-2509\nVoxtral\nGemma-3\nGemma-3n\nQwen2-VL\nQwen2.5-VL\nQwen3.5\nGLM-4.6V\nSmolVLM2\nLFM2-VL\nIntern-VL",
"crumbs": [
"How To Guides",
"MultiModal / Vision Language Models (BETA)"
]
},
{
"objectID": "docs/multimodal.html#supported-models",
"href": "docs/multimodal.html#supported-models",
"title": "MultiModal / Vision Language Models (BETA)",
"section": "",
"text": "Mllama\nLlama4\nPixtral\nLlava-1.5\nMistral-Small-3.1\nMistral-Small-4\nMagistral-Small-2509\nVoxtral\nGemma-3\nGemma-3n\nQwen2-VL\nQwen2.5-VL\nQwen3.5\nGLM-4.6V\nSmolVLM2\nLFM2-VL\nIntern-VL",
"crumbs": [
"How To Guides",
"MultiModal / Vision Language Models (BETA)"
]
},
{
"objectID": "docs/multimodal.html#usage",
"href": "docs/multimodal.html#usage",
"title": "MultiModal / Vision Language Models (BETA)",
"section": "Usage",
"text": "Usage\nMultimodal support is limited and doesnt have full feature parity.\nHere are the hyperparams youll need to use to finetune a multimodal model.\nprocessor_type: AutoProcessor\n\nskip_prepare_dataset: true\nremove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training\nsample_packing: false # not yet supported with multimodal\n\nchat_template: # see in next section if specified\n\n# example dataset\ndatasets:\n - path: HuggingFaceH4/llava-instruct-mix-vsft\n type: chat_template\n split: train[:1%]\n\n# (optional) if doing lora, only finetune the Language model,\n# leave the vision model and vision tower frozen\n# load_in_8bit: true\nadapter: lora\nlora_target_modules: 'model.language_model.layers.[\\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'\n\n# (optional) if you want to resize images to a set size\nimage_size: 512\nimage_resize_algorithm: bilinear\nPlease see examples folder for full configs.\n\n\n\n\n\n\nTip\n\n\n\nSome of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.\n\n\n\n\n\n\n\n\nNote\n\n\n\nAs of now, we do not truncate nor drop samples based on sequence_len as each arch has different ways to process non-text tokens. We are looking for help on this.\n\n\n\nMllama\nbase_model: meta-llama/Llama-3.2-11B-Vision-Instruct\n\nchat_template: llama3_2_vision\n\n\nLlama4\nbase_model: meta-llama/Llama-4-Scout-17B-16E-Instruct\n\nchat_template: llama4\n\n\nPixtral\nbase_model: mistralai/Pixtral-12B-2409\n\nchat_template: pixtral\n\n\nLlava-1.5\nbase_model: llava-hf/llava-1.5-7b-hf\n\nchat_template: llava\n\n\nMistral-Small-3.1\n\n\n\n\n\n\nTip\n\n\n\nPlease make sure to install vision lib via pip install 'mistral-common[opencv]==1.8.5'\n\n\nbase_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503\n\n\nMistral-Small-4\nbase_model: mistralai/Mistral-Small-4-119B-2603\n\n\nMagistral-Small-2509\n\n\n\n\n\n\nTip\n\n\n\nPlease make sure to install vision lib via pip install 'mistral-common[opencv]==1.8.5'\n\n\nbase_model: mistralai/Magistral-Small-2509\n\n\nVoxtral\n\n\n\n\n\n\nTip\n\n\n\nPlease make sure to install audio lib via pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'\n\n\nbase_model: mistralai/Voxtral-Mini-3B-2507\n\nprocessor_type: VoxtralProcessor\n\n\nGemma-3\n\n\n\n\n\n\nTip\n\n\n\nThe Gemma3-1B model is a text-only model, so please train as regular text model.\n\n\nFor multi-modal 4B/12B/27B models, use the following config:\nbase_model: google/gemma-3-4b-it\n\nchat_template: gemma3\n\n\nGemma-3n\n\n\n\n\n\n\nWarning\n\n\n\nThe models initial loss and grad norm will be very high. We suspect this to be due to the Conv in the vision layers.\n\n\n\n\n\n\n\n\nTip\n\n\n\nPlease make sure to install timm via pip3 install timm==1.0.17\n\n\nbase_model: google/gemma-3n-E2B-it\n\nchat_template: gemma3n\n\n\nQwen2-VL\nbase_model: Qwen/Qwen2-VL-7B-Instruct\n\nchat_template: qwen2_vl\n\n\nQwen2.5-VL\nbase_model: Qwen/Qwen2.5-VL-7B-Instruct\n\nchat_template: qwen2_vl # same as qwen2-vl\n\n\nQwen3-VL\nbase_model: Qwen/Qwen3-VL-4B-Instruct\n\nchat_template: qwen2_vl # same as qwen2-vl\n\n\nQwen3.5\nbase_model: Qwen/Qwen3.5-9B\n\nchat_template: qwen3_5\n\n\nGLM-4.6V\nBoth GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.\n# GLM-4.6V (106B MoE version)\nbase_model: zai-org/GLM-4.6V\n\n# OR GLM-4.6V-Flash (9B version)\nbase_model: zai-org/GLM-4.6V-Flash\n\n\nSmolVLM2\n\n\n\n\n\n\nTip\n\n\n\nPlease make sure to install num2words via pip3 install num2words==0.5.14\n\n\nbase_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct\n\n\nLFM2-VL\n\n\n\n\n\n\nWarning\n\n\n\nPlease uninstall causal-conv1d via pip3 uninstall -y causal-conv1d\n\n\nbase_model: LiquidAI/LFM2-VL-450M\n\n\nIntern-VL\n\n\n\n\n\n\nTip\n\n\n\nPlease make sure to install timm via pip3 install timm==1.0.19\n\n\nbase_model: OpenGVLab/InternVL3_5-8B",
"crumbs": [
"How To Guides",
"MultiModal / Vision Language Models (BETA)"
]
},
{
"objectID": "docs/multimodal.html#dataset-format",
"href": "docs/multimodal.html#dataset-format",
"title": "MultiModal / Vision Language Models (BETA)",
"section": "Dataset Format",
"text": "Dataset Format\nFor multi-modal datasets, we adopt an extended chat_template format similar to OpenAIs Message format.\n\nA message is a list of role and content.\nrole can be system, user, assistant, etc.\ncontent is a list of type and (text, image, path, url, base64, or audio).\n\n\nImage\n\n\n\n\n\n\nNote\n\n\n\nFor backwards compatibility:\n\nIf the dataset has a images or image column of list[Image], it will be appended to the first content list as {\"type\": \"image\", \"image\": ...}. However, if the content already has a {\"type\": \"image\"} but no image key, it will be set the image key.\nIf content is a string, it will be converted to a list with type as text.\n\n\n\nFor image loading, you can use the following keys within content alongside \"type\": \"image\":\n\n\"path\": \"/path/to/image.jpg\"\n\"url\": \"https://example.com/image.jpg\"\n\"base64\": \"...\"\n\"image\": PIL.Image\n\n\n\nAudio\nFor audio loading, you can use the following keys within content alongside \"type\": \"audio\":\n\n\"path\": \"/path/to/audio.mp3\"\n\"url\": \"https://example.com/audio.mp3\"\n\"audio\": np.ndarray\n\n\n\n\n\n\n\nTip\n\n\n\nYou may need to install librosa via pip3 install librosa==0.11.0.\n\n\n\n\nVideo\n\n\n\n\n\n\nWarning\n\n\n\nThis is not well tested at the moment. We welcome contributors!\n\n\nFor video loading, you can use the following keys within content alongside \"type\": \"video\":\n\n\"path\": \"/path/to/video.mp4\"\n\"url\": \"https://example.com/video.mp4\"\n\"video\": np.ndarray | list[PIL.Image.Image] | torch.Tensor (or list of the aforementioned)\n\n\n\nExample\nHere is an example of a multi-modal dataset:\n[\n {\n \"messages\": [\n {\n \"role\": \"system\",\n \"content\": [\n {\"type\": \"text\", \"text\": \"You are a helpful assistant.\"}\n ]\n },\n {\n \"role\": \"user\",\n \"content\": [\n {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg\"},\n {\"type\": \"text\", \"text\": \"Describe this image in detail.\"}\n ]\n },\n {\n \"role\": \"assistant\",\n \"content\": [\n {\"type\": \"text\", \"text\": \"The image is a bee.\"}\n ]\n }\n ]\n }\n]",
"crumbs": [
"How To Guides",
"MultiModal / Vision Language Models (BETA)"
]
},
{
"objectID": "docs/multimodal.html#faq",
"href": "docs/multimodal.html#faq",
"title": "MultiModal / Vision Language Models (BETA)",
"section": "FAQ",
"text": "FAQ\n\nPIL.UnidentifiedImageError: cannot identify image file ...\n\nPIL could not retrieve the file at url using requests. Please check for typo. One alternative reason is that the request is blocked by the server.",
"crumbs": [
"How To Guides",
"MultiModal / Vision Language Models (BETA)"
]
}
]