Compare commits
10 Commits
topk-logpr
...
train-refa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c0510a876 | ||
|
|
e1bc18763a | ||
|
|
ed5178cd3d | ||
|
|
a3224c7c3c | ||
|
|
c4104fc10c | ||
|
|
75cbd15301 | ||
|
|
2efe1b4c09 | ||
|
|
1110a37e21 | ||
|
|
9850f42204 | ||
|
|
00fc8109e4 |
@@ -50,13 +50,14 @@ Features:
|
||||
## 🚀 Quick Start
|
||||
|
||||
**Requirements**:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
### Installation
|
||||
|
||||
```shell
|
||||
```bash
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
@@ -68,7 +69,7 @@ Other installation approaches are described [here](https://axolotl-ai-cloud.gith
|
||||
|
||||
### Your First Fine-tune
|
||||
|
||||
```shell
|
||||
```bash
|
||||
# Fetch axolotl examples
|
||||
axolotl fetch examples
|
||||
|
||||
|
||||
63
_quarto.yml
63
_quarto.yml
@@ -3,10 +3,12 @@ project:
|
||||
|
||||
website:
|
||||
title: "Axolotl"
|
||||
description: "Fine-tuning"
|
||||
description: "We make fine-tuning accessible, scalable, and fun"
|
||||
favicon: favicon.jpg
|
||||
|
||||
navbar:
|
||||
title: Axolotl
|
||||
logo: image/axolotl_logo_digital_white.svg
|
||||
title: false
|
||||
background: dark
|
||||
pinned: false
|
||||
collapse: false
|
||||
@@ -25,33 +27,58 @@ website:
|
||||
contents:
|
||||
- text: Home
|
||||
href: index.qmd
|
||||
- section: "How-To Guides"
|
||||
|
||||
- section: "Getting Started"
|
||||
contents:
|
||||
# TODO Edit folder structure after we have more docs.
|
||||
- docs/getting-started.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/debugging.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/inference.qmd
|
||||
- docs/multipack.qmd
|
||||
- docs/fsdp_qlora.qmd
|
||||
- docs/input_output.qmd
|
||||
- docs/rlhf.qmd
|
||||
- docs/nccl.qmd
|
||||
- docs/mac.qmd
|
||||
- docs/multi-gpu.qmd
|
||||
- docs/multi-node.qmd
|
||||
- docs/unsloth.qmd
|
||||
- docs/amd_hpc.qmd
|
||||
- docs/ray-integration.qmd
|
||||
|
||||
- section: "Dataset Formats"
|
||||
contents: docs/dataset-formats/*
|
||||
|
||||
- section: "Deployments"
|
||||
contents:
|
||||
- docs/multi-gpu.qmd
|
||||
- docs/multi-node.qmd
|
||||
- docs/ray-integration.qmd
|
||||
- docs/amd_hpc.qmd
|
||||
- docs/mac.qmd
|
||||
|
||||
- section: "How To Guides"
|
||||
contents:
|
||||
- docs/multimodal.qmd
|
||||
- docs/rlhf.qmd
|
||||
- docs/reward_modelling.qmd
|
||||
- docs/lr_groups.qmd
|
||||
- docs/lora_optims.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
- docs/batch_vs_grad.qmd
|
||||
- docs/dataset_preprocessing.qmd
|
||||
- docs/multipack.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
- docs/fsdp_qlora.qmd
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
- docs/faq.qmd
|
||||
- docs/debugging.qmd
|
||||
- docs/nccl.qmd
|
||||
|
||||
- section: "Reference"
|
||||
contents:
|
||||
- docs/config.qmd
|
||||
- docs/faq.qmd
|
||||
|
||||
format:
|
||||
html:
|
||||
theme: materia
|
||||
theme: darkly
|
||||
css: styles.css
|
||||
toc: true
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Training with AMD GPUs on HPC Systems
|
||||
title: AMD GPUs on HPC Systems
|
||||
description: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs
|
||||
---
|
||||
|
||||
|
||||
134
docs/cli.qmd
134
docs/cli.qmd
@@ -1,28 +1,19 @@
|
||||
# Axolotl CLI Documentation
|
||||
---
|
||||
title: "CLI Reference"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 3
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers
|
||||
the CLI commands, their usage, and common examples.
|
||||
|
||||
### Table of Contents
|
||||
|
||||
- Basic Commands
|
||||
- Command Reference
|
||||
- fetch
|
||||
- preprocess
|
||||
- train
|
||||
- inference
|
||||
- merge-lora
|
||||
- merge-sharded-fsdp-weights
|
||||
- evaluate
|
||||
- lm-eval
|
||||
- Legacy CLI Usage
|
||||
- Remote Compute with Modal Cloud
|
||||
- Cloud Configuration
|
||||
- Running on Modal Cloud
|
||||
- Cloud Configuration Options
|
||||
|
||||
|
||||
### Basic Commands
|
||||
## Basic Commands
|
||||
|
||||
All Axolotl commands follow this general structure:
|
||||
|
||||
@@ -32,9 +23,9 @@ axolotl <command> [config.yml] [options]
|
||||
|
||||
The config file can be local or a URL to a raw YAML file.
|
||||
|
||||
### Command Reference
|
||||
## Command Reference
|
||||
|
||||
#### fetch
|
||||
### fetch
|
||||
|
||||
Downloads example configurations and deepspeed configs to your local machine.
|
||||
|
||||
@@ -49,7 +40,7 @@ axolotl fetch deepspeed_configs
|
||||
axolotl fetch examples --dest path/to/folder
|
||||
```
|
||||
|
||||
#### preprocess
|
||||
### preprocess
|
||||
|
||||
Preprocesses and tokenizes your dataset before training. This is recommended for large datasets.
|
||||
|
||||
@@ -74,7 +65,7 @@ dataset_prepared_path: Local folder for saving preprocessed data
|
||||
push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)
|
||||
```
|
||||
|
||||
#### train
|
||||
### train
|
||||
|
||||
Trains or fine-tunes a model using the configuration specified in your YAML file.
|
||||
|
||||
@@ -95,7 +86,38 @@ axolotl train config.yml --no-accelerate
|
||||
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
|
||||
```
|
||||
|
||||
#### inference
|
||||
It is possible to run sweeps over multiple hyperparameters by passing in a sweeps config.
|
||||
|
||||
```bash
|
||||
# Basic training with sweeps
|
||||
axolotl train config.yml --sweep path/to/sweep.yaml
|
||||
```
|
||||
|
||||
Example sweep config:
|
||||
```yaml
|
||||
_:
|
||||
# This section is for dependent variables we need to fix
|
||||
- load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
adapter: lora
|
||||
- load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
adapter: lora
|
||||
|
||||
# These are independent variables
|
||||
learning_rate: [0.0003, 0.0006]
|
||||
lora_r:
|
||||
- 16
|
||||
- 32
|
||||
lora_alpha:
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
||||
```
|
||||
|
||||
|
||||
|
||||
### inference
|
||||
|
||||
Runs inference using your trained model in either CLI or Gradio interface mode.
|
||||
|
||||
@@ -115,7 +137,7 @@ cat prompt.txt | axolotl inference config.yml \
|
||||
--base-model="./completed-model"
|
||||
```
|
||||
|
||||
#### merge-lora
|
||||
### merge-lora
|
||||
|
||||
Merges trained LoRA adapters into the base model.
|
||||
|
||||
@@ -137,7 +159,7 @@ gpu_memory_limit: Limit GPU memory usage
|
||||
lora_on_cpu: Load LoRA weights on CPU
|
||||
```
|
||||
|
||||
#### merge-sharded-fsdp-weights
|
||||
### merge-sharded-fsdp-weights
|
||||
|
||||
Merges sharded FSDP model checkpoints into a single combined checkpoint.
|
||||
|
||||
@@ -146,7 +168,7 @@ Merges sharded FSDP model checkpoints into a single combined checkpoint.
|
||||
axolotl merge-sharded-fsdp-weights config.yml
|
||||
```
|
||||
|
||||
#### evaluate
|
||||
### evaluate
|
||||
|
||||
Evaluates a model's performance using metrics specified in the config.
|
||||
|
||||
@@ -155,27 +177,27 @@ Evaluates a model's performance using metrics specified in the config.
|
||||
axolotl evaluate config.yml
|
||||
```
|
||||
|
||||
#### lm-eval
|
||||
### lm-eval
|
||||
|
||||
Runs LM Evaluation Harness on your model.
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl lm-eval config.yml
|
||||
|
||||
# Evaluate specific tasks
|
||||
axolotl lm-eval config.yml --tasks arc_challenge,hellaswag
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
lm_eval_tasks: List of tasks to evaluate
|
||||
lm_eval_batch_size: Batch size for evaluation
|
||||
output_dir: Directory to save evaluation results
|
||||
# List of tasks to evaluate
|
||||
lm_eval_tasks:
|
||||
- arc_challenge
|
||||
- hellaswag
|
||||
lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
### Legacy CLI Usage
|
||||
## Legacy CLI Usage
|
||||
|
||||
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
||||
|
||||
@@ -195,12 +217,18 @@ accelerate launch -m axolotl.cli.inference config.yml \
|
||||
--lora_model_dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
### Remote Compute with Modal Cloud
|
||||
::: {.callout-important}
|
||||
When overriding CLI parameters in the legacy CLI, use same notation as in yaml file (e.g., `--lora_model_dir`).
|
||||
|
||||
**Note:** This differs from the new Click-based CLI, which uses dash notation (e.g., `--lora-model-dir`). Keep this in mind if you're referencing newer documentation or switching between CLI versions.
|
||||
:::
|
||||
|
||||
## Remote Compute with Modal Cloud
|
||||
|
||||
Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a
|
||||
cloud YAML file alongside your regular Axolotl config.
|
||||
|
||||
#### Cloud Configuration
|
||||
### Cloud Configuration
|
||||
|
||||
Create a cloud config YAML with your Modal settings:
|
||||
|
||||
@@ -215,13 +243,17 @@ branch: main # Git branch to use (optional)
|
||||
volumes: # Persistent storage volumes
|
||||
- name: axolotl-cache
|
||||
mount: /workspace/cache
|
||||
- name: axolotl-data
|
||||
mount: /workspace/data
|
||||
- name: axolotl-artifacts
|
||||
mount: /workspace/artifacts
|
||||
|
||||
env: # Environment variables
|
||||
- WANDB_API_KEY
|
||||
- HF_TOKEN
|
||||
```
|
||||
|
||||
#### Running on Modal Cloud
|
||||
### Running on Modal Cloud
|
||||
|
||||
Commands that support the --cloud flag:
|
||||
|
||||
@@ -239,18 +271,18 @@ axolotl train config.yml --cloud cloud_config.yml --no-accelerate
|
||||
axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
```
|
||||
|
||||
#### Cloud Configuration Options
|
||||
### Cloud Configuration Options
|
||||
|
||||
```yaml
|
||||
provider: compute provider, currently only `modal` is supported
|
||||
gpu: GPU type to use
|
||||
gpu_count: Number of GPUs (default: 1)
|
||||
memory: RAM in GB (default: 128)
|
||||
timeout: Maximum runtime in seconds
|
||||
timeout_preprocess: Preprocessing timeout
|
||||
branch: Git branch to use
|
||||
docker_tag: Custom Docker image tag
|
||||
volumes: List of persistent storage volumes
|
||||
env: Environment variables to pass
|
||||
secrets: Secrets to inject
|
||||
provider: # compute provider, currently only `modal` is supported
|
||||
gpu: # GPU type to use
|
||||
gpu_count: # Number of GPUs (default: 1)
|
||||
memory: # RAM in GB (default: 128)
|
||||
timeout: # Maximum runtime in seconds
|
||||
timeout_preprocess: # Preprocessing timeout
|
||||
branch: # Git branch to use
|
||||
docker_tag: # Custom Docker image tag
|
||||
volumes: # List of persistent storage volumes
|
||||
env: # Environment variables to pass
|
||||
secrets: # Secrets to inject
|
||||
```
|
||||
|
||||
@@ -166,7 +166,7 @@ datasets:
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||
# Note: If the below 4 fields are empty, defaults to training only on the last message.
|
||||
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["assistant"] # default
|
||||
@@ -174,6 +174,7 @@ datasets:
|
||||
# - all: train on all EOS tokens
|
||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
train_on_eos: last
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
|
||||
57
docs/custom_integrations.qmd
Normal file
57
docs/custom_integrations.qmd
Normal file
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: Custom Integrations
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
---
|
||||
|
||||
```{python}
|
||||
#| echo: false
|
||||
|
||||
import re
|
||||
|
||||
def process_readme(integration_name):
|
||||
try:
|
||||
path = f'../src/axolotl/integrations/{integration_name}/README.md'
|
||||
with open(path, 'r') as f:
|
||||
txt = f.read()
|
||||
# Remove h1 headings
|
||||
txt = re.sub(r'^# .*\n?', '', txt, flags=re.MULTILINE)
|
||||
# Convert h2 to h3
|
||||
txt = re.sub(r'^## ', '### ', txt, flags=re.MULTILINE)
|
||||
return txt
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def print_section(name, folder_name):
|
||||
output = f"\n## {name}\n"
|
||||
content = process_readme(folder_name)
|
||||
if content:
|
||||
output += content
|
||||
output += f"\nPlease see reference [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/{folder_name})\n"
|
||||
return output
|
||||
```
|
||||
|
||||
```{python}
|
||||
#| output: asis
|
||||
#| echo: false
|
||||
|
||||
# Introduction text
|
||||
print("""
|
||||
Axolotl adds custom features through `integrations`. They are located within the `src/axolotl/integrations` directory.
|
||||
|
||||
To enable them, please check the respective documentations.
|
||||
""")
|
||||
|
||||
# Sections
|
||||
sections = [
|
||||
("Cut Cross Entropy", "cut_cross_entropy"),
|
||||
("Grokfast", "grokfast"),
|
||||
("Knowledge Distillation (KD)", "kd"),
|
||||
("Liger Kernels", "liger"),
|
||||
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
|
||||
("Spectrum", "spectrum")
|
||||
]
|
||||
|
||||
for section_name, folder_name in sections:
|
||||
print(print_section(section_name, folder_name))
|
||||
```
|
||||
@@ -6,7 +6,9 @@ order: 3
|
||||
|
||||
## sharegpt
|
||||
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
::: {.callout-important}
|
||||
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
:::
|
||||
|
||||
## pygmalion
|
||||
|
||||
@@ -102,6 +104,10 @@ datasets:
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
:::
|
||||
|
||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
For a data sample that looks like:
|
||||
@@ -149,4 +155,6 @@ datasets:
|
||||
message_field_training_detail: train_detail
|
||||
```
|
||||
|
||||
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.
|
||||
::: {.callout-tip}
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
@@ -13,7 +13,7 @@ As there are a lot of available options in Axolotl, this guide aims to provide a
|
||||
|
||||
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
|
||||
|
||||
## [Pre-training](pretraining.qmd)
|
||||
## Pre-training
|
||||
|
||||
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
|
||||
|
||||
@@ -96,6 +96,10 @@ One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_st
|
||||
|
||||
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
|
||||
|
||||
### Reference
|
||||
|
||||
Please see docs [here](pretraining.qmd).
|
||||
|
||||
## Supervised fine-tuning (SFT)
|
||||
|
||||
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.
|
||||
@@ -120,7 +124,7 @@ If you went through the flow chart and did not find one that matches, it is reco
|
||||
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
|
||||
:::
|
||||
|
||||
### [Pre-Tokenized Dataset](tokenized.qmd)
|
||||
### Pre-Tokenized Dataset
|
||||
|
||||
We suggest this approach when you want to bring your own tokenized dataset.
|
||||
|
||||
@@ -145,7 +149,9 @@ datasets:
|
||||
`type: ` is empty!
|
||||
:::
|
||||
|
||||
### [Template Free Dataset](template_free.qmd)
|
||||
Reference: [Pre-Tokenized Dataset Documentation](tokenized.qmd).
|
||||
|
||||
### Template Free Dataset
|
||||
|
||||
We reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldn't suffice.
|
||||
|
||||
@@ -182,7 +188,9 @@ datasets:
|
||||
type: input_output
|
||||
```
|
||||
|
||||
### [Conversation Dataset](conversation.qmd)
|
||||
Reference: [Template Free Documentation](template_free.qmd).
|
||||
|
||||
### Conversation Dataset
|
||||
|
||||
`conversation` messages are a list of messages which usually contain a `role` and `content` key.
|
||||
|
||||
@@ -258,7 +266,7 @@ Newer conversation datasets usually follow the OpenAI format.
|
||||
|
||||
Axolotl supports both as well as allowing customization of any kind of key.
|
||||
|
||||
#### [Chat Template Usage](conversation.qmd#chat_template)
|
||||
#### Chat Template Usage
|
||||
|
||||
To properly use this method, it is important to identify three things:
|
||||
|
||||
@@ -340,9 +348,19 @@ datasets:
|
||||
narrator: ["narrator"]
|
||||
```
|
||||
|
||||
#### Applying `chat_template`
|
||||
::: {.callout-tip}
|
||||
As chat_templates may use hardcoded EOS/EOT tokens that are different from the tokenizer's EOS, it is highly recommended to set them. For example, `ChatML` uses `<|im_end|>` to end turns.
|
||||
|
||||
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset. The final step would be to correctly set the EOS token in your config:
|
||||
```yaml
|
||||
special_tokens:
|
||||
eos_token: <|im_end|>
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
##### Applying `chat_template`
|
||||
|
||||
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -391,7 +409,17 @@ If this config were to be applied to the sample dataset above, the output would
|
||||
|
||||
The first number refers to the label, the second refers to the `token_id`. For example, `-100` labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the `token_id`.
|
||||
|
||||
### [Instruction Dataset](inst_tune.qmd)
|
||||
::: {.callout-note}
|
||||
|
||||
If during `preprocess`, there are a lot of warnings of `Could not find content __ boundary`, please check the FAQ section for [chat_templates](../faq.qmd#chat-templates).
|
||||
|
||||
:::
|
||||
|
||||
#### Reference
|
||||
|
||||
Please see docs [here](conversation.qmd).
|
||||
|
||||
### Instruction Dataset
|
||||
|
||||
Instruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.
|
||||
|
||||
@@ -423,6 +451,9 @@ datasets:
|
||||
|
||||
Axolotl supports many kinds of instruction dataset. All of them can be found here (https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/inst_tune.html) with their respective type and sample row format.
|
||||
|
||||
|
||||
Reference: [Instruction Dataset Documentation](inst_tune.qmd).
|
||||
|
||||
#### Custom Instruct Prompt Format
|
||||
|
||||
Due to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.
|
||||
@@ -453,6 +484,8 @@ datasets:
|
||||
|
||||
The config sets that the `field_instruction` is actually named `input`, and the `field_input` is empty as we don't have an `input` in this sample. Generally, `instruction` can be thought as the question to the model, and `input` as the additional information with `output` being the response. It is not necessary to have an `input` nor `system`. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.
|
||||
|
||||
Reference: [Custom Instruct Prompt Format Documentation](inst_tune.qmd#how-to-add-custom-prompt-format).
|
||||
|
||||
## Reinforcement Learning from Human Feedback (RLHF)
|
||||
|
||||
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF datasets](../rlhf.qmd) documentation for more detail.
|
||||
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF documentation](../rlhf.qmd) for more detail.
|
||||
|
||||
@@ -27,7 +27,6 @@ pretraining_dataset:
|
||||
type: pretrain
|
||||
trust_remote_code:
|
||||
skip: # number of rows of data to skip over from the beginning
|
||||
...
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
@@ -1,7 +1,239 @@
|
||||
---
|
||||
title: Template-Free
|
||||
description: Construct prompts without a template.
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
order: 4
|
||||
---
|
||||
|
||||
See [these docs](../input_output.qmd).
|
||||
## Background {#sec-background}
|
||||
|
||||
### Masking Inputs {#masking-inputs}
|
||||
|
||||
One of the most popular features of
|
||||
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
|
||||
setting the following configuration value:
|
||||
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false
|
||||
```
|
||||
|
||||
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
|
||||
such as `alpaca` or `chatml`, axolotl knows what is an input
|
||||
(i.e. human) vs. an output (i.e. the assistant) and masks the input
|
||||
labels so that your model can focus on predicting the outputs only.
|
||||
|
||||
### You may not want prompt templates {#sec-you-may-not-want-prompt-templates}
|
||||
|
||||
However, there are many situations where you don't want to use one of
|
||||
these formats or templates. This is because they can:
|
||||
|
||||
- Add unnecessary boilerplate to your prompts.
|
||||
- Create artifacts like special delimiters `<|im_start|>` that can
|
||||
quickly become footguns if you don't include them correctly at
|
||||
inference time.
|
||||
- Enforce a *chat* interface when you do not want one. Sometimes you
|
||||
just want to fine-tune a model to a very specific task and do NOT
|
||||
want multi-turn conversations, roles, etc.
|
||||
- Limit you to only certain roles that the template allows.
|
||||
|
||||
### The `input_output` format {#sec-the-inputoutput-format}
|
||||
|
||||
You can construct your prompts without a template by using the
|
||||
`input_output` format, by setting `type: input_output` in your
|
||||
configuration file like this:
|
||||
|
||||
**config.yml**
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false # Mask segments of your data
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output # use template free prompt construction
|
||||
```
|
||||
|
||||
Unlike `type: completion`, which is also template-free,
|
||||
`type: input_output` allows you to mask segments of your text. More
|
||||
details on how this works are described below.
|
||||
|
||||
## Usage {#sec-usage}
|
||||
|
||||
This is how you can use the `input_output` format:
|
||||
|
||||
### 1. Prepare Data {#sec-1-prepare-data}
|
||||
|
||||
To use the `input_output` format, collect your data in the following
|
||||
format into a jsonl file (below is the first row from the file
|
||||
`output`.jsonl` pretty printed):
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
```
|
||||
|
||||
:::{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
:::
|
||||
|
||||
Set `label:false` when you want to mask a segment of text so that the
|
||||
model isn't trained on it. Some things to keep in mind:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
|
||||
concatenates all the segments as-is.** The tokenizer doesn't add
|
||||
anything additional. Notice how I added spaces, newlines, `<s>`
|
||||
(BOS), and `</s>` (EOS) myself.
|
||||
> 2. Make sure you check the materialized output to validate that the
|
||||
prompt is getting assembled how you like.
|
||||
|
||||
### 2. Use `type: input_output` {#sec-2-use-type-inputoutput}
|
||||
|
||||
Let's materialize data with our `output.jsonl` file by setting
|
||||
`type: input_output` in our axolotl config:
|
||||
|
||||
```yaml
|
||||
# training_config.yaml
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
data_seed: 49
|
||||
seed: 49
|
||||
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output
|
||||
val_set_size: 0.1
|
||||
|
||||
sequence_len: 896
|
||||
sample_packing: false
|
||||
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 3
|
||||
eval_batch_size: 2
|
||||
num_epochs: 1
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
```
|
||||
|
||||
You can use the following command to materialize your data. The
|
||||
`--debug` flag will print the tokens, along with the labels so you can
|
||||
verify that the correct items are being ignored:
|
||||
|
||||
```bash
|
||||
axolotl preprocess training_config.yaml --debug
|
||||
|
||||
...
|
||||
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
|
||||
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
|
||||
|
||||
```
|
||||
|
||||
The format is `decoded_token`(`label`, `token_id`), for example,
|
||||
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
|
||||
token_id is `1`. When the label is `-100` then that token is ignored for
|
||||
training.
|
||||
|
||||
### 3. Check the prompts {#sec-3-check-the-prompts}
|
||||
|
||||
Here is another way to check the materialized output:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from datasets import load_from_disk
|
||||
import yaml
|
||||
|
||||
directory = !ls last_run_prepared/
|
||||
with open('training_config.yaml', 'r') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
model_id = cfg['base_model']
|
||||
tok = AutoTokenizer.from_pretrained(model_id)
|
||||
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
|
||||
```
|
||||
|
||||
```python
|
||||
>>> row = ds[0]
|
||||
>>> print(tok.decode(row['input_ids']))
|
||||
<s> Hello
|
||||
hi there!. goodbye farewell</s>
|
||||
```
|
||||
|
||||
We can check that the right tokens are ignored by comparing the labels
|
||||
to each token:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
|
||||
zip(row['input_ids'], row['labels'])])
|
||||
```
|
||||
|
||||
| token | label | id |
|
||||
|-------|-------|-------|
|
||||
| 0 | \<s\> | 1 |
|
||||
| 1 | Hello | 22557 |
|
||||
| 2 | \\n | 13 |
|
||||
| 3 | hi | 12014 |
|
||||
| 4 | there | 736 |
|
||||
| 5 | ! | 28808 |
|
||||
| 6 | . | 28723 |
|
||||
| 7 | | 28705 |
|
||||
| 8 | good | -100 |
|
||||
| 9 | bye | -100 |
|
||||
| 10 | | -100 |
|
||||
| 11 | fare | 19111 |
|
||||
| 12 | well | 5458 |
|
||||
| 13 | \</s\>| 2 |
|
||||
|
||||
|
||||
|
||||
If we look at the input data, the above table seems correct! (The jsonl
|
||||
version is repeated below for reference):
|
||||
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
```
|
||||
|
||||
:::{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
:::
|
||||
|
||||
@@ -3,8 +3,11 @@ title: Dataset Preprocessing
|
||||
description: How datasets are processed
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||
the (dataset format)[../dataset-formats/] and prompt strategies to:
|
||||
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
||||
|
||||
- parse the dataset based on the *dataset format*
|
||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||
- tokenize the dataset based on the configured model & tokenizer
|
||||
@@ -12,10 +15,12 @@ the (dataset format)[../dataset-formats/] and prompt strategies to:
|
||||
|
||||
The processing of the datasets can happen one of two ways:
|
||||
|
||||
1. Before kicking off training by calling `python -m axolotl.cli.preprocess /path/to/your.yaml --debug`
|
||||
1. Before kicking off training by calling `axolotl preprocess config.yaml --debug`
|
||||
2. When training is started
|
||||
|
||||
What are the benefits of pre-processing? When training interactively or for sweeps
|
||||
### What are the benefits of pre-processing?
|
||||
|
||||
When training interactively or for sweeps
|
||||
(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly
|
||||
slow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent
|
||||
training parameters so that it will intelligently pull from its cache when possible.
|
||||
@@ -28,8 +33,12 @@ default path of `./last_run_prepared/`, but will ignore anything already cached
|
||||
setting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed
|
||||
data is in the cache.
|
||||
|
||||
What are the edge cases? Let's say you are writing a custom prompt strategy or using a user-defined
|
||||
### What are the edge cases?
|
||||
|
||||
Let's say you are writing a custom prompt strategy or using a user-defined
|
||||
prompt template. Because the trainer cannot readily detect these changes, we cannot change the
|
||||
calculated hash value for the pre-processed dataset. If you have `dataset_prepared_path: ...` set
|
||||
calculated hash value for the pre-processed dataset.
|
||||
|
||||
If you have `dataset_prepared_path: ...` set
|
||||
and change your prompt templating logic, it may not pick up the changes you made and you will be
|
||||
training over the old prompt.
|
||||
|
||||
@@ -31,11 +31,13 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
|
||||
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
|
||||
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
|
||||
|
||||
```yaml
|
||||
dataset:
|
||||
datasets:
|
||||
...
|
||||
shards: 20
|
||||
```
|
||||
|
||||
3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
|
||||
4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
|
||||
- `micro_batch_size: 1`
|
||||
@@ -85,7 +87,7 @@ The easiest way to get started is to modify the [.vscode/launch.json](../.vscode
|
||||
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
|
||||
```jsonc
|
||||
```json
|
||||
// .vscode/launch.json
|
||||
{
|
||||
"version": "0.2.0",
|
||||
@@ -132,7 +134,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
|
||||
Below is the [./vscode/tasks.json](../.vscode/tasks.json) file that defines the `cleanup-for-dataprep` task. This task is run before each debugging session when you use the above configuration. Note how there are two tasks that delete the two folders mentioned above. The third task `cleanup-for-dataprep` is a composite task that combines the two tasks. A composite task is necessary because VSCode does not allow you to specify multiple tasks in the `preLaunchTask` argument of the `launch.json` file.
|
||||
|
||||
```jsonc
|
||||
```json
|
||||
// .vscode/tasks.json
|
||||
// this file is used by launch.json
|
||||
{
|
||||
|
||||
23
docs/faq.qmd
23
docs/faq.qmd
@@ -3,6 +3,7 @@ title: FAQ
|
||||
description: Frequently asked questions
|
||||
---
|
||||
|
||||
### General
|
||||
|
||||
**Q: The trainer stopped and hasn't progressed in several minutes.**
|
||||
|
||||
@@ -24,6 +25,28 @@ description: Frequently asked questions
|
||||
|
||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
|
||||
|
||||
**Q: `Empty template generated for turn ___`**
|
||||
|
||||
> A: The `content` is empty for that turn.
|
||||
|
||||
**Q: `Could not find content start/end boundary for turn __`**
|
||||
|
||||
> A: The specific turn's start/end could not be detected. Please ensure you have set the `eos_token` following your `chat_template`. Otherwise, this could be a `chat_template` which doesn't use proper boundaries for each turn (like system). On the rare occurrence, make sure your content is not `[[dummy_message]]`. Please let us know about this.
|
||||
|
||||
**Q: `Content end boundary is before start boundary for turn ___`**
|
||||
|
||||
> A: This is an edge case which should not occur. Please create an Issue if this happens.
|
||||
|
||||
**Q: `Content end boundary is the same as start boundary for turn ___. This is likely an empty turn.`**
|
||||
|
||||
> A: This is likely an empty turn.
|
||||
|
||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||
|
||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Getting Started with Axolotl"
|
||||
title: "Quickstart"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
@@ -17,12 +17,12 @@ Let's start by fine-tuning a small language model using LoRA. This example uses
|
||||
Assuming `axolotl` is installed (if not, see our [Installation Guide](installation.qmd))
|
||||
|
||||
1. Download example configs:
|
||||
```shell
|
||||
```bash
|
||||
axolotl fetch examples
|
||||
```
|
||||
|
||||
2. Run the training:
|
||||
```shell
|
||||
```bash
|
||||
axolotl train examples/llama-3/lora-1b.yml
|
||||
```
|
||||
|
||||
@@ -108,7 +108,7 @@ Please consult the supported [Dataset Formats](dataset-formats/) for more detail
|
||||
|
||||
3. Run the training:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
axolotl train my_training.yml
|
||||
```
|
||||
|
||||
@@ -118,7 +118,7 @@ axolotl train my_training.yml
|
||||
|
||||
After training, test your model:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
```
|
||||
|
||||
@@ -126,7 +126,7 @@ axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
|
||||
For large datasets, preprocess first:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
axolotl preprocess my_training.yml
|
||||
```
|
||||
|
||||
@@ -134,7 +134,7 @@ axolotl preprocess my_training.yml
|
||||
|
||||
Launch a Gradio interface:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
---
|
||||
title: "Inference Guide"
|
||||
title: "Inference"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
code-tools: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
@@ -3,263 +3,4 @@ title: Template-free prompt construction
|
||||
description: "Template-free prompt construction with the `input_output` format"
|
||||
---
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [Background](#background)
|
||||
- [Masking Inputs](#masking-inputs)
|
||||
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
|
||||
- [The `input_output` format](#the-input_output-format)
|
||||
- [Usage](#usage)
|
||||
- [1. Prepare Data](#1-prepare-data)
|
||||
- [2. Use `type: input_output`](#2-use-type-input_output)
|
||||
- [3. Check the prompts](#3-check-the-prompts)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
<a id="markdown-background" name="background"></a>
|
||||
|
||||
## Background
|
||||
|
||||
<a id="markdown-masking-inputs" name="masking-inputs"></a>
|
||||
|
||||
### Masking Inputs
|
||||
|
||||
One of the most popular features of
|
||||
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
|
||||
setting the following configuration value:
|
||||
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false
|
||||
```
|
||||
|
||||
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
|
||||
such as `alpaca` or `chatml`, axolotl knows what is an input
|
||||
(i.e. human) vs. an output (i.e. the assistant) and masks the input
|
||||
labels so that your model can focus on predicting the outputs only.
|
||||
|
||||
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
|
||||
|
||||
### You may not want prompt templates
|
||||
|
||||
However, there are many situations where you don't want to use one of
|
||||
these formats or templates. This is because they can:
|
||||
|
||||
- Add unnecessary boilerplate to your prompts.
|
||||
- Create artifacts like special delimiters `<|im_start|>` that can
|
||||
quickly become footguns if you don't include them correctly at
|
||||
inference time.
|
||||
- Enforce a *chat* interface when you do not want one. Sometimes you
|
||||
just want to fine-tune a model to a very specific task and do NOT
|
||||
want multi-turn conversations, roles, etc.
|
||||
- Limit you to only certain roles that the template allows.
|
||||
|
||||
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
|
||||
|
||||
### The `input_output` format
|
||||
|
||||
You can construct your prompts without a template by using the
|
||||
`input_output` format, by setting `type: input_output` in your
|
||||
configuration file like this:
|
||||
|
||||
**config.yml**
|
||||
|
||||
```yaml
|
||||
train_on_inputs: false # Mask segments of your data
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output # use template free prompt construction
|
||||
```
|
||||
|
||||
Unlike `type: completion`, which is also template-free,
|
||||
`type: input_output` allows you to mask segments of your text. More
|
||||
details on how this works are described below.
|
||||
|
||||
<a id="markdown-usage" name="usage"></a>
|
||||
|
||||
## Usage
|
||||
|
||||
This is how you can use the `input_output` format:
|
||||
|
||||
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
|
||||
|
||||
### 1. Prepare Data
|
||||
|
||||
To use the `input_output` format, collect your data in the following
|
||||
format into a jsonl file (below is the first row from the file
|
||||
`output`.jsonl` pretty printed):
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
```
|
||||
|
||||
:::{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
:::
|
||||
|
||||
Set `label:false` when you want to mask a segment of text so that the
|
||||
model isn't trained on it. Some things to keep in mind:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
|
||||
concatenates all the segments as-is.** The tokenizer doesn't add
|
||||
anything additional. Notice how I added spaces, newlines, `<s>`
|
||||
(BOS), and `</s>` (EOS) myself.
|
||||
> 2. Make sure you check the materialized output to validate that the
|
||||
prompt is getting assembled how you like.
|
||||
|
||||
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
|
||||
|
||||
### 2. Use `type: input_output`
|
||||
|
||||
Let's materialize data with our `output.jsonl` file by setting
|
||||
`type: input_output` in our axolotl config:
|
||||
|
||||
```yaml
|
||||
# training_config.yaml
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
data_seed: 49
|
||||
seed: 49
|
||||
|
||||
datasets:
|
||||
- path: output.jsonl
|
||||
type: input_output
|
||||
val_set_size: 0.1
|
||||
|
||||
sequence_len: 896
|
||||
sample_packing: false
|
||||
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 3
|
||||
eval_batch_size: 2
|
||||
num_epochs: 1
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
```
|
||||
|
||||
You can use the following command to materialize your data. The
|
||||
`--debug` flag will print the tokens, along with the labels so you can
|
||||
verify that the correct items are being ignored:
|
||||
|
||||
```bash
|
||||
$ python -m axolotl.cli.preprocess training_config.yaml --debug
|
||||
|
||||
...
|
||||
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
|
||||
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
|
||||
|
||||
```
|
||||
|
||||
The format is `decoded_token`(`label`, `token_id`), for example,
|
||||
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
|
||||
token_id is `1`. When the label is `-100` then that token is ignored for
|
||||
training.
|
||||
|
||||
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
|
||||
|
||||
### 3. Check the prompts
|
||||
|
||||
Here is another way to check the materialized output:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from datasets import load_from_disk
|
||||
import yaml
|
||||
|
||||
directory = !ls last_run_prepared/
|
||||
with open('training_config.yaml', 'r') as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
model_id = cfg['base_model']
|
||||
tok = AutoTokenizer.from_pretrained(model_id)
|
||||
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
|
||||
```
|
||||
|
||||
```python
|
||||
>>> row = ds[0]
|
||||
>>> print(tok.decode(row['input_ids']))
|
||||
<s> Hello
|
||||
hi there!. goodbye farewell</s>
|
||||
```
|
||||
|
||||
We can check that the right tokens are ignored by comparing the labels
|
||||
to each token:
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
|
||||
zip(row['input_ids'], row['labels'])])
|
||||
```
|
||||
|
||||
| token | label | id |
|
||||
|-------|-------|-------|
|
||||
| 0 | \<s\> | 1 |
|
||||
| 1 | Hello | 22557 |
|
||||
| 2 | \\n | 13 |
|
||||
| 3 | hi | 12014 |
|
||||
| 4 | there | 736 |
|
||||
| 5 | ! | 28808 |
|
||||
| 6 | . | 28723 |
|
||||
| 7 | | 28705 |
|
||||
| 8 | good | -100 |
|
||||
| 9 | bye | -100 |
|
||||
| 10 | | -100 |
|
||||
| 11 | fare | 19111 |
|
||||
| 12 | well | 5458 |
|
||||
| 13 | \</s\>| 2 |
|
||||
|
||||
|
||||
|
||||
If we look at the input data, the above table seems correct! (The jsonl
|
||||
version is repeated below for reference):
|
||||
|
||||
|
||||
```bash
|
||||
$ head -n1 output.jsonl | python -m json.tool
|
||||
```
|
||||
|
||||
:::{.cell-output .cell-output-stdout}
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
:::
|
||||
The documentation moved to [here](dataset-formats/template_free.qmd).
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
---
|
||||
title: "Installation Guide"
|
||||
title: "Installation"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
code-tools: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
---
|
||||
title: "LoRA Optimizations"
|
||||
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
|
||||
LoRA fine-tuning"
|
||||
description: "Custom autograd functions and Triton kernels in Axolotl for optimized LoRA fine-tuning"
|
||||
---
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
|
||||
@@ -19,4 +19,5 @@ Current support:
|
||||
- [ ] DeepSpeed
|
||||
|
||||
Untested:
|
||||
|
||||
- FSDP
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Multi-GPU Training Guide"
|
||||
title: "Multi-GPU"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
@@ -35,7 +35,11 @@ deepspeed: deepspeed_configs/zero1.json
|
||||
### Usage {#sec-deepspeed-usage}
|
||||
|
||||
```{.bash}
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
|
||||
# Passing arg via config
|
||||
axolotl train config.yml
|
||||
|
||||
# Passing arg via cli
|
||||
axolotl train config.yml --deepspeed deepspeed_configs/zero1.json
|
||||
```
|
||||
|
||||
### ZeRO Stages {#sec-zero-stages}
|
||||
@@ -70,25 +74,7 @@ For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||
|
||||
### Liger Kernel Integration {#sec-liger}
|
||||
|
||||
::: {.callout-note}
|
||||
Liger Kernel provides efficient Triton kernels for LLM training, offering:
|
||||
|
||||
- 20% increase in multi-GPU training throughput
|
||||
- 60% reduction in memory usage
|
||||
- Compatibility with both FSDP and DeepSpeed
|
||||
:::
|
||||
|
||||
Configuration:
|
||||
|
||||
```{.yaml}
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
Please see [docs](custom_integrations.qmd#liger) for more info.
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ You will also need to have the same configuration file for your model on each ma
|
||||
Make sure the main machine is reachable by other machines.
|
||||
:::
|
||||
|
||||
# Accelerate
|
||||
## Accelerate
|
||||
|
||||
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
|
||||
|
||||
@@ -51,17 +51,17 @@ fsdp_config:
|
||||
|
||||
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
|
||||
|
||||
# Raytrain
|
||||
## Raytrain
|
||||
|
||||
Please see ray train doc [here](ray-integration.qmd).
|
||||
|
||||
# Torchrun
|
||||
## Torchrun
|
||||
|
||||
If you are using Infiniband, we recommend torchrun to utilize the full bandwidth.
|
||||
|
||||
Set the following env (change buffersize/socketname depending on your system):
|
||||
|
||||
```yaml
|
||||
```bash
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
|
||||
export NCCL_BUFFSIZE=2097152
|
||||
|
||||
@@ -13,13 +13,13 @@ Often, this timeout will happen after 30 minutes (the default setting) and is ac
|
||||
|
||||
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
nvidia-smi nvlink --status
|
||||
```
|
||||
|
||||
To force NCCL to use NVLink, simply set this in the environment:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
export NCCL_P2P_LEVEL=NVL
|
||||
```
|
||||
|
||||
@@ -33,13 +33,13 @@ If NVLink is not available in your environment there are other options for ``NCC
|
||||
|
||||
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
|
||||
|
||||
```shell
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_DEBUG_SUBSYS=ALL
|
||||
export TORCH_DISTRIBUTED_DEBUG=INFO
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Ray Train integration
|
||||
title: Ray Train
|
||||
description: How to use Axolotl with Ray Train
|
||||
---
|
||||
|
||||
@@ -9,7 +9,7 @@ With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](htt
|
||||
|
||||
## Ray cluster setup
|
||||
|
||||
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html
|
||||
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs [here](https://docs.ray.io/en/latest/cluster/getting-started.html).
|
||||
|
||||
Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).
|
||||
|
||||
@@ -58,13 +58,11 @@ You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.
|
||||
The key parameters to note here are:
|
||||
|
||||
```yaml
|
||||
...
|
||||
use_ray: true
|
||||
ray_num_workers: 4
|
||||
# optional
|
||||
resources_per_worker:
|
||||
GPU: 1
|
||||
...
|
||||
```
|
||||
|
||||
- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.
|
||||
|
||||
116
docs/rlhf.qmd
116
docs/rlhf.qmd
@@ -3,22 +3,22 @@ title: "RLHF (Beta)"
|
||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
# Overview
|
||||
## Overview
|
||||
|
||||
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
||||
feedback. Various methods include, but not limited to:
|
||||
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
- [Direct Preference Optimization (DPO)](#dpo)
|
||||
- [Identity Preference Optimization (IPO)](#ipo)
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
|
||||
|
||||
# RLHF using Axolotl
|
||||
## RLHF using Axolotl
|
||||
|
||||
::: {.callout-important}
|
||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
@@ -30,7 +30,7 @@ We rely on the [TRL](https://github.com/huggingface/trl) library for implementat
|
||||
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
||||
:::
|
||||
|
||||
## DPO
|
||||
### DPO
|
||||
|
||||
Example config:
|
||||
|
||||
@@ -47,7 +47,7 @@ datasets:
|
||||
|
||||
DPO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
#### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -58,7 +58,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
#### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -73,7 +73,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.icr
|
||||
#### chatml.icr
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -84,7 +84,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
#### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -95,7 +95,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
#### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -106,7 +106,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
#### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -123,7 +123,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
#### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -134,7 +134,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
#### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -149,7 +149,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.icr
|
||||
#### llama3.icr
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -160,7 +160,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
#### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -171,7 +171,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
#### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -182,7 +182,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
#### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -199,7 +199,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### zephyr.nectar
|
||||
#### zephyr.nectar
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -218,7 +218,7 @@ DPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chat_template.default
|
||||
#### chat_template.default
|
||||
|
||||
```yaml
|
||||
rl: dpo
|
||||
@@ -264,7 +264,7 @@ Sample input format:
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
#### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
@@ -295,7 +295,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
}
|
||||
```
|
||||
|
||||
## IPO
|
||||
### IPO
|
||||
|
||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||
|
||||
@@ -303,7 +303,7 @@ As IPO is just DPO with a different loss function, all supported options for DPO
|
||||
rl: ipo
|
||||
```
|
||||
|
||||
## ORPO
|
||||
### ORPO
|
||||
|
||||
Paper: https://arxiv.org/abs/2403.07691
|
||||
|
||||
@@ -320,7 +320,7 @@ datasets:
|
||||
|
||||
ORPO supports the following types with the following dataset format:
|
||||
|
||||
### chat_template.argilla
|
||||
#### chat_template.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -339,7 +339,7 @@ ORPO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
## KTO
|
||||
### KTO
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
@@ -360,7 +360,7 @@ gradient_checkpointing_kwargs:
|
||||
|
||||
KTO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
#### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -370,7 +370,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
#### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -383,7 +383,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
#### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -393,7 +393,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
#### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -403,7 +403,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
#### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -413,7 +413,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
#### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -423,7 +423,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
#### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -434,7 +434,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
#### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -444,7 +444,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
#### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -454,7 +454,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
#### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -464,7 +464,7 @@ KTO supports the following types with the following dataset format:
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
#### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
@@ -494,7 +494,49 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
}
|
||||
```
|
||||
|
||||
## Using local dataset files
|
||||
### GRPO
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
|
||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
import random
|
||||
|
||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
|
||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]},],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
```
|
||||
|
||||
```yaml
|
||||
rl: grpo
|
||||
|
||||
trl:
|
||||
beta: 0.001
|
||||
max_completion_length: 256
|
||||
use_vllm: True
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.15
|
||||
num_generations: 4
|
||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
|
||||
```
|
||||
|
||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||
|
||||
### Using local dataset files
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -505,7 +547,7 @@ datasets:
|
||||
type: chatml.intel
|
||||
```
|
||||
|
||||
## TRL auto-unwrapping for PEFT
|
||||
### TRL auto-unwrapping for PEFT
|
||||
|
||||
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
||||
|
||||
|
||||
@@ -3,6 +3,12 @@ title: "PyTorch ao"
|
||||
description: "Custom data types and layouts for training and inference"
|
||||
---
|
||||
|
||||
To use experimental optimizers (`AdamWFp8`, `AdamW4bit`, `AdamW8bit`) from Pytorch Ao, please install the package as shown below.
|
||||
|
||||
::: {.callout-tip}
|
||||
Some experimental optimizers are already present in regular Pytorch, so please re-check if you actually need this package!
|
||||
:::
|
||||
|
||||
### Installation
|
||||
|
||||
Stable Release from the PyTorch index
|
||||
|
||||
@@ -8,6 +8,12 @@ description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
||||
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
||||
standard industry baselines.
|
||||
|
||||
::: {.callout-important}
|
||||
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
|
||||
|
||||
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
|
||||
:::
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -17,7 +23,7 @@ The following will install the correct unsloth and extras from source.
|
||||
python scripts/unsloth_install.py | sh
|
||||
```
|
||||
|
||||
### Using unsloth w Axolotl
|
||||
### Usage
|
||||
|
||||
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
toc-location: right-body
|
||||
toc-title: Table Of Contents
|
||||
toc-expand: 2
|
||||
# toc-location: right-body
|
||||
# toc-title: Table Of Contents
|
||||
# toc-expand: 2
|
||||
---
|
||||
|
||||
```{python}
|
||||
|
||||
@@ -7,7 +7,7 @@ mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.4.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.2
|
||||
liger-kernel==0.5.3
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
@@ -258,25 +258,21 @@ class ModalCloud(Cloud):
|
||||
|
||||
|
||||
def _preprocess(config_yaml: str, volumes=None):
|
||||
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_folder = "/workspace/mounts"
|
||||
run_cmd(
|
||||
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
|
||||
"axolotl preprocess /workspace/mounts/config.yaml --dataset-processes=8",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_folder = "/workspace/mounts"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
else:
|
||||
@@ -285,20 +281,18 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _lm_eval(config_yaml: str, volumes=None):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_folder = "/workspace/mounts"
|
||||
run_cmd(
|
||||
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
|
||||
"axolotl lm-eval /workspace/mounts/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -24,8 +24,8 @@ class TrainDatasetMeta:
|
||||
"""Dataclass with fields for training and validation datasets and metadata."""
|
||||
|
||||
train_dataset: Dataset
|
||||
eval_dataset: Optional[Dataset] = None
|
||||
total_num_steps: Optional[int] = None
|
||||
eval_dataset: Dataset | None = None
|
||||
total_num_steps: int | None = None
|
||||
|
||||
|
||||
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
|
||||
@@ -91,13 +91,11 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
"""
|
||||
"""Base class for trainer builder."""
|
||||
|
||||
_train_dataset = None
|
||||
_eval_dataset = None
|
||||
@@ -110,9 +108,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
self.tokenizer = tokenizer
|
||||
self.processor = processor
|
||||
|
||||
# in case the model supports tagging, add the axolotl tag.
|
||||
# If the model supports tagging, add the axolotl tag.
|
||||
# This makes sure the tag is correctly pushed even if a user calls
|
||||
# model.push_to_hub instad of trainer.push_to_hub.
|
||||
# model.push_to_hub instead of trainer.push_to_hub.
|
||||
if hasattr(model, "add_model_tags"):
|
||||
model.add_model_tags(["axolotl"])
|
||||
|
||||
@@ -227,8 +225,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Build the HuggingFace training args/trainer for causal models
|
||||
and reward modelling using TRL.
|
||||
Build the HuggingFace training args/trainer for causal models and reward modeling
|
||||
using TRL.
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
@@ -872,9 +870,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
|
||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
|
||||
"""
|
||||
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# Cut Cross Entropy
|
||||
|
||||
### Usage
|
||||
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
||||
|
||||
See https://github.com/apple/ml-cross-entropy
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
@@ -8,3 +12,19 @@ plugins:
|
||||
|
||||
cut_cross_entropy: true
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@article{wijmans2024cut,
|
||||
author = {Erik Wijmans and
|
||||
Brody Huval and
|
||||
Alexander Hertzberg and
|
||||
Vladlen Koltun and
|
||||
Philipp Kr\"ahenb\"uhl},
|
||||
title = {Cut Your Losses in Large-Vocabulary Language Models},
|
||||
journal = {arXiv},
|
||||
year = {2024},
|
||||
url = {https://arxiv.org/abs/2411.09009},
|
||||
}
|
||||
```
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
See https://github.com/ironjr/grokfast
|
||||
|
||||
### Usage
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
@@ -11,3 +11,14 @@ plugins:
|
||||
grokfast_alpha: 2.0
|
||||
grokfast_lamb: 0.98
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@article{lee2024grokfast,
|
||||
title={{Grokfast}: Accelerated Grokking by Amplifying Slow Gradients},
|
||||
author={Lee, Jaerin and Kang, Bong Gyun and Kim, Kihoon and Lee, Kyoung Mu},
|
||||
journal={arXiv preprint arXiv:2405.20233},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
23
src/axolotl/integrations/kd/README.md
Normal file
23
src/axolotl/integrations/kd/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Knowledge Distillation
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- "axolotl.integrations.kd.KDPlugin"
|
||||
|
||||
kd_trainer: True
|
||||
kd_ce_alpha: 0.1
|
||||
kd_alpha: 0.9
|
||||
kd_temperature: 1.0
|
||||
|
||||
torch_compile: True # torch>=2.5.1, recommended to reduce vram
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: "axolotl.integrations.kd.chat_template"
|
||||
field_messages: "messages_combined"
|
||||
logprobs_field: "llm_text_generation_vllm_logprobs" # for kd only, field of logprobs
|
||||
```
|
||||
|
||||
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)
|
||||
@@ -1,391 +0,0 @@
|
||||
"""
|
||||
benchmark utility helper for benchmarking the KL divergence triton kernel
|
||||
"""
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def benchmark_kl_div_loss_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4]
|
||||
seq_lens = [64, 512, 2048, 4096, 8192]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Define functions for timing that include both forward and backward passes
|
||||
def run_reference():
|
||||
# Forward pass
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
# Backward pass
|
||||
loss_ref.backward()
|
||||
|
||||
def run_triton():
|
||||
# Forward pass
|
||||
# pylint: disable=duplicate-code
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
# Backward pass
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark reference implementation (forward + backward)
|
||||
t0 = Timer(
|
||||
stmt="run_reference()",
|
||||
globals={
|
||||
"run_reference": run_reference,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_ref.grad = None
|
||||
ref_time = t0.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Benchmark Triton implementation (forward + backward)
|
||||
t1 = Timer(
|
||||
stmt="run_triton()",
|
||||
globals={
|
||||
"run_triton": run_triton,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_triton.grad = None
|
||||
triton_time = t1.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedup
|
||||
speedup = ref_time / triton_time if triton_time > 0 else float("inf")
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"reference_time_ms": ref_time,
|
||||
"triton_time_ms": triton_time,
|
||||
"speedup": speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(f" Reference time (fwd+bwd): {ref_time:.2f} ms")
|
||||
print(f" Triton time (fwd+bwd): {triton_time:.2f} ms")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def benchmark_forward_backward_separately():
|
||||
"""
|
||||
Benchmark forward and backward passes separately to identify where the speedup comes from.
|
||||
"""
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4, 8]
|
||||
seq_lens = [64, 512, 2048]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
detailed_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Forward-only reference
|
||||
def run_reference_forward():
|
||||
with torch.no_grad():
|
||||
return eager_loss(
|
||||
student_logits_ref,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Forward-only triton
|
||||
def run_triton_forward():
|
||||
with torch.no_grad():
|
||||
return triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Benchmark forward pass only
|
||||
|
||||
t0_fwd = Timer(
|
||||
stmt="run_reference_forward()",
|
||||
globals={
|
||||
"run_reference_forward": run_reference_forward,
|
||||
},
|
||||
)
|
||||
ref_fwd_time = t0_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_fwd = Timer(
|
||||
stmt="run_triton_forward()",
|
||||
globals={
|
||||
"run_triton_forward": run_triton_forward,
|
||||
},
|
||||
)
|
||||
triton_fwd_time = t1_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Pre-compute losses for backward pass benchmarking
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
|
||||
# Backward-only reference
|
||||
def run_reference_backward():
|
||||
student_logits_ref.grad = None
|
||||
loss_ref.backward()
|
||||
|
||||
# Backward-only triton
|
||||
def run_triton_backward():
|
||||
student_logits_triton.grad = None
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark backward pass only
|
||||
t0_bwd = Timer(
|
||||
stmt="run_reference_backward()",
|
||||
globals={
|
||||
"run_reference_backward": run_reference_backward,
|
||||
},
|
||||
)
|
||||
ref_bwd_time = t0_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_bwd = Timer(
|
||||
stmt="run_triton_backward()",
|
||||
globals={
|
||||
"run_triton_backward": run_triton_backward,
|
||||
},
|
||||
)
|
||||
triton_bwd_time = t1_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedups
|
||||
fwd_speedup = (
|
||||
ref_fwd_time / triton_fwd_time if triton_fwd_time > 0 else float("inf")
|
||||
)
|
||||
bwd_speedup = (
|
||||
ref_bwd_time / triton_bwd_time if triton_bwd_time > 0 else float("inf")
|
||||
)
|
||||
total_ref_time = ref_fwd_time + ref_bwd_time
|
||||
total_triton_time = triton_fwd_time + triton_bwd_time
|
||||
total_speedup = (
|
||||
total_ref_time / total_triton_time
|
||||
if total_triton_time > 0
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
# Store results
|
||||
detailed_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"ref_forward_ms": ref_fwd_time,
|
||||
"triton_forward_ms": triton_fwd_time,
|
||||
"forward_speedup": fwd_speedup,
|
||||
"ref_backward_ms": ref_bwd_time,
|
||||
"triton_backward_ms": triton_bwd_time,
|
||||
"backward_speedup": bwd_speedup,
|
||||
"total_ref_ms": total_ref_time,
|
||||
"total_triton_ms": total_triton_time,
|
||||
"total_speedup": total_speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(
|
||||
f" Forward: Reference={ref_fwd_time:.2f}ms, Triton={triton_fwd_time:.2f}ms, Speedup={fwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Backward: Reference={ref_bwd_time:.2f}ms, Triton={triton_bwd_time:.2f}ms, Speedup={bwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Total: Reference={total_ref_time:.2f}ms, Triton={total_triton_time:.2f}ms, Speedup={total_speedup:.2f}x"
|
||||
)
|
||||
|
||||
return detailed_results
|
||||
|
||||
|
||||
def benchmark_memory_usage_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 2]
|
||||
seq_len = 8192
|
||||
vocab_size = 128000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
mem_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Measure PyTorch memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_ref.backward()
|
||||
torch.cuda.synchronize()
|
||||
pytorch_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton.backward()
|
||||
torch.cuda.synchronize()
|
||||
triton_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage with different chunk sizes (forward + backward)
|
||||
for n_chunks in [1, 2, 4, 8]:
|
||||
student_logits_chunk = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_chunk = triton_loss(
|
||||
student_logits_chunk,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
loss_chunk.backward()
|
||||
torch.cuda.synchronize()
|
||||
chunk_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": f"Triton (chunks={n_chunks})",
|
||||
"memory_mb": chunk_mem,
|
||||
}
|
||||
)
|
||||
|
||||
# Store results
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "PyTorch",
|
||||
"memory_mb": pytorch_mem,
|
||||
}
|
||||
)
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "Triton (default)",
|
||||
"memory_mb": triton_mem,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size} (with backward pass)")
|
||||
print(f" PyTorch memory: {pytorch_mem:.2f} MB")
|
||||
print(f" Triton memory: {triton_mem:.2f} MB")
|
||||
print(f" Memory reduction: {(1 - triton_mem/pytorch_mem)*100:.2f}%")
|
||||
|
||||
return mem_results
|
||||
|
||||
|
||||
def main():
|
||||
print("Running benchmarks with forward and backward passes...")
|
||||
benchmark_kl_div_loss_with_backward()
|
||||
clean()
|
||||
|
||||
print("\nRunning detailed forward/backward benchmarks...")
|
||||
# benchmark_forward_backward_separately()
|
||||
# clean()
|
||||
|
||||
print("\nRunning memory usage benchmarks with backward passes...")
|
||||
benchmark_memory_usage_with_backward()
|
||||
clean()
|
||||
|
||||
|
||||
def clean():
|
||||
for _ in range(5):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,750 +0,0 @@
|
||||
"""
|
||||
Optimized Triton kernel for KL divergence loss between teacher and student models.
|
||||
"""
|
||||
# pylint: disable=invalid-name,unused-argument
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_logsumexp_logprobs_kernel(
|
||||
student_logits_ptr, # Input logits in original dtype
|
||||
student_logprobs_ptr, # Output logprobs (float32)
|
||||
token_ids_ptr, # Token IDs for top-k
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
temperature,
|
||||
stride_l_b,
|
||||
stride_l_s,
|
||||
stride_l_v,
|
||||
stride_lp_b,
|
||||
stride_lp_s,
|
||||
stride_lp_k,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused kernel that computes logsumexp and logprobs for topk tokens.
|
||||
All computations are done in float32 for numerical stability.
|
||||
"""
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Compute logsumexp over the vocabulary
|
||||
max_val = -float("inf")
|
||||
|
||||
# Phase 1: Find max value across vocabulary
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Update max value
|
||||
block_max = tl.max(block_logits, axis=0)
|
||||
max_val = tl.maximum(max_val, block_max)
|
||||
|
||||
# Phase 2: Compute sum of exp(logits - max_val)
|
||||
sum_exp = 0.0
|
||||
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Compute exp(logits - max_val) and add to sum
|
||||
block_exp = tl.exp(block_logits - max_val)
|
||||
sum_exp += tl.sum(block_exp * mask, axis=0)
|
||||
|
||||
# Compute final logsumexp
|
||||
logsumexp = max_val + tl.log(sum_exp)
|
||||
|
||||
# Phase 3: Compute and store logprobs for the top-k tokens
|
||||
token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
logprobs_base = (
|
||||
student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s
|
||||
)
|
||||
|
||||
for k in range(K):
|
||||
# Load token ID for position k
|
||||
token_id = tl.load(token_ids_base + k * stride_t_k)
|
||||
|
||||
# Load the corresponding logit and convert to float32
|
||||
token_logit_ptr = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ token_id * stride_l_v
|
||||
)
|
||||
token_logit = tl.load(token_logit_ptr).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
token_logit = token_logit / temperature
|
||||
|
||||
# Compute logprob directly: logit - logsumexp
|
||||
token_logprob = token_logit - logsumexp
|
||||
|
||||
# Store the result
|
||||
tl.store(logprobs_base + k * stride_lp_k, token_logprob)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grad_softmax_kernel(
|
||||
grad_student_logits_ptr,
|
||||
target_token_ids_ptr,
|
||||
teacher_probs_ptr,
|
||||
student_probs_ptr,
|
||||
mask_ptr,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
scale,
|
||||
stride_gl_b,
|
||||
stride_gl_s,
|
||||
stride_gl_v,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
stride_p_b,
|
||||
stride_p_s,
|
||||
stride_p_k,
|
||||
stride_sp_b,
|
||||
stride_sp_s,
|
||||
stride_sp_k,
|
||||
stride_m_b,
|
||||
stride_m_s,
|
||||
stride_m_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Base pointers for this (batch, seq) pair
|
||||
grad_logits_base = (
|
||||
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
|
||||
)
|
||||
token_ids_base = (
|
||||
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
)
|
||||
teacher_probs_base = (
|
||||
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
||||
)
|
||||
student_probs_base = (
|
||||
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
|
||||
)
|
||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||
|
||||
# Process each teacher probability one at a time, computing all gradients for it
|
||||
for k in range(0, K):
|
||||
# Load data for current position k
|
||||
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
|
||||
student_prob_k = tl.load(student_probs_base + k * stride_sp_k)
|
||||
mask_val = tl.load(mask_base + k * stride_m_k)
|
||||
|
||||
# Precompute the self-influence term (multiplied by scale)
|
||||
self_term = teacher_prob * (1.0 - student_prob_k) * scale
|
||||
|
||||
# Calculate gradient contributions for all positions j
|
||||
for j in range(0, K):
|
||||
token_id_j = tl.load(token_ids_base + j * stride_t_k)
|
||||
student_prob_j = tl.load(student_probs_base + j * stride_sp_k)
|
||||
mask_j = tl.load(mask_base + j * stride_m_k)
|
||||
|
||||
# Calculate the masking factor
|
||||
combined_mask = mask_val * mask_j
|
||||
|
||||
# Determine if this is a diagonal or off-diagonal term
|
||||
is_k_equals_j = tl.where(k == j, 1.0, 0.0)
|
||||
|
||||
# Compute the gradient contribution
|
||||
# For diagonal (k==j): -teacher_prob * (1-student_prob_k) * scale * mask
|
||||
# For off-diagonal: -(-teacher_prob * student_prob_j) * scale * mask
|
||||
grad_contribution = (
|
||||
-(
|
||||
self_term * is_k_equals_j
|
||||
- teacher_prob * student_prob_j * scale * (1.0 - is_k_equals_j)
|
||||
)
|
||||
* combined_mask
|
||||
)
|
||||
|
||||
# Atomically update the gradient for this token
|
||||
tl.atomic_add(
|
||||
grad_logits_base + token_id_j * stride_gl_v, grad_contribution
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grad_topk_softmax_kernel(
|
||||
grad_student_logits_ptr,
|
||||
student_logits_ptr,
|
||||
target_token_ids_ptr,
|
||||
teacher_probs_ptr,
|
||||
student_probs_ptr,
|
||||
mask_ptr,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
scale,
|
||||
stride_gl_b,
|
||||
stride_gl_s,
|
||||
stride_gl_v,
|
||||
stride_l_b,
|
||||
stride_l_s,
|
||||
stride_l_v,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
stride_p_b,
|
||||
stride_p_s,
|
||||
stride_p_k,
|
||||
stride_sp_b,
|
||||
stride_sp_s,
|
||||
stride_sp_k,
|
||||
stride_m_b,
|
||||
stride_m_s,
|
||||
stride_m_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Base pointers for this (batch, seq) pair
|
||||
grad_logits_base = (
|
||||
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
|
||||
)
|
||||
# logits_base = student_logits_ptr + batch_idx * stride_l_b + seq_idx * stride_l_s
|
||||
token_ids_base = (
|
||||
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
)
|
||||
teacher_probs_base = (
|
||||
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
||||
)
|
||||
student_probs_base = (
|
||||
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
|
||||
)
|
||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||
|
||||
# Load all token IDs, probs and masks for this position
|
||||
token_ids = tl.zeros([K], dtype=tl.int32)
|
||||
teacher_probs = tl.zeros([K], dtype=tl.float32)
|
||||
student_probs = tl.zeros([K], dtype=tl.float32)
|
||||
masks = tl.zeros([K], dtype=tl.float32)
|
||||
|
||||
for k in range(K):
|
||||
token_ids[k] = tl.load(token_ids_base + k * stride_t_k)
|
||||
teacher_probs[k] = tl.load(teacher_probs_base + k * stride_p_k)
|
||||
student_probs[k] = tl.load(student_probs_base + k * stride_sp_k)
|
||||
masks[k] = tl.load(mask_base + k * stride_m_k)
|
||||
|
||||
# Process gradients for all tokens in this position
|
||||
for k in range(K):
|
||||
# token_id = token_ids[k]
|
||||
mask_k = masks[k]
|
||||
|
||||
# Skip computation if mask is zero by multiplying gradient by mask
|
||||
for j in range(K):
|
||||
other_token_id = token_ids[j]
|
||||
mask_j = masks[j]
|
||||
combined_mask = mask_k * mask_j
|
||||
|
||||
# Compute gradient differently for diagonal vs off-diagonal entries
|
||||
# Using * 1.0 to convert boolean to float
|
||||
is_diagonal = tl.where(j == k, 1.0, 0.0)
|
||||
|
||||
# Self influence: gradient = teacher_prob * (1 - student_prob)
|
||||
self_grad = teacher_probs[k] * (1.0 - student_probs[k]) * is_diagonal
|
||||
|
||||
# Cross influence: gradient = -teacher_prob[k] * student_prob[j]
|
||||
cross_grad = -teacher_probs[k] * student_probs[j] * (1.0 - is_diagonal)
|
||||
|
||||
# Combined gradient scaled by mask
|
||||
grad_val = (self_grad + cross_grad) * scale * combined_mask
|
||||
|
||||
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)
|
||||
|
||||
|
||||
# Triton-accelerated implementation of KL divergence loss for top-k tokens
|
||||
# Chunking helper functions for handling long sequences
|
||||
def chunk_tensor(
|
||||
tensor: torch.Tensor, max_seq_len: int
|
||||
) -> Tuple[torch.Tensor, Optional[int]]:
|
||||
"""Split a tensor along sequence dimension if needed."""
|
||||
_, seq_len, *__ = tensor.shape
|
||||
|
||||
if seq_len <= max_seq_len:
|
||||
return tensor, None
|
||||
|
||||
num_chunks = (seq_len + max_seq_len - 1) // max_seq_len
|
||||
chunks = []
|
||||
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * max_seq_len
|
||||
end_idx = min((i + 1) * max_seq_len, seq_len)
|
||||
chunks.append(tensor[:, start_idx:end_idx, ...])
|
||||
|
||||
return chunks, num_chunks
|
||||
|
||||
|
||||
def merge_chunks(chunks: list, original_shape: torch.Size):
|
||||
"""Merge chunks back into a single tensor with original shape."""
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
|
||||
# Triton-accelerated implementation of KL divergence loss for top-k tokens
|
||||
class TopKKLDivergence(torch.autograd.Function):
|
||||
"""
|
||||
Autograd function for KL divergence loss between top-k logprobs
|
||||
with support for chunking to handle very long sequences.
|
||||
"""
|
||||
|
||||
# Max sequence length to process in a single kernel launch
|
||||
# This is a tunable parameter that might need adjustment based on GPU memory
|
||||
MAX_SEQ_LEN = 8192
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch=-1,
|
||||
kd_temperature=1.0,
|
||||
top_k_before_softmax=0,
|
||||
):
|
||||
"""
|
||||
Forward pass for KL divergence loss between top-k logprobs with chunking.
|
||||
"""
|
||||
# Only convert target_logprobs to float, leave student_logits as is
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
# Get dimensions
|
||||
batch_size, _, vocab_size = student_logits.shape
|
||||
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||
|
||||
# Slice student logits to match teacher sequence length
|
||||
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
|
||||
|
||||
# Store original values for backward pass
|
||||
ctx.original_seq_len = teacher_seq_len
|
||||
ctx.original_dtype = student_logits.dtype
|
||||
|
||||
# Apply chunking for long sequences
|
||||
if teacher_seq_len > TopKKLDivergence.MAX_SEQ_LEN:
|
||||
# Chunk the inputs
|
||||
student_logits_chunks, num_chunks = chunk_tensor(
|
||||
student_logits_for_kd, TopKKLDivergence.MAX_SEQ_LEN
|
||||
)
|
||||
target_token_ids_chunks, _ = chunk_tensor(
|
||||
target_token_ids, TopKKLDivergence.MAX_SEQ_LEN
|
||||
)
|
||||
# target_logprobs_chunks, _ = chunk_tensor(
|
||||
# target_logprobs, TopKKLDivergence.MAX_SEQ_LEN
|
||||
# )
|
||||
# target_mask_chunks, _ = chunk_tensor(
|
||||
# target_mask, TopKKLDivergence.MAX_SEQ_LEN
|
||||
# )
|
||||
|
||||
# Process each chunk
|
||||
student_logprobs_chunks = []
|
||||
student_probs_chunks = []
|
||||
|
||||
for i in range(num_chunks):
|
||||
chunk_logits = student_logits_chunks[i]
|
||||
chunk_token_ids = target_token_ids_chunks[i]
|
||||
chunk_seq_len = chunk_logits.shape[1]
|
||||
|
||||
if top_k_before_softmax:
|
||||
# Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
chunk_logits = chunk_logits / kd_temperature
|
||||
|
||||
# Gather student logits for top-k tokens
|
||||
chunk_logits_topk = torch.gather(
|
||||
chunk_logits, dim=-1, index=chunk_token_ids
|
||||
)
|
||||
|
||||
# Compute softmax over gathered logits
|
||||
chunk_logprobs_topk = torch.log_softmax(chunk_logits_topk, dim=-1)
|
||||
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
|
||||
else:
|
||||
# Allocate output tensor for logprobs directly (always in float32)
|
||||
chunk_logprobs_topk = torch.empty(
|
||||
(batch_size, chunk_seq_len, top_k),
|
||||
dtype=torch.float32,
|
||||
device=chunk_logits.device,
|
||||
)
|
||||
|
||||
# Launch fused kernel directly
|
||||
grid = (batch_size * chunk_seq_len,)
|
||||
fused_logsumexp_logprobs_kernel[grid](
|
||||
chunk_logits.contiguous(),
|
||||
chunk_logprobs_topk,
|
||||
chunk_token_ids.contiguous(),
|
||||
batch_size,
|
||||
chunk_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
kd_temperature,
|
||||
chunk_logits.stride(0),
|
||||
chunk_logits.stride(1),
|
||||
chunk_logits.stride(2),
|
||||
chunk_logprobs_topk.stride(0),
|
||||
chunk_logprobs_topk.stride(1),
|
||||
chunk_logprobs_topk.stride(2),
|
||||
chunk_token_ids.stride(0),
|
||||
chunk_token_ids.stride(1),
|
||||
chunk_token_ids.stride(2),
|
||||
min(1024, triton.next_power_of_2(vocab_size)),
|
||||
)
|
||||
|
||||
# Calculate probs from logprobs
|
||||
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
|
||||
|
||||
# Store results
|
||||
student_logprobs_chunks.append(chunk_logprobs_topk)
|
||||
student_probs_chunks.append(chunk_probs_topk)
|
||||
|
||||
# Merge results
|
||||
student_logprobs_topk = torch.cat(student_logprobs_chunks, dim=1)
|
||||
student_probs_topk = torch.cat(student_probs_chunks, dim=1)
|
||||
|
||||
# Save chunking info for backward pass
|
||||
ctx.used_chunking = True
|
||||
ctx.num_chunks = num_chunks
|
||||
|
||||
else:
|
||||
# Original code path for shorter sequences
|
||||
if top_k_before_softmax:
|
||||
# Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_for_kd = student_logits_for_kd / kd_temperature
|
||||
|
||||
# Gather student logits for top-k tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
)
|
||||
|
||||
# Compute softmax over gathered logits
|
||||
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
else:
|
||||
# Allocate output tensor for logprobs directly (always in float32)
|
||||
student_logprobs_topk = torch.empty(
|
||||
(batch_size, teacher_seq_len, top_k),
|
||||
dtype=torch.float32,
|
||||
device=student_logits.device,
|
||||
)
|
||||
|
||||
# Launch fused kernel directly
|
||||
grid = (batch_size * teacher_seq_len,)
|
||||
fused_logsumexp_logprobs_kernel[grid](
|
||||
student_logits_for_kd.contiguous(),
|
||||
student_logprobs_topk,
|
||||
target_token_ids.contiguous(),
|
||||
batch_size,
|
||||
teacher_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
kd_temperature,
|
||||
student_logits_for_kd.stride(0),
|
||||
student_logits_for_kd.stride(1),
|
||||
student_logits_for_kd.stride(2),
|
||||
student_logprobs_topk.stride(0),
|
||||
student_logprobs_topk.stride(1),
|
||||
student_logprobs_topk.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
min(1024, triton.next_power_of_2(vocab_size)),
|
||||
)
|
||||
|
||||
# Calculate probs from logprobs
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
|
||||
# No chunking used
|
||||
ctx.used_chunking = False
|
||||
|
||||
# Save tensors for backward pass
|
||||
ctx.save_for_backward(
|
||||
student_logits_for_kd,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
student_probs_topk,
|
||||
)
|
||||
ctx.kd_temperature = kd_temperature
|
||||
ctx.top_k_before_softmax = top_k_before_softmax
|
||||
ctx.num_items_in_batch = num_items_in_batch
|
||||
|
||||
# Convert mask to boolean
|
||||
valid_mask = target_mask.bool()
|
||||
|
||||
# Extract valid tokens only - this is where the error was happening
|
||||
# Use cloned contiguous tensors and explicit indexing for safety
|
||||
student_logprobs_flat = student_logprobs_topk.view(-1, top_k)
|
||||
target_logprobs_flat = target_logprobs.view(-1, top_k)
|
||||
valid_mask_flat = valid_mask.view(-1, top_k)
|
||||
|
||||
# Gather valid indices explicitly to avoid illegal memory access
|
||||
valid_indices = torch.nonzero(valid_mask_flat.view(-1)).squeeze(-1)
|
||||
student_logprobs_valid = torch.index_select(
|
||||
student_logprobs_flat.view(-1), 0, valid_indices
|
||||
)
|
||||
target_logprobs_valid = torch.index_select(
|
||||
target_logprobs_flat.view(-1), 0, valid_indices
|
||||
)
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs_valid = torch.exp(target_logprobs_valid)
|
||||
|
||||
# Compute KL divergence loss
|
||||
token_losses = teacher_probs_valid * (
|
||||
target_logprobs_valid - student_logprobs_valid
|
||||
)
|
||||
kd_loss = token_losses.sum()
|
||||
|
||||
# Apply temperature scaling
|
||||
# pylint: disable=duplicate-code
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
num_valid_tokens = valid_indices.numel()
|
||||
kd_loss = kd_loss / float(num_valid_tokens if num_valid_tokens > 0 else 1)
|
||||
|
||||
return kd_loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
Optimized backward pass for KL divergence loss with proper dtype handling and chunking.
|
||||
"""
|
||||
(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
student_probs,
|
||||
) = ctx.saved_tensors
|
||||
kd_temperature = ctx.kd_temperature
|
||||
num_items_in_batch = ctx.num_items_in_batch
|
||||
original_dtype = ctx.original_dtype
|
||||
|
||||
# Get dimensions
|
||||
batch_size, _, vocab_size = student_logits.shape
|
||||
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||
|
||||
# Initialize gradient tensor in float32 to support atomic operations
|
||||
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
|
||||
|
||||
# Compute scaling factor
|
||||
scale = grad_output.item()
|
||||
|
||||
# Apply temperature scaling from forward pass
|
||||
if kd_temperature != 1.0:
|
||||
scale = scale * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
scale = scale / float(num_items_in_batch)
|
||||
else:
|
||||
scale = scale / float(target_mask.sum().item())
|
||||
|
||||
# Apply chain rule for temperature scaling (1/temperature)
|
||||
if kd_temperature != 1.0:
|
||||
scale = scale / kd_temperature
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs = torch.exp(target_logprobs)
|
||||
|
||||
# Use chunking for the backward pass if used in forward
|
||||
if getattr(ctx, "used_chunking", False):
|
||||
num_chunks = ctx.num_chunks
|
||||
max_seq = TopKKLDivergence.MAX_SEQ_LEN
|
||||
|
||||
# Process each chunk
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * max_seq
|
||||
end_idx = min((i + 1) * max_seq, teacher_seq_len)
|
||||
chunk_len = end_idx - start_idx
|
||||
|
||||
# Get chunk slices
|
||||
# student_logits_chunk = student_logits[:, start_idx:end_idx, :]
|
||||
target_token_ids_chunk = target_token_ids[:, start_idx:end_idx, :]
|
||||
teacher_probs_chunk = teacher_probs[:, start_idx:end_idx, :]
|
||||
student_probs_chunk = student_probs[:, start_idx:end_idx, :]
|
||||
target_mask_chunk = target_mask[:, start_idx:end_idx, :]
|
||||
grad_student_logits_chunk = grad_student_logits[:, start_idx:end_idx, :]
|
||||
|
||||
# Launch gradient computation kernel for this chunk
|
||||
grid = (batch_size * chunk_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits_chunk.contiguous(),
|
||||
target_token_ids_chunk.contiguous(),
|
||||
teacher_probs_chunk.contiguous(),
|
||||
student_probs_chunk.contiguous(),
|
||||
target_mask_chunk.contiguous(),
|
||||
batch_size,
|
||||
chunk_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits_chunk.stride(0),
|
||||
grad_student_logits_chunk.stride(1),
|
||||
grad_student_logits_chunk.stride(2),
|
||||
target_token_ids_chunk.stride(0),
|
||||
target_token_ids_chunk.stride(1),
|
||||
target_token_ids_chunk.stride(2),
|
||||
teacher_probs_chunk.stride(0),
|
||||
teacher_probs_chunk.stride(1),
|
||||
teacher_probs_chunk.stride(2),
|
||||
student_probs_chunk.stride(0),
|
||||
student_probs_chunk.stride(1),
|
||||
student_probs_chunk.stride(2),
|
||||
target_mask_chunk.stride(0),
|
||||
target_mask_chunk.stride(1),
|
||||
target_mask_chunk.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Update the gradient tensor (already in-place)
|
||||
else:
|
||||
# Original code path for shorter sequences
|
||||
# Launch gradient computation kernel
|
||||
grid = (batch_size * teacher_seq_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits.contiguous(),
|
||||
target_token_ids.contiguous(),
|
||||
teacher_probs.contiguous(),
|
||||
student_probs.contiguous(),
|
||||
target_mask.contiguous(),
|
||||
batch_size,
|
||||
teacher_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits.stride(0),
|
||||
grad_student_logits.stride(1),
|
||||
grad_student_logits.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
teacher_probs.stride(0),
|
||||
teacher_probs.stride(1),
|
||||
teacher_probs.stride(2),
|
||||
student_probs.stride(0),
|
||||
student_probs.stride(1),
|
||||
student_probs.stride(2),
|
||||
target_mask.stride(0),
|
||||
target_mask.stride(1),
|
||||
target_mask.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Convert gradient back to original dtype if needed
|
||||
if original_dtype != torch.float32:
|
||||
grad_student_logits = grad_student_logits.to(original_dtype)
|
||||
|
||||
# Return gradients for student_logits and None for other inputs
|
||||
return grad_student_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
# Wrapper function for chunked computation
|
||||
def loss(
|
||||
student_logits: torch.Tensor,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_logprobs: torch.Tensor,
|
||||
target_mask: torch.Tensor,
|
||||
num_items_in_batch: int = -1,
|
||||
kd_temperature: float = 1.0,
|
||||
top_k_before_softmax: int = 0,
|
||||
max_seq_len: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation
|
||||
with support for very long sequences.
|
||||
|
||||
Args:
|
||||
student_logits: Student logits [B, seq_len, vocab_size]
|
||||
target_token_ids: Teacher token IDs [B, seq_len, top_k]
|
||||
target_logprobs: Teacher logprobs [B, seq_len, top_k]
|
||||
target_mask: Token mask [B, seq_len, top_k]
|
||||
num_items_in_batch: Number of items for normalization (-1 for auto)
|
||||
kd_temperature: Temperature for KD
|
||||
top_k_before_softmax: Flag for softmax application order
|
||||
max_seq_len: Override default MAX_SEQ_LEN value for chunking
|
||||
"""
|
||||
# Allow overriding the max sequence length
|
||||
if max_seq_len is not None and max_seq_len > 0:
|
||||
TopKKLDivergence.MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
total_loss = TopKKLDivergence.apply(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
-1 if num_items_in_batch <= 0 else num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
|
||||
return total_loss
|
||||
@@ -1,67 +0,0 @@
|
||||
"""
|
||||
Optimized Triton kernels for logsumexp
|
||||
"""
|
||||
# pylint: disable=invalid-name,unused-argument
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# Helper function for computing logsumexp
|
||||
@triton.jit
|
||||
def logsumexp_kernel(
|
||||
logits_ptr,
|
||||
output_ptr,
|
||||
B,
|
||||
S,
|
||||
V, # batch size, seq len, vocab size
|
||||
stride_b,
|
||||
stride_s,
|
||||
stride_v,
|
||||
out_stride_b,
|
||||
out_stride_s,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
# pylint: disable=duplicate-code
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Pointers
|
||||
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s
|
||||
|
||||
# Find maximum for numerical stability
|
||||
max_val = -float("inf")
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||
|
||||
logits_block = tl.load(
|
||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||
mask=mask,
|
||||
other=-float("inf"),
|
||||
)
|
||||
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))
|
||||
|
||||
# Compute sum of exp(logit - max_val)
|
||||
sum_exp = 0.0
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||
|
||||
logits_block = tl.load(
|
||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||
mask=mask,
|
||||
other=-float("inf"),
|
||||
)
|
||||
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)
|
||||
|
||||
# Compute logsumexp
|
||||
result = max_val + tl.log(sum_exp)
|
||||
|
||||
# Store result
|
||||
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)
|
||||
@@ -20,7 +20,6 @@ from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
@@ -86,12 +85,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_fn = (
|
||||
topk_kd_loss
|
||||
if self.args.kd_top_k_before_softmax
|
||||
else topk_kd_loss_triton
|
||||
)
|
||||
loss_kd = loss_fn(
|
||||
loss_kd = topk_kd_loss(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
|
||||
36
src/axolotl/integrations/liger/README.md
Normal file
36
src/axolotl/integrations/liger/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Liger Kernel Integration
|
||||
|
||||
Liger Kernel provides efficient Triton kernels for LLM training, offering:
|
||||
|
||||
- 20% increase in multi-GPU training throughput
|
||||
- 60% reduction in memory usage
|
||||
- Compatibility with both FSDP and DeepSpeed
|
||||
|
||||
See https://github.com/linkedin/Liger-Kernel
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@article{hsu2024ligerkernelefficienttriton,
|
||||
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
|
||||
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
|
||||
year={2024},
|
||||
eprint={2410.10989},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG},
|
||||
url={https://arxiv.org/abs/2410.10989},
|
||||
journal={arXiv preprint arXiv:2410.10989},
|
||||
}
|
||||
```
|
||||
@@ -1,6 +1,10 @@
|
||||
# LM Eval Harness
|
||||
|
||||
### Usage
|
||||
Run evaluation on model using the popular lm-evaluation-harness library.
|
||||
|
||||
See https://github.com/EleutherAI/lm-evaluation-harness
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
@@ -10,4 +14,22 @@ lm_eval_tasks:
|
||||
- gsm8k
|
||||
- hellaswag
|
||||
- arc_easy
|
||||
|
||||
lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@misc{eval-harness,
|
||||
author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
|
||||
title = {A framework for few-shot language model evaluation},
|
||||
month = 07,
|
||||
year = 2024,
|
||||
publisher = {Zenodo},
|
||||
version = {v0.4.3},
|
||||
doi = {10.5281/zenodo.12608602},
|
||||
url = {https://zenodo.org/records/12608602}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
## Spectrum: Targeted Training on Signal to Noise Ratio
|
||||
# Spectrum: Targeted Training on Signal to Noise Ratio
|
||||
|
||||
by Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar
|
||||
|
||||
This plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR).
|
||||
|
||||
### Overview
|
||||
See https://github.com/cognitivecomputations/spectrum
|
||||
|
||||
## Overview
|
||||
|
||||
Spectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models.
|
||||
By identifying the top n% of layers with the highest SNR, you can optimize training efficiency.
|
||||
|
||||
### Usage
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
@@ -19,3 +21,17 @@ spectrum_top_fraction: 0.5
|
||||
# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror
|
||||
spectrum_model_name: meta-llama/Meta-Llama-3.1-8B
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
@misc{hartford2024spectrumtargetedtrainingsignal,
|
||||
title={Spectrum: Targeted Training on Signal to Noise Ratio},
|
||||
author={Eric Hartford and Lucas Atkins and Fernando Fernandes Neto and David Golchinfar},
|
||||
year={2024},
|
||||
eprint={2406.06623},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG},
|
||||
url={https://arxiv.org/abs/2406.06623},
|
||||
}
|
||||
```
|
||||
|
||||
@@ -25,6 +25,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
"deepseek_v2",
|
||||
"deepseek_v3",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,26 +1,29 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from peft import PeftModel
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from datasets import Dataset
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
@@ -32,17 +35,25 @@ try:
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def train(
|
||||
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||
def setup_model_and_tokenizer(
|
||||
cfg: DictDefault,
|
||||
) -> tuple[
|
||||
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
|
||||
]:
|
||||
"""
|
||||
Load the tokenizer, processor (for multimodal models), and model based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
Tuple containing model, tokenizer, `peft_config` (if LoRA / QLoRA, else
|
||||
`None`), and processor (if multimodal, else `None`).
|
||||
"""
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
@@ -55,11 +66,58 @@ def train(
|
||||
if cfg.is_multimodal:
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
# Get datasets
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
# Load the model and peft_config
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
|
||||
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
# Apply freezing if specified
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
return model, tokenizer, peft_config, processor
|
||||
|
||||
|
||||
def setup_reference_model(
|
||||
cfg: DictDefault, tokenizer: PreTrainedTokenizer
|
||||
) -> PreTrainedModel | None:
|
||||
"""
|
||||
Set up the reference model for RL training if needed.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
tokenizer: The tokenizer to use for the reference model.
|
||||
|
||||
Returns:
|
||||
Reference model if needed for RL training, `None` otherwise.
|
||||
"""
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||
return model_ref
|
||||
|
||||
|
||||
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||
"""
|
||||
Determine the checkpoint to resume from based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
Path to the checkpoint to resume from, or `None` if not resuming.
|
||||
"""
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
@@ -73,77 +131,22 @@ def train(
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
return cfg.resume_from_checkpoint
|
||||
|
||||
# Load the model and tokenizer
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
model, peft_config = load_model(cfg, tokenizer, processor=processor)
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
model_ref = None # explicit setting to None
|
||||
else:
|
||||
# load the model again for model_ref/baseline
|
||||
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
"""
|
||||
Set up signal handler for graceful termination.
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
(model, model_ref, peft_config),
|
||||
tokenizer,
|
||||
processor,
|
||||
total_num_steps,
|
||||
)
|
||||
|
||||
if cfg.fix_untrained_tokens:
|
||||
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_untrained_tokens(
|
||||
model,
|
||||
tokenizer,
|
||||
train_dataset,
|
||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||
)
|
||||
else:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
peft_config.save_pretrained(cfg.output_dir)
|
||||
# additionally presave the tokenizer and model configs
|
||||
if not Path(cfg.output_dir).is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||
|
||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||
if (
|
||||
cfg.local_rank == 0 and not cfg.use_ray
|
||||
): # ray workers don't have access to this signal
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
model: The model to save on termination
|
||||
safe_serialization: Whether to use safe serialization when saving
|
||||
"""
|
||||
# ray workers don't have access to this signal
|
||||
if cfg.local_rank == 0 and not cfg.use_ray:
|
||||
|
||||
def terminate_handler(_, __, model_weakref):
|
||||
if model_weakref() is not None:
|
||||
@@ -161,21 +164,22 @@ def train(
|
||||
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
|
||||
)
|
||||
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
if getattr(cfg, "axolotl_config_path"):
|
||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||
version = get_distribution("axolotl").version
|
||||
if raw_axolotl_cfg.is_file():
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||
def execute_training(
|
||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||
):
|
||||
"""
|
||||
Execute the training process with appropriate backend configurations.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
trainer: The configured trainer object.
|
||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||
"""
|
||||
LOG.info("Starting trainer...")
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
@@ -187,15 +191,30 @@ def train(
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
model: PreTrainedModel,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
Save the trained model according to configuration and training setup.
|
||||
|
||||
# post training
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
trainer: The trainer object.
|
||||
model: The trained model to save.
|
||||
safe_serialization: Whether to use safe serialization.
|
||||
"""
|
||||
LOG.info(f"Training completed! Saving pre-trained model to {cfg.output_dir}.")
|
||||
|
||||
# Post training module hooks
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
|
||||
# Handle FSDP state dict type
|
||||
state_dict_type = "FULL_STATE_DICT"
|
||||
if trainer.is_fsdp_enabled:
|
||||
if cfg.fsdp_final_state_dict_type:
|
||||
@@ -203,16 +222,18 @@ def train(
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
|
||||
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
model = model.merge_and_unload()
|
||||
else:
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return model, tokenizer
|
||||
return
|
||||
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
# TODO: do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple
|
||||
# processes attempt to write the same file
|
||||
if (
|
||||
state_dict_type == "SHARDED_STATE_DICT"
|
||||
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
|
||||
@@ -244,7 +265,6 @@ def train(
|
||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
model = BetterTransformer.reverse(model)
|
||||
@@ -255,58 +275,239 @@ def train(
|
||||
)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||
"""
|
||||
Create a model card for the trained model if needed.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
trainer: The trainer object with model card creation capabilities.
|
||||
"""
|
||||
if not cfg.hub_model_id:
|
||||
# Guard since create_model_card may fail if dataset_tags is empty list
|
||||
try:
|
||||
model_card_kwarg = {
|
||||
"model_name": cfg.output_dir.lstrip("./")
|
||||
.encode("utf-8")
|
||||
.decode("utf-8")
|
||||
}
|
||||
if cfg.datasets is not None:
|
||||
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
|
||||
dataset_tags = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
dataset_tags = [
|
||||
d for d in dataset_tags if not d.startswith("https://")
|
||||
]
|
||||
if dataset_tags:
|
||||
# guard as create_model_card may fail if dataset_tags is empty list
|
||||
model_card_kwarg["dataset_name"] = dataset_tags
|
||||
else:
|
||||
dataset_tags = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
dataset_tags = [
|
||||
d for d in dataset_tags if not d.startswith("https://")
|
||||
]
|
||||
if dataset_tags:
|
||||
# guard as create_model_card may fail if dataset_tags is empty list
|
||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||
|
||||
# We check if we're using a TRL trainer; if so, `dataset_tags` is not consumed.
|
||||
rl = cfg.rl is not None or cfg.reward_model or cfg.process_reward_model
|
||||
if cfg.datasets is not None and not rl:
|
||||
dataset_tags = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
dataset_tags = [d for d in dataset_tags if not d.startswith("https://")]
|
||||
|
||||
if dataset_tags:
|
||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||
|
||||
trainer.create_model_card(**model_card_kwarg)
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
pass
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
# Defensively push to the hub to ensure the model card is updated
|
||||
trainer.push_to_hub()
|
||||
|
||||
|
||||
def save_initial_configs(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PreTrainedModel,
|
||||
peft_config: PeftConfig | None,
|
||||
):
|
||||
"""
|
||||
Save initial configurations before training.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
tokenizer: The tokenizer to save.
|
||||
model: The model to save configuration for.
|
||||
peft_config: The PEFT configuration to save if applicable.
|
||||
"""
|
||||
# Create output_dir if it doesn't already exist
|
||||
output_dir = Path(cfg.output_dir)
|
||||
if not output_dir.is_dir():
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
|
||||
# Pre-save adapter config so it's available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}...")
|
||||
peft_config.save_pretrained(cfg.output_dir)
|
||||
|
||||
# Pre-save the tokenizer and model configs
|
||||
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
|
||||
tokenizer.save_pretrained(str(output_dir))
|
||||
if hasattr(model, "config"):
|
||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||
model.config.save_pretrained(str(output_dir))
|
||||
|
||||
|
||||
def setup_model_card(cfg: DictDefault):
|
||||
"""
|
||||
Set up the Axolotl badge and add the Axolotl config to the model card if available.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)"""
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
if getattr(cfg, "axolotl_config_path"):
|
||||
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||
version = importlib.metadata.version("axolotl")
|
||||
if raw_axolotl_cfg.is_file():
|
||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||
|
||||
|
||||
def handle_untrained_tokens_fix(
|
||||
cfg: DictDefault,
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
train_dataset: Dataset,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
Apply fixes for untrained tokens if configured.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
model: The model to apply fixes to.
|
||||
tokenizer: The tokenizer for token identification.
|
||||
train_dataset: The training dataset to use.
|
||||
safe_serialization: Whether to use safe serialization when saving.
|
||||
"""
|
||||
if not cfg.fix_untrained_tokens:
|
||||
return
|
||||
|
||||
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
|
||||
sig = inspect.signature(fix_untrained_tokens)
|
||||
|
||||
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
|
||||
if "token_ids_to_fix" in sig.parameters and isinstance(
|
||||
cfg.fix_untrained_tokens, list
|
||||
):
|
||||
fix_untrained_tokens(
|
||||
model,
|
||||
tokenizer,
|
||||
train_dataset,
|
||||
token_ids_to_fix=cfg.fix_untrained_tokens,
|
||||
)
|
||||
else:
|
||||
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
|
||||
def setup_model_and_trainer(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[
|
||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
]:
|
||||
"""
|
||||
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
||||
trainer setup.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary with training parameters.
|
||||
dataset_meta: Object with training, validation datasets and metadata.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Trainer (Causal or RLHF)
|
||||
- Model
|
||||
- Tokenizer
|
||||
- PEFT config
|
||||
"""
|
||||
# Load tokenizer, processor and model
|
||||
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
||||
|
||||
# Set up reference model for RL if needed
|
||||
model_ref = setup_reference_model(cfg, tokenizer)
|
||||
|
||||
# Get datasets from metadata
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Set up trainer
|
||||
trainer = setup_trainer(
|
||||
cfg=cfg,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
total_num_steps=total_num_steps,
|
||||
model_ref=model_ref,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
return (
|
||||
trainer,
|
||||
model,
|
||||
tokenizer,
|
||||
peft_config,
|
||||
)
|
||||
|
||||
|
||||
def train(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[PeftModel | PreTrainedModel, PreTrainedTokenizer]:
|
||||
"""
|
||||
Train a model on the given dataset.
|
||||
|
||||
Args:
|
||||
cfg: The configuration dictionary with training parameters
|
||||
dataset_meta: Object with training, validation datasets and metadata
|
||||
|
||||
Returns:
|
||||
Tuple of (model, tokenizer) after training
|
||||
"""
|
||||
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
||||
(
|
||||
trainer,
|
||||
model,
|
||||
tokenizer,
|
||||
peft_config,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
# Determine if we need to resume from a checkpoint
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
|
||||
# Configuration for saving
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
handle_untrained_tokens_fix(
|
||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||
)
|
||||
|
||||
# Save initial configs
|
||||
save_initial_configs(cfg, tokenizer, model, peft_config)
|
||||
|
||||
# Set up signal handler for graceful termination
|
||||
setup_signal_handler(cfg, model, safe_serialization)
|
||||
|
||||
# Set up badges and config info for model card
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# Save the trained model
|
||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||
|
||||
# Create model card
|
||||
create_model_card(cfg, trainer)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pretrain_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
|
||||
|
||||
def post_train_hooks(_cfg, _trainer):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -55,6 +55,7 @@ class ChatTemplate(str, Enum):
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
||||
jamba = "jamba" # pylint: disable=invalid-name
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
|
||||
@@ -574,14 +574,40 @@ def prepare_opinionated_env(cfg):
|
||||
|
||||
|
||||
def setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||
cfg,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
total_num_steps,
|
||||
model_ref=None,
|
||||
peft_config=None,
|
||||
):
|
||||
"""
|
||||
Helper method for instantiating and building a (causal or RLHF) trainer.
|
||||
|
||||
Args:
|
||||
cfg: Axolotl config object containing training parameters.
|
||||
train_dataset: Dataset to use for training.
|
||||
eval_dataset: Dataset to use for evaluation.
|
||||
model: The model to train.
|
||||
tokenizer: Tokenizer for processing text input.
|
||||
processor: Processor for data preparation.
|
||||
total_num_steps: The total number of training steps.
|
||||
model_ref: Optional reference model for RLHF training. Default is None.
|
||||
peft_config: Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None.
|
||||
|
||||
Returns:
|
||||
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
||||
on the provided parameters.
|
||||
"""
|
||||
if cfg.rl:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
||||
trainer_builder.model_ref = model_ref
|
||||
trainer_builder.peft_config = peft_config
|
||||
else:
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer, processor)
|
||||
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
trainer_builder.eval_dataset = eval_dataset
|
||||
|
||||
194
styles.css
194
styles.css
@@ -1,5 +1,193 @@
|
||||
/* css styles */
|
||||
/* TYPOGRAPHY SECTION */
|
||||
|
||||
img[alt="Axolotl"] {
|
||||
content: url("https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg") !important;
|
||||
/* Import fonts */
|
||||
@import url('https://fonts.googleapis.com/css2?family=Be+Vietnam+Pro:wght@400;500&display=swap');
|
||||
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400&display=swap');
|
||||
|
||||
/* Typography hierarchy */
|
||||
:root {
|
||||
--font-title: 'Be Vietnam Pro', sans-serif;
|
||||
--font-body: 'JetBrains Mono', monospace;
|
||||
}
|
||||
|
||||
/* Title (h1) */
|
||||
h1 {
|
||||
font-family: var(--font-title);
|
||||
font-weight: 400;
|
||||
font-size: 6rem;
|
||||
line-height: 1.1;
|
||||
letter-spacing: -0.05em;
|
||||
font-feature-settings: "ss01" on;
|
||||
}
|
||||
|
||||
/* Heading (h2) */
|
||||
h2 {
|
||||
font-family: var(--font-title);
|
||||
font-weight: 500;
|
||||
font-size: 2rem;
|
||||
line-height: 1.2;
|
||||
letter-spacing: -0.03em;
|
||||
font-feature-settings: "ss01" on;
|
||||
}
|
||||
|
||||
/* Subtitle/Preamble */
|
||||
h3,
|
||||
h4 {
|
||||
font-family: var(--font-body);
|
||||
font-weight: 400;
|
||||
font-size: 1.5rem;
|
||||
line-height: 1.5;
|
||||
letter-spacing: -0.02em;
|
||||
}
|
||||
|
||||
/* Body text */
|
||||
body {
|
||||
font-family: var(--font-body);
|
||||
font-weight: 400;
|
||||
font-size: 1rem;
|
||||
line-height: 1.5;
|
||||
letter-spacing: -0.02em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
a {
|
||||
font-family: var(--font-body);
|
||||
font-weight: 400;
|
||||
font-size: 0.875rem;
|
||||
line-height: 1;
|
||||
letter-spacing: -0.02em;
|
||||
}
|
||||
|
||||
/* NAV BAR SECTION */
|
||||
|
||||
/* Navbar logo styling */
|
||||
.navbar-brand img {
|
||||
height: 32px;
|
||||
margin-right: 10px;
|
||||
}
|
||||
|
||||
/* COLORS SECTION */
|
||||
|
||||
/* Brand colors */
|
||||
:root {
|
||||
--white: #ffffff;
|
||||
--greige-300: #EEEEE7;
|
||||
--greige-600: #CCCAC0;
|
||||
--black: #141310;
|
||||
--lime: #E3F8A8;
|
||||
--cyan: #A0F4EA;
|
||||
--purple: #C8D0F8;
|
||||
}
|
||||
|
||||
/* Base styles */
|
||||
body {
|
||||
background-color: var(--black);
|
||||
color: var(--greige-300);
|
||||
}
|
||||
|
||||
/* Navigation */
|
||||
.navbar {
|
||||
background-color: var(--black) !important;
|
||||
}
|
||||
|
||||
.navbar-dark .navbar-nav .nav-link {
|
||||
color: var(--greige-300);
|
||||
}
|
||||
|
||||
.navbar-dark .navbar-nav .nav-link:hover {
|
||||
color: var(--lime);
|
||||
}
|
||||
|
||||
/* Sidebar */
|
||||
.sidebar-navigation {
|
||||
background-color: var(--black);
|
||||
border-right: 1px solid var(--greige-600);
|
||||
}
|
||||
|
||||
.sidebar nav[role="doc-toc"] ul>li>a {
|
||||
color: var(--greige-300);
|
||||
}
|
||||
|
||||
.sidebar nav[role="doc-toc"] ul>li>a:hover {
|
||||
color: var(--lime);
|
||||
}
|
||||
|
||||
/* Links */
|
||||
a {
|
||||
color: var(--lime);
|
||||
}
|
||||
|
||||
a:hover {
|
||||
color: var(--cyan);
|
||||
}
|
||||
|
||||
/* Headers */
|
||||
h1,
|
||||
h2,
|
||||
h3,
|
||||
h4,
|
||||
h5,
|
||||
h6 {
|
||||
color: var(--white);
|
||||
}
|
||||
|
||||
/* Code blocks */
|
||||
pre {
|
||||
background-color: #1a1a1a !important;
|
||||
border: 1px solid var(--greige-600);
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
.table {
|
||||
color: var(--greige-300);
|
||||
}
|
||||
|
||||
/* TOC */
|
||||
#toc-title {
|
||||
color: var(--white);
|
||||
}
|
||||
|
||||
.toc-active {
|
||||
color: var(--lime) !important;
|
||||
}
|
||||
|
||||
/* Buttons */
|
||||
.btn-primary {
|
||||
background-color: var(--lime);
|
||||
color: var(--black);
|
||||
border: none;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background-color: var(--cyan);
|
||||
color: var(--black);
|
||||
}
|
||||
|
||||
/* For inline code (single backtick) */
|
||||
code {
|
||||
background-color: #1a1a1a !important;
|
||||
color: var(--lime) !important;
|
||||
padding: 2px 4px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* For inline code that is also a link */
|
||||
a code {
|
||||
color: var(--cyan) !important;
|
||||
}
|
||||
|
||||
/* For code blocks (triple backtick) */
|
||||
pre.sourceCode {
|
||||
background-color: #1a1a1a !important;
|
||||
}
|
||||
|
||||
/* Make comments in bash/shell scripts green */
|
||||
code span.co {
|
||||
color: #5cb85c !important;
|
||||
}
|
||||
|
||||
/* Remove underlines from JSON comments and make them green */
|
||||
code span.er {
|
||||
color: #5cb85c !important;
|
||||
text-decoration: none !important;
|
||||
}
|
||||
|
||||
@@ -90,12 +90,6 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"load_in_8bit",
|
||||
@@ -127,9 +121,3 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
sanity checks on kl loss and gradients
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Import both implementations
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||
|
||||
|
||||
def test_kl_loss_gradient():
|
||||
"""Test that the gradient of the Triton implementation matches the eager implementation."""
|
||||
|
||||
# Set the random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create random inputs
|
||||
batch_size = 2
|
||||
seq_len = 3
|
||||
vocab_size = 100
|
||||
top_k = 5
|
||||
|
||||
# Generate random student logits
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||
)
|
||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||
|
||||
# Generate random target token IDs, ensuring they're valid indices
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
|
||||
# Generate random target logprobs (before normalization)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
|
||||
# Normalize the target logprobs to ensure they form a valid distribution
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
|
||||
# Create a random mask with some tokens masked out
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Additional parameters
|
||||
num_items_in_batch = batch_size * seq_len
|
||||
kd_temperature = 1.0
|
||||
top_k_before_softmax = 0 # Test both modes
|
||||
|
||||
# Compute the loss and gradients with eager implementation
|
||||
loss_eager = eager_loss(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_eager.backward()
|
||||
grad_eager = student_logits.grad.clone()
|
||||
|
||||
# Reset gradients
|
||||
student_logits.grad.zero_()
|
||||
|
||||
# Compute the loss and gradients with Triton implementation
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_triton.backward()
|
||||
grad_triton = student_logits_triton.grad.clone()
|
||||
|
||||
# Compare loss values
|
||||
print(f"Eager loss: {loss_eager.item()}")
|
||||
print(f"Triton loss: {loss_triton.item()}")
|
||||
loss_diff = abs(loss_eager.item() - loss_triton.item())
|
||||
print(f"Loss difference: {loss_diff}")
|
||||
assert loss_diff < 1e-5, "Loss values differ significantly!"
|
||||
|
||||
# Compare gradients
|
||||
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||
print(f"Max gradient difference: {grad_diff}")
|
||||
|
||||
# Print some sample gradients
|
||||
sample_idx = (0, 0, 0) # (batch, seq, vocab)
|
||||
print(f"Sample eager gradient: {grad_eager[sample_idx].item()}")
|
||||
print(f"Sample triton gradient: {grad_triton[sample_idx].item()}")
|
||||
|
||||
# Compute relative difference for non-zero gradients
|
||||
mask = grad_eager.abs() > 1e-10
|
||||
if mask.sum() > 0:
|
||||
rel_diff = (
|
||||
(
|
||||
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||
/ (grad_eager[mask].abs() + 1e-10)
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
print(f"Max relative gradient difference: {rel_diff}")
|
||||
assert rel_diff < 1e-3, "Gradients differ significantly!"
|
||||
|
||||
# Also test top_k_before_softmax = 1 mode
|
||||
top_k_before_softmax = 1
|
||||
|
||||
# Reset the gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||
)
|
||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||
|
||||
# Compute the loss and gradients with eager implementation
|
||||
loss_eager = eager_loss(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_eager.backward()
|
||||
grad_eager = student_logits.grad.clone()
|
||||
|
||||
# Compute the loss and gradients with Triton implementation
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_triton.backward()
|
||||
grad_triton = student_logits_triton.grad.clone()
|
||||
|
||||
# Compare gradients for top_k_before_softmax = 1
|
||||
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||
print("\nWith top_k_before_softmax=1:")
|
||||
print(f"Max gradient difference: {grad_diff}")
|
||||
|
||||
# Compute relative difference for non-zero gradients
|
||||
mask = grad_eager.abs() > 1e-10
|
||||
if mask.sum() > 0:
|
||||
rel_diff = (
|
||||
(
|
||||
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||
/ (grad_eager[mask].abs() + 1e-10)
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
assert (
|
||||
rel_diff < 1e-3
|
||||
), f"Gradients differ significantly, Max relative gradient difference: {rel_diff}"
|
||||
@@ -1,204 +0,0 @@
|
||||
"""
|
||||
sanity checks on logsumexp kernel validity
|
||||
"""
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from axolotl.integrations.kd.topk_logprob.logsumexp import logsumexp_kernel
|
||||
|
||||
|
||||
# PyTorch implementation of logsumexp for reference
|
||||
def torch_logsumexp(logits):
|
||||
"""PyTorch implementation of logsumexp over last dimension"""
|
||||
return torch.logsumexp(logits, dim=-1)
|
||||
|
||||
|
||||
# Wrapper function for Triton logsumexp kernel
|
||||
def triton_logsumexp(logits):
|
||||
"""Triton implementation of logsumexp over last dimension"""
|
||||
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||
|
||||
grid = (B * S,)
|
||||
logsumexp_kernel[grid](
|
||||
logits.contiguous(),
|
||||
output,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
min(1024, triton.next_power_of_2(V)),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TritonLogSumExp(torch.autograd.Function):
|
||||
"""
|
||||
Wrap a custom autograd function to use the Triton logsumexp for gradient testing
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits):
|
||||
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||
|
||||
# Save inputs for backward pass
|
||||
ctx.save_for_backward(logits)
|
||||
ctx.shape = logits.shape
|
||||
|
||||
grid = (B * S,)
|
||||
logsumexp_kernel[grid](
|
||||
logits.contiguous(),
|
||||
output,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
min(1024, triton.next_power_of_2(V)),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(logits,) = ctx.saved_tensors
|
||||
|
||||
# For logsumexp, the gradient is softmax(input) * grad_output
|
||||
# First compute the logsumexp
|
||||
lse = TritonLogSumExp.apply(logits)
|
||||
|
||||
# Compute softmax by exponentiating differences
|
||||
softmax_output = torch.exp(logits - lse.unsqueeze(-1))
|
||||
|
||||
# Compute gradient of logsumexp by multiplying the softmax output by the gradient
|
||||
grad_input = softmax_output * grad_output.unsqueeze(-1)
|
||||
|
||||
return grad_input
|
||||
|
||||
|
||||
def test_logsumexp_values():
|
||||
"""Test that the Triton logsumexp implementation matches PyTorch's"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Test with various input shapes
|
||||
test_shapes = [
|
||||
(2, 3, 10), # small vocab
|
||||
(4, 5, 100), # medium vocab
|
||||
(2, 2, 32000), # large vocab (typical for LLMs)
|
||||
]
|
||||
|
||||
for shape in test_shapes:
|
||||
# Create random input tensors
|
||||
logits = torch.randn(shape, device="cuda", requires_grad=False)
|
||||
|
||||
# Compute logsumexp using both implementations
|
||||
torch_result = torch_logsumexp(logits)
|
||||
triton_result = triton_logsumexp(logits)
|
||||
|
||||
# Compare results
|
||||
max_diff = (torch_result - triton_result).abs().max().item()
|
||||
print(f"Shape {shape}, Max diff: {max_diff}")
|
||||
|
||||
# Assert that the results are very close
|
||||
assert max_diff < 1e-5, f"Results differ for shape {shape}: max diff {max_diff}"
|
||||
|
||||
|
||||
def test_logsumexp_edge_cases():
|
||||
"""Test edge cases for numerical stability"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Case 1: Very large values that might cause overflow
|
||||
logits_large = torch.ones(2, 3, 100, device="cuda") * 1000
|
||||
|
||||
# Case 2: Very small values that might cause underflow
|
||||
logits_small = torch.ones(2, 3, 100, device="cuda") * -1000
|
||||
|
||||
# Case 3: Mix of large and small values
|
||||
logits_mixed = torch.zeros(2, 3, 100, device="cuda")
|
||||
logits_mixed[:, :, 0] = 1000 # One very large value
|
||||
|
||||
# Case 4: All identical values
|
||||
logits_identical = torch.ones(2, 3, 100, device="cuda") * 5
|
||||
|
||||
# Case 5: Extreme values with NaN check
|
||||
logits_extreme = torch.cat(
|
||||
[
|
||||
torch.full((1, 3, 50), 1e10, device="cuda"),
|
||||
torch.full((1, 3, 50), -1e10, device="cuda"),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
for i, logits in enumerate(
|
||||
[logits_large, logits_small, logits_mixed, logits_identical, logits_extreme]
|
||||
):
|
||||
# Compute logsumexp using both implementations
|
||||
torch_result = torch_logsumexp(logits)
|
||||
triton_result = triton_logsumexp(logits)
|
||||
|
||||
# Check for NaNs
|
||||
assert not torch.isnan(
|
||||
torch_result
|
||||
).any(), f"PyTorch produced NaNs for case {i+1}"
|
||||
assert not torch.isnan(
|
||||
triton_result
|
||||
).any(), f"Triton produced NaNs for case {i+1}"
|
||||
|
||||
# Compare results
|
||||
max_diff = (torch_result - triton_result).abs().max().item()
|
||||
print(f"Edge case {i+1}, Max diff: {max_diff}")
|
||||
|
||||
# For very extreme values, allow a bit more tolerance
|
||||
if i == 4: # extreme case
|
||||
assert max_diff < 1e-2, f"Results differ too much for edge case {i+1}"
|
||||
else:
|
||||
assert max_diff < 1e-5, f"Results differ too much for edge case {i+1}"
|
||||
|
||||
|
||||
def test_logsumexp_gradients():
|
||||
"""Test that the gradients of Triton logsumexp match PyTorch's"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create input tensors with gradients enabled
|
||||
shapes = [(2, 3, 10), (4, 5, 100)]
|
||||
|
||||
for shape in shapes:
|
||||
# Create two identical tensors for PyTorch and Triton
|
||||
logits_torch = torch.randn(shape, device="cuda", requires_grad=True)
|
||||
logits_triton = logits_torch.clone().detach().requires_grad_(True)
|
||||
|
||||
# Forward pass
|
||||
torch_output = torch_logsumexp(logits_torch)
|
||||
triton_output = TritonLogSumExp.apply(logits_triton)
|
||||
|
||||
# Compare forward pass values
|
||||
max_diff_forward = (torch_output - triton_output).abs().max().item()
|
||||
assert max_diff_forward < 1e-5, f"Forward pass values differ for shape {shape}"
|
||||
|
||||
# Create random gradient
|
||||
grad_output = torch.randn_like(torch_output)
|
||||
|
||||
# Backward pass
|
||||
torch_output.backward(grad_output)
|
||||
triton_output.backward(grad_output)
|
||||
|
||||
# Compare gradients
|
||||
max_diff_grad = (logits_torch.grad - logits_triton.grad).abs().max().item()
|
||||
print(f"Shape {shape}, Max gradient diff: {max_diff_grad}")
|
||||
|
||||
# Assert that gradients are very close
|
||||
assert (
|
||||
max_diff_grad < 1e-5
|
||||
), f"Gradients differ for shape {shape}: max diff {max_diff_grad}"
|
||||
130
tests/e2e/test_deepseekv3.py
Normal file
130
tests/e2e/test_deepseekv3.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestDeepseekV3:
|
||||
"""
|
||||
Test case for DeepseekV3 models
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_lora_deepseekv3(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
|
||||
"trust_remote_code": True,
|
||||
"sample_packing": sample_packing,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 2048,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"message_property_mappings": {
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
"drop_system_message": True,
|
||||
"split": "train[:1%]",
|
||||
},
|
||||
],
|
||||
"special_tokens": {
|
||||
"bos_token": "<|begin▁of▁sentence|>",
|
||||
"eos_token": "<|end▁of▁sentence|>",
|
||||
},
|
||||
"chat_template": "deepseek_v3",
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_deepseekv3(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
|
||||
"trust_remote_code": True,
|
||||
"sample_packing": sample_packing,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
"split": "train[:1%]",
|
||||
},
|
||||
],
|
||||
"chat_template": "deepseek_v3",
|
||||
"special_tokens": {
|
||||
"bos_token": "<|begin▁of▁sentence|>",
|
||||
"eos_token": "<|end▁of▁sentence|>",
|
||||
},
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
@@ -102,11 +102,7 @@ def is_hopper():
|
||||
|
||||
|
||||
def check_tensorboard(
|
||||
temp_run_dir: str,
|
||||
tag: str,
|
||||
comparison_val: float,
|
||||
assertion_err: str,
|
||||
lt: bool = True,
|
||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
@@ -116,20 +112,10 @@ def check_tensorboard(
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
if lt:
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < comparison_val, (
|
||||
assertion_err % df.value.values[-1]
|
||||
)
|
||||
else:
|
||||
assert df.value.values[-1] < comparison_val, assertion_err
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
else:
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] > comparison_val, (
|
||||
assertion_err % df.value.values[-1]
|
||||
)
|
||||
else:
|
||||
assert df.value.values[-1] > comparison_val, assertion_err
|
||||
assert df.value.values[-1] < lt_val, assertion_err
|
||||
|
||||
|
||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||
|
||||
Reference in New Issue
Block a user