diff --git a/README.md b/README.md index 395f52042..836aab2be 100644 --- a/README.md +++ b/README.md @@ -385,6 +385,15 @@ pretraining_dataset: # hf path only +##### Template-Free + +- `input_output`: template-free prompt construction + ```json + {"segments": [{"label": true|false, "text": "..."}]} + ``` + +This is a special format that allows you to construct prompts without using templates. This is for advanced users who want more freedom with prompt construction. See [these docs](docs/input_output.md) for more details. + ##### Conversation - `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt) diff --git a/docs/input_output.md b/docs/input_output.md new file mode 100644 index 000000000..dbc6979c6 --- /dev/null +++ b/docs/input_output.md @@ -0,0 +1,260 @@ +# Template-free prompt construction with the `input_output` format + + + +- [Background](#background) + - [Masking Inputs](#masking-inputs) + - [You may not want prompt templates](#you-may-not-want-prompt-templates) + - [The `input_output` format](#the-input_output-format) +- [Usage](#usage) + - [1. Prepare Data](#1-prepare-data) + - [2. Use `type: input_output`](#2-use-type-input_output) + - [3. Check the prompts](#3-check-the-prompts) + + + + + +## Background + + + +### Masking Inputs + +One of the most popular features of +[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is +setting the following configuration value: + + +```yaml +train_on_inputs: false +``` + +If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset) +such as `alpaca` or `chatml`, axolotl knows what is an input +(i.e. human) vs. an output (i.e. the assistant) and masks the input +labels so that your model can focus on predicting the outputs only. + + + +### You may not want prompt templates + +However, there are many situations where you don't want to use one of +these formats or templates (I usually don't!). This is because they can: + +- Add unnecessary boilerplate to your prompts. +- Create artifacts like special delimiters `<|im_start|>` that can + quickly become footguns if you don't include them correctly at + inference time. +- Enforce a *chat* interface when you do not want one. Sometimes you + just want to fine-tune a model to a very specific task and do NOT + want multi-turn conversations, roles, etc. +- Limit you to only certain roles that the template allows. + + + +### The `input_output` format + +You can construct your prompts without a template by using the +`input_output` format, by setting `type: input_output` in your +configuration file like this: + +**config.yml** + +```yaml +train_on_inputs: false # Mask segments of your data +datasets: + - path: output.jsonl + type: input_output # use template free prompt construction +``` + +Unlike `type: completion`, which is also template-free, +`type: input_output` allows you to mask segments of your text. More +details on how this works are described below. + + + +## Usage + +This is how you can use the `input_output` format: + + + +### 1. Prepare Data + +To use the `input_output` format, collect your data in the following +format into a jsonl file (below is the first row from the file +`output`.jsonl` pretty printed): + +```bash +$ head -n1 output.jsonl | python -m json.tool + +{.cell-output .cell-output-stdout} + { + "segments": [ + { + "label": true, + "text": "Hello\n" + }, + { + "label": true, + "text": "hi there!. " + }, + { + "label": false, + "text": "goodbye " + }, + { + "label": true, + "text": "farewell" + } + ] + } +``` + +Set `label:false` when you want to mask a segment of text so that the +model isn't trained on it. Some things to keep in mind: + +> [!IMPORTANT] +> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl + concatenates all the segments as-is.** The tokenizer doesn't add + anything additional. Notice how I added spaces, newlines, `` + (BOS), and `` (EOS) myself. +> 2. Make sure you check the materialized output to validate that the + prompt is getting assembled how you like. + + + +### 2. Use `type: input_output` + +Let's materialize data with our `output.jsonl` file by setting +`type: input_output` in our axolotl config: + +```yaml +# training_config.yaml +base_model: mistralai/Mistral-7B-v0.1 +data_seed: 49 +seed: 49 + +datasets: + - path: output.jsonl + type: input_output +val_set_size: 0.1 + +sequence_len: 896 +sample_packing: false + +micro_batch_size: 2 +gradient_accumulation_steps: 3 +eval_batch_size: 2 +num_epochs: 1 +learning_rate: 0.0002 + +train_on_inputs: false +special_tokens: + bos_token: "" + eos_token: "" + unk_token: "" +``` + +You can use the following command to materialize your data. The +`--debug` flag will print the tokens, along with the labels so you can +verify that the correct items are being ignored: + +```bash +$ python -m axolotl.cli.preprocess training_config.yaml --debug + +... +[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] (1, 1) Hello(22557, 22557) +(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) (2, 2) + +``` + +The format is `decoded_token`(`label`, `token_id`), for example, +`(1, 1)` means that the token is ``, the label is `1` and the +token_id is `1`. When the label is `-100` then that token is ignored for +training. + + + +### 3. Check the prompts + +Here is another way to check the materialized output: + +```python +from transformers import AutoTokenizer +from datasets import load_from_disk +import yaml + +directory = !ls last_run_prepared/ +with open('training_config.yaml', 'r') as f: + cfg = yaml.safe_load(f) +model_id = cfg['base_model'] +tok = AutoTokenizer.from_pretrained(model_id) +ds = load_from_disk(f'last_run_prepared/{directory[0]}/') +``` + +```python +>>> row = ds[0] +>>> print(tok.decode(row['input_ids'])) + Hello + hi there!. goodbye farewell +``` + +We can check that the right tokens are ingored by comparing the labels +to each token: + +```python +import pandas as pd +pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in + zip(row['input_ids'], row['labels'])]) +``` + +| token | label | id | +|-------|-------|-------| +| 0 | \ | 1 | +| 1 | Hello | 22557 | +| 2 | \\n | 13 | +| 3 | hi | 12014 | +| 4 | there | 736 | +| 5 | ! | 28808 | +| 6 | . | 28723 | +| 7 | | 28705 | +| 8 | good | -100 | +| 9 | bye | -100 | +| 10 | | -100 | +| 11 | fare | 19111 | +| 12 | well | 5458 | +| 13 | \| 2 | + + + +If we look at the input data, the above table seems correct! (The jsonl +version is repeated below for reference): + + +```bash +$ head -n1 output.jsonl | python -m json.tool + +{.cell-output .cell-output-stdout} + { + "segments": [ + { + "label": true, + "text": "Hello\n" + }, + { + "label": true, + "text": "hi there!. " + }, + { + "label": false, + "text": "goodbye " + }, + { + "label": true, + "text": "farewell" + } + ] + } +```