Compare commits

..

43 Commits

Author SHA1 Message Date
NanoCode012
10d18e6c97 fix(test): replace jackfram llama with smollm 2025-02-28 16:40:49 +07:00
NanoCode012
75cbd15301 Fix(doc): address missing doc changes (#2362)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.5.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* fix: add multiple tips about eos_token masking

* fix: format dataset preprocessing doc

* Update docs/dataset-formats/conversation.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-02-25 13:50:02 -05:00
NanoCode012
2efe1b4c09 Feat(doc): Reorganize documentation, fix broken syntax, update notes (#2348)
* feat(doc): organize docs, add to menu bar, fix broken formatting

* feat: add link to custom integrations

* feat: update readme for integrations to include citations and repo link

* chore: update lm_eval info

* chore: use fullname

* Update docs/cli.qmd per suggestion

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* feat: add sweep doc

* feat: add kd doc

* fix: remove toc

* fix: update deprecation

* feat: add more info about chat_template issues

* fix: heading level

* fix: shell->bash code block

* fix: ray link

* fix(doc): heading level, header links, formatting

* feat: add grpo docs

* feat: add style changes

* fix: wrong cli arg for lm-eval

* fix: remove old run method

* feat: load custom integration doc dynamically

* fix: remove old cli way

* fix: toc

* fix: minor formatting

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-02-25 16:09:37 +07:00
NanoCode012
1110a37e21 feat: add deepseek_v3 sample packing (#2230) 2025-02-24 15:03:15 -05:00
Wing Lian
9850f42204 bump liger to 0.5.3 (#2353) 2025-02-24 12:40:54 -05:00
Matt Baker
00fc8109e4 Correctly reference mount paths (#2347)
* Correctly reference mount paths

* Also fix mount paths in lm_eval

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-02-24 11:12:57 -05:00
Wing Lian
2d5826f544 Relicense the logprob KD loss functions as Apache 2.0 (#2358) 2025-02-23 12:31:35 -05:00
Wing Lian
a4170030ab don't install extraneous old version of pydantic in ci and make sre to run multigpu ci (#2355) 2025-02-21 22:06:29 -05:00
NanoCode012
bf842730a5 fix(doc): add missing auto_find_batch_size (#2339) [skip ci] 2025-02-21 11:56:38 +07:00
Wing Lian
1db6ad60a7 support for passing init_lora_weights to lora_config (#2352) 2025-02-20 22:56:34 -05:00
salman
29b366b2e1 Bumping 0.15.1 TRL version for GRPO+PEFT fix (#2344)
* bumping TRL version

* apply upstream fixes to our custom fix

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-02-20 22:56:04 -05:00
NanoCode012
b53a41372f feat: update transformers version to 4.49.0 (#2340) 2025-02-20 21:12:06 -05:00
Wing Lian
02f45e94be calculate sample length fixes and SFT splitting fixes (#2351)
* fix chat template splitting long samples across multiple rows

* make the preprocessing faster
2025-02-20 14:29:58 -05:00
Dan Saunders
954e192f38 quick formatting fix for LoRA optims doc (#2349) 2025-02-19 09:23:31 -05:00
Tobias
8dfadc2b3c Fix sample packing producing longer sequences than specified by sequence_len (#2332)
* Extend MultiPackBatchSampler test to include shorter sequence length and drop long sequences filter

* Fix get_dataset_lengths for datasets that were previously filtered (e.g., with drop_long_seq_in_dataset)

* Update src/axolotl/utils/samplers/utils.py

Fix get_dataset_lengths for datasets that do not have position_ids or length attributes

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2025-02-19 12:02:35 +07:00
Wing Lian
23a9fcb0a7 make sure chatml dpo dataset loading works (#2333) 2025-02-18 16:08:40 -05:00
Dan Saunders
c3d4f6e295 Doc fix: TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL not necessary to use Triton kernel patches (#2343)
* removing note about TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL

* suggest using TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL for memory efficient attn
2025-02-18 10:06:31 -05:00
Wing Lian
7fa690fac8 bump dev version (#2342) 2025-02-18 04:30:59 -05:00
Wing Lian
3c743c4bfb v0.7.0 for release (#2341)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.5.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-02-18 04:26:21 -05:00
NJordan72
91bb95685a chore: cleanup deprecated config elements (#2309)
* feat: update metadata fields and refactor config class in axolotlinputconfig

- Replace `metadata` fields with `json_schema_extra` in RayConfig class.
- Replace `Config` class with `ConfigDict` in AxolotlInputConfig.
- Set `populate_by_name` to `True` directly in `ConfigDict` instance.

* feat: update axolotlinputconfig in utils

* Replace `conlist` with `Annotated` for `datasets`, `test_datasets`, and `pretraining_dataset` fields
* Change default values for `lr_scheduler` and `optimizer` fields in `HyperparametersConfig` class
* Remove unnecessary Union from `evals_per_epoch` field in `AxolotlInputConfig` class
* Import `MinLen` from `annotated_types` module
* Remove import of `conlist` from `pydantic` module

* feat: update modelinputconfig and axolotlinputconfig in v0_4_1

- Removed ConfigDict import from pydantic in `src/axolotl/utils/config/models/input/v0_4_1/__init__.py`
- Added `model_config` with `protected_namespaces` to ModelInputConfig
- Replaced `config: ConfigDict` with `model_config` in AxolotlInputConfig
- Set `populate_by_name` to True in `model_config` for AxolotlInputConfig

* chore: get rid of unused import
2025-02-18 15:39:24 +07:00
NJordan72
b194e17c28 feat: add config for optional parameters in a chat message (#2260)
* feat: add config for optional parameters in a chat message

* chore: cleanup

* chore: fix nits and add light docs

* docs: update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* feat: configurable message mappings, jinja template analyzer

* chore: handle bradley terry

* docs: update docs

* refactor: change order of mappings, improve message transform

* refactor: make chat awware of property mappings

* chore: remove .python-version

* chore: revert change

* chore: add dataset validation to tests where appropriate

* chore: add dataset validation to tests where appropriate

* chore: clean up handling of ds_cfg

* chore: recursively serialize config

* make sure to use the return value from validate_config

* DefaultDict pickle/unpickle fix

* fix super call for override

* refactor: message fields

* chore: empty commit

* tests: validate config before using

* chore: add config validation to all e2e tests

* chore: add unneeded logging

* chore: add missed config validation

* chore: pass field_messages to prompter

* test: fix borked test

* chore: remove uninteded file

* chore: add deprecation warning and update chat_datasets script

* chore: lint

* refactor: message fields

* feat: update axolotlinputconfig and test_models

- add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py
- remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes
- update model_dump method in axolotlinputconfig to exclude none values
- correct typo in test_models.py comment

* feat: simplify dpodataset and ktodataset classes in config models

removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets.

* feat: improve readability and structure in dataset configuration models

this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file.

* feat: change log level from info to debug in chattemplatestrategy

* feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy

- Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor
- Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor
- Add `messages_array_name` property to `ChatTemplatePrompter`
- Change `processor` type to Optional in `ChatTemplatePrompter`
- Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt`
- Remove `_messages` property from `ChatTemplateStrategy`
- Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor
- Remove `messages` getter and setter from `ChatTemplateStrategy`
- Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread`
- Remove condition to set `messages` field in `load` function

* feat(tests/utils): ignore type check in load_model call in test_models.py

* feat: improve type handling and test structure in chat templates

- Add return type hint for `get_chat_template` function in `chat_templates.py`
- Remove unnecessary assignment of `strategy.messages` in several test cases
- Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py`
- Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py`

* feat(axolotl): enhance chat strategy with datasetconfig support

This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include:

- Importing Union from typing and BaseModel from pydantic.
- Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader.
- Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances.
- Refactoring the prompter_params and strategy_params for better readability.
- Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method.

* feat: update message handling in btchattemplatestrategy

* Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages"
* Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages"
* Add a new attribute "messages_array_name" to the `load` function parameters
* Remove the conditional attribute assignment for "field_messages" in the `load` function

* feat: add config validation in test_kd.py

- Import `validate_config` from `axolotl.utils.config`
- Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class

* feat: enhance config validation and capabilities handling

* Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
* Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)`
* Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump

* feat: update config validation in axolotl utils

- Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
- Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities`

* feat: refactor strategyloader in chat_template.py

- Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)`
- Created a new function, `_get_strategy_cls()`, to obtain the strategy class
- Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation

* trigger CI

* chore: revert dataset config changes for kto/dpo

* subject: refactor: rename 'messages_array_name' to 'field_messages'

Body:
- Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py'
- Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change
- Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name
- Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages'
- Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name'

* feat: refactor prompt strategies and update config models

* Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py`
* Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists
* Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py`
* Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model

* chore: remove unused input

* chore: remove redundant type ignore

* fix: remove old configs and update examples

* fix: type check

* fix: remove loading old config in ChatMessage

* fix: update faq with potential new undefinederror

* fix: add debug if property mapped is not found

* chore: improve explanation for unmapped properties

* fix: update docs with new config

* chore: add note for deprecation config and del old config from dict

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-02-18 09:59:27 +07:00
Dan Saunders
3aac3b1da9 Move sweeps code to another module (#2338) 2025-02-17 15:46:04 -05:00
Dan Saunders
3d8425fa91 Activation function Triton kernels, LoRA custom autograd functions (#2324)
* LoRA + activation fn Triton kernels: initial commit

* implementing optims

* finalizing MLP LoRA kernels and progress on QKV / W kernels

* updates

* O projection optim

* adding monkey patching logic

* doc strings, typing, pre-commit fixes

* updates

* adding lora 8b kernels example

* working on fsdp support

* tests and fixes

* small fixes, getting tests to pass, adding doc strings

* integration tests for LoRA patching

* config.qmd

* remove unneeded pytest fixture

* fix

* review comments first pass

* improving tests, attention class agnostic patching

* adding support for more archs

* wip SiLU / GELU impls

* improved testing, small updates, etc.

* slightly updating docs

* rebase

* fixing test_attention_patching_integration

* additional review comments, fixing test in CI (hopefully)

* isolating problematic patching test

* relaxing allclose threshold to reduce flakiness

* fixing accidental change

* adding model arch agnostic attention class fetching

* removing unused activations
2025-02-17 14:23:15 -05:00
Seungduk Kim
97a2fa2781 Select input_ids explicitly after panda conversion (#2335)
Without selecting the column, applying `len` counts the whole row as 1 which resulting the total number of the samples instead of the token counts.
2025-02-17 00:07:27 -05:00
Wing Lian
a98526ef78 add support for include_tokens_per_second in training args (#2269)
* add support for include_tokens_per_second in training args

* Update docs/config.qmd

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update src/axolotl/core/trainer_builder.py

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-02-13 17:39:19 -05:00
NanoCode012
2e57391bf8 fix: add missing shards_idx, preprocess_shards to docs and validator (#2331) 2025-02-13 17:28:21 -05:00
minpeter
aa45fed451 Add bos_token and add_generation_prompt to the alpaca chat template (#2322)
* fix alpaca add_generation_prompt

* Alpaca template considering multi-turn

Co-authored-by: xzuyn <xzuyn@users.noreply.github.com>

---------

Co-authored-by: xzuyn <xzuyn@users.noreply.github.com>
2025-02-13 17:27:55 -05:00
NanoCode012
a09a5cfd1c feat(doc): add tensorboard config to docs (#2329) 2025-02-13 16:02:16 -05:00
NanoCode012
40362d60e0 feat(doc): Improve guide to dataset types with better examples (#2286) 2025-02-13 16:01:41 -05:00
Wing Lian
ffae8d6a95 GRPO (#2307) 2025-02-13 16:01:01 -05:00
Lee Park
fdbb1a207c [Fixing #2149] load_from_disk for RL-type training (#2193)
* Update rl.py

* Update rl.py

* Update rl.py

* refactor pref dataset loading to reuse load_dataset_w_config

* refactor again after rebase from main

* chore: add docstring and types

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-02-13 08:31:07 -05:00
Wing Lian
30046315d9 disable ray tests for latest torch release (#2328)
* disable ray tests for latest torch release

* move decorator from class to method
2025-02-12 18:29:02 -05:00
Wing Lian
e37a4a536a lint docs (#2327) 2025-02-12 10:04:26 -05:00
Sung Ching Liu
44f64ab627 Update faq.qmd (#2319)
* Update faq.qmd

Added Q&A for being stuck on saving preprocessed datasets

* Update faq.qmd

added details on preprocessing on cpu

* Update faq.qmd

* Update faq.qmd
2025-02-11 13:18:31 -05:00
NanoCode012
826f1b1494 feat(doc): Add multi-node torchrun info (#2304) 2025-02-08 06:02:02 -05:00
NanoCode012
526e5ee8b8 fix(config): missing config not being documented and fix model_ override (#2317)
* fix(config): missing config not being documented and fix model_ space override

* fix: delete redundant field
2025-02-08 06:01:48 -05:00
NanoCode012
fd8cb32547 chore: remove redundant py310 from tests (#2316) 2025-02-07 21:34:16 -05:00
NanoCode012
e48e2df4dd feat: update FA to 2.7.4.post1 which includes torch2.6 binary (#2315) 2025-02-07 21:34:01 -05:00
Wing Lian
b7616022ab bump transformers to 4.48.3 (#2318) 2025-02-07 21:33:44 -05:00
Wing Lian
1faf1a5c5a batch add of spectrum snr results (#2320) 2025-02-07 21:33:14 -05:00
NanoCode012
5bbad5ef93 feat: add torch2.6 to ci (#2311) 2025-02-07 07:28:54 -05:00
Wing Lian
a971eb4ce6 Torch 2.6 support for base docker image (#2312) 2025-02-05 09:24:02 -05:00
NanoCode012
a620d481e2 fix: drop long seq even if not sample packing (#2211)
* fix: drop long seq even if not sample packing

* fix: logging import

* fix: cfg passed being none

* fix: try to fix logging

* fix: refactor call to not use accelerate log

* fix: try to fix circular import issue

* fix: don't drop when skip prepare

* chore: remove duplicate line

* fix: update warning to mention that sequences will be trimmed

* fix: do not drop seq if input_ids don't exist

* fix: increase RM unittest sequence length to reduce trim warnings

* fix: solve conflicts

* fix: default min_seq_len in case of None
2025-02-04 09:43:35 -05:00
184 changed files with 18739 additions and 8647 deletions

View File

@@ -22,12 +22,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.10"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
@@ -40,6 +34,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -19,7 +19,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.11'
- name: install dependencies
run: |
python3 -m pip install jupyter

View File

@@ -19,6 +19,6 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1

View File

@@ -24,8 +24,13 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
axolotl_extras: vllm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -4,6 +4,10 @@ on:
pull_request:
paths:
- 'tests/e2e/multigpu/*.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -24,13 +28,21 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
axolotl_extras: # no vllm support for 2.4.1
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
# awaiting vllm#12721
axolotl_extras:
num_gpus: 2
nightly_build: "true"
@@ -42,7 +54,7 @@ jobs:
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip

View File

@@ -22,6 +22,11 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -36,7 +36,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install dependencies
run: |

View File

@@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
@@ -25,13 +25,8 @@ jobs:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20
steps:
@@ -112,13 +107,20 @@ jobs:
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip

View File

@@ -35,7 +35,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
@@ -48,13 +48,8 @@ jobs:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20
steps:
@@ -127,7 +122,7 @@ jobs:
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20
steps:
@@ -209,14 +204,14 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
axolotl_extras: vllm
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
@@ -251,13 +246,19 @@ jobs:
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip

View File

@@ -50,13 +50,14 @@ Features:
## 🚀 Quick Start
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.10
- Python 3.11
- PyTorch ≥2.4.1
### Installation
```shell
```bash
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
@@ -68,7 +69,7 @@ Other installation approaches are described [here](https://axolotl-ai-cloud.gith
### Your First Fine-tune
```shell
```bash
# Fetch axolotl examples
axolotl fetch examples

View File

@@ -3,10 +3,12 @@ project:
website:
title: "Axolotl"
description: "Fine-tuning"
description: "We make fine-tuning accessible, scalable, and fun"
favicon: favicon.jpg
navbar:
title: Axolotl
logo: image/axolotl_logo_digital_white.svg
title: false
background: dark
pinned: false
collapse: false
@@ -25,33 +27,58 @@ website:
contents:
- text: Home
href: index.qmd
- section: "How-To Guides"
- section: "Getting Started"
contents:
# TODO Edit folder structure after we have more docs.
- docs/getting-started.qmd
- docs/installation.qmd
- docs/debugging.qmd
- docs/cli.qmd
- docs/inference.qmd
- docs/multipack.qmd
- docs/fsdp_qlora.qmd
- docs/input_output.qmd
- docs/rlhf.qmd
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/unsloth.qmd
- docs/amd_hpc.qmd
- docs/ray-integration.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Deployments"
contents:
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/ray-integration.qmd
- docs/amd_hpc.qmd
- docs/mac.qmd
- section: "How To Guides"
contents:
- docs/multimodal.qmd
- docs/rlhf.qmd
- docs/reward_modelling.qmd
- docs/lr_groups.qmd
- docs/lora_optims.qmd
- section: "Core Concepts"
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/multipack.qmd
- section: "Advanced Features"
contents:
- docs/fsdp_qlora.qmd
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- section: "Troubleshooting"
contents:
- docs/faq.qmd
- docs/debugging.qmd
- docs/nccl.qmd
- section: "Reference"
contents:
- docs/config.qmd
- docs/faq.qmd
format:
html:
theme: materia
theme: darkly
css: styles.css
toc: true

View File

@@ -4,8 +4,8 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -37,15 +37,11 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])

View File

@@ -1,6 +1,4 @@
"""
modal application to run axolotl gpu tests in Modal
"""
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os

View File

@@ -1,5 +1,5 @@
---
title: Training with AMD GPUs on HPC Systems
title: AMD GPUs on HPC Systems
description: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs
---

View File

@@ -1,28 +1,19 @@
# Axolotl CLI Documentation
---
title: "CLI Reference"
format:
html:
toc: true
toc-expand: 2
toc-depth: 3
execute:
enabled: false
---
The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers
the CLI commands, their usage, and common examples.
### Table of Contents
- Basic Commands
- Command Reference
- fetch
- preprocess
- train
- inference
- merge-lora
- merge-sharded-fsdp-weights
- evaluate
- lm-eval
- Legacy CLI Usage
- Remote Compute with Modal Cloud
- Cloud Configuration
- Running on Modal Cloud
- Cloud Configuration Options
### Basic Commands
## Basic Commands
All Axolotl commands follow this general structure:
@@ -32,9 +23,9 @@ axolotl <command> [config.yml] [options]
The config file can be local or a URL to a raw YAML file.
### Command Reference
## Command Reference
#### fetch
### fetch
Downloads example configurations and deepspeed configs to your local machine.
@@ -49,7 +40,7 @@ axolotl fetch deepspeed_configs
axolotl fetch examples --dest path/to/folder
```
#### preprocess
### preprocess
Preprocesses and tokenizes your dataset before training. This is recommended for large datasets.
@@ -74,7 +65,7 @@ dataset_prepared_path: Local folder for saving preprocessed data
push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)
```
#### train
### train
Trains or fine-tunes a model using the configuration specified in your YAML file.
@@ -95,7 +86,38 @@ axolotl train config.yml --no-accelerate
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
```
#### inference
It is possible to run sweeps over multiple hyperparameters by passing in a sweeps config.
```bash
# Basic training with sweeps
axolotl train config.yml --sweep path/to/sweep.yaml
```
Example sweep config:
```yaml
_:
# This section is for dependent variables we need to fix
- load_in_8bit: false
load_in_4bit: false
adapter: lora
- load_in_8bit: true
load_in_4bit: false
adapter: lora
# These are independent variables
learning_rate: [0.0003, 0.0006]
lora_r:
- 16
- 32
lora_alpha:
- 16
- 32
- 64
```
### inference
Runs inference using your trained model in either CLI or Gradio interface mode.
@@ -115,7 +137,7 @@ cat prompt.txt | axolotl inference config.yml \
--base-model="./completed-model"
```
#### merge-lora
### merge-lora
Merges trained LoRA adapters into the base model.
@@ -137,7 +159,7 @@ gpu_memory_limit: Limit GPU memory usage
lora_on_cpu: Load LoRA weights on CPU
```
#### merge-sharded-fsdp-weights
### merge-sharded-fsdp-weights
Merges sharded FSDP model checkpoints into a single combined checkpoint.
@@ -146,7 +168,7 @@ Merges sharded FSDP model checkpoints into a single combined checkpoint.
axolotl merge-sharded-fsdp-weights config.yml
```
#### evaluate
### evaluate
Evaluates a model's performance using metrics specified in the config.
@@ -155,27 +177,27 @@ Evaluates a model's performance using metrics specified in the config.
axolotl evaluate config.yml
```
#### lm-eval
### lm-eval
Runs LM Evaluation Harness on your model.
```bash
# Basic evaluation
axolotl lm-eval config.yml
# Evaluate specific tasks
axolotl lm-eval config.yml --tasks arc_challenge,hellaswag
```
Configuration options:
```yaml
lm_eval_tasks: List of tasks to evaluate
lm_eval_batch_size: Batch size for evaluation
output_dir: Directory to save evaluation results
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
- hellaswag
lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
### Legacy CLI Usage
## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
@@ -195,12 +217,18 @@ accelerate launch -m axolotl.cli.inference config.yml \
--lora_model_dir="./outputs/lora-out" --gradio
```
### Remote Compute with Modal Cloud
::: {.callout-important}
When overriding CLI parameters in the legacy CLI, use same notation as in yaml file (e.g., `--lora_model_dir`).
**Note:** This differs from the new Click-based CLI, which uses dash notation (e.g., `--lora-model-dir`). Keep this in mind if you're referencing newer documentation or switching between CLI versions.
:::
## Remote Compute with Modal Cloud
Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a
cloud YAML file alongside your regular Axolotl config.
#### Cloud Configuration
### Cloud Configuration
Create a cloud config YAML with your Modal settings:
@@ -215,13 +243,17 @@ branch: main # Git branch to use (optional)
volumes: # Persistent storage volumes
- name: axolotl-cache
mount: /workspace/cache
- name: axolotl-data
mount: /workspace/data
- name: axolotl-artifacts
mount: /workspace/artifacts
env: # Environment variables
- WANDB_API_KEY
- HF_TOKEN
```
#### Running on Modal Cloud
### Running on Modal Cloud
Commands that support the --cloud flag:
@@ -239,18 +271,18 @@ axolotl train config.yml --cloud cloud_config.yml --no-accelerate
axolotl lm-eval config.yml --cloud cloud_config.yml
```
#### Cloud Configuration Options
### Cloud Configuration Options
```yaml
provider: compute provider, currently only `modal` is supported
gpu: GPU type to use
gpu_count: Number of GPUs (default: 1)
memory: RAM in GB (default: 128)
timeout: Maximum runtime in seconds
timeout_preprocess: Preprocessing timeout
branch: Git branch to use
docker_tag: Custom Docker image tag
volumes: List of persistent storage volumes
env: Environment variables to pass
secrets: Secrets to inject
provider: # compute provider, currently only `modal` is supported
gpu: # GPU type to use
gpu_count: # Number of GPUs (default: 1)
memory: # RAM in GB (default: 128)
timeout: # Maximum runtime in seconds
timeout_preprocess: # Preprocessing timeout
branch: # Git branch to use
docker_tag: # Custom Docker image tag
volumes: # List of persistent storage volumes
env: # Environment variables to pass
secrets: # Secrets to inject
```

View File

@@ -46,6 +46,10 @@ overrides_of_model_config:
type: # linear | dynamic
factor: # float
# optional overrides the base model loading from_pretrained
overrides_of_model_kwargs:
# use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
@@ -87,7 +91,12 @@ datasets:
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
@@ -133,10 +142,19 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Key for role in each message (default: "role")
message_field_role: role
# Key for content in each message (default: "content")
message_field_content: content
# Mapping of properties from the input dataset to the chat template.
# (default: message_property_mappings={'role':'role', 'content':'content'})
# If a property exists in the template but not in this mapping, the system will attempt
# to load it directly from the message using the property name as the key.
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
# while 'value' is loaded and used as 'content' in the chat template.
message_property_mappings:
role: from
content: value
# ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
@@ -148,7 +166,7 @@ datasets:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
# Note: If the below 4 fields are empty, defaults to training only on the last message.
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["assistant"] # default
@@ -156,6 +174,7 @@ datasets:
# - all: train on all EOS tokens
# - turn (default): train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
@@ -296,6 +315,13 @@ lora_modules_to_save:
lora_fan_in_fan_out: false
# Apply custom LoRA autograd functions and activation function Triton kernels for
# speed and memory savings
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
# LoRA+ hyperparameters
# For more details about the following options, see:
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
@@ -344,6 +370,9 @@ comet_mode: # Create a new experiment ("create") or log to an existing one ("get
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Tensorboard
use_tensorboard: # Optional[bool]
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -378,6 +407,12 @@ save_total_limit: # Checkpoints saved at a time
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool]
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]

View File

@@ -0,0 +1,57 @@
---
title: Custom Integrations
toc: true
toc-depth: 3
---
```{python}
#| echo: false
import re
def process_readme(integration_name):
try:
path = f'../src/axolotl/integrations/{integration_name}/README.md'
with open(path, 'r') as f:
txt = f.read()
# Remove h1 headings
txt = re.sub(r'^# .*\n?', '', txt, flags=re.MULTILINE)
# Convert h2 to h3
txt = re.sub(r'^## ', '### ', txt, flags=re.MULTILINE)
return txt
except FileNotFoundError:
return None
def print_section(name, folder_name):
output = f"\n## {name}\n"
content = process_readme(folder_name)
if content:
output += content
output += f"\nPlease see reference [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/{folder_name})\n"
return output
```
```{python}
#| output: asis
#| echo: false
# Introduction text
print("""
Axolotl adds custom features through `integrations`. They are located within the `src/axolotl/integrations` directory.
To enable them, please check the respective documentations.
""")
# Sections
sections = [
("Cut Cross Entropy", "cut_cross_entropy"),
("Grokfast", "grokfast"),
("Knowledge Distillation (KD)", "kd"),
("Liger Kernels", "liger"),
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
("Spectrum", "spectrum")
]
for section_name, folder_name in sections:
print(print_section(section_name, folder_name))
```

View File

@@ -6,7 +6,9 @@ order: 3
## sharegpt
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
::: {.callout-important}
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
:::
## pygmalion
@@ -22,7 +24,7 @@ Chat Template strategy uses a jinja2 template that converts a list of messages i
{"conversations": [{"role": "...", "content": "..."}]}
```
See `config.qmd` for full configs and supported templates.
See [configs](../config.qmd) for full configs and supported templates.
### Migrating from sharegpt
@@ -42,8 +44,9 @@ datasets:
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
@@ -52,8 +55,9 @@ datasets:
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
```
We recommend checking the below examples for other usecases.
@@ -100,6 +104,10 @@ datasets:
type: chat_template
```
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
:::
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
@@ -138,12 +146,15 @@ datasets:
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.
::: {.callout-tip}
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::

View File

@@ -1,14 +1,491 @@
---
title: Dataset Formats
description: Supported dataset formats.
listing:
fields: [title, description]
type: table
sort-ui: false
filter-ui: false
max-description-length: 250
description: Guide to Dataset Formats in Axolotl
back-to-top-navigation: true
toc: true
toc-depth: 5
---
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL format. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
Below are these various formats organized by task:
Axolotl is a training framework that aims to make the process convenient yet flexible to users by simply passing a config yaml file.
As there are a lot of available options in Axolotl, this guide aims to provide an simplify the user experience to choosing the proper choice.
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
## Pre-training
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
A sample format for a pre-training dataset is as follows:
```json
{"text": "first row"}
{"text": "second row"}
...
```
It is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.
Axolotl supports loading from a Hugging Face hub repo or from local files.
::: {.callout-important}
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
:::
### Pre-training from Hugging Face hub datasets
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
```yaml
pretraining_dataset: hf_org/name
```
### Pre-training from local dataset files
Given a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:
```yaml
pretraining_dataset:
- path: json
data_files:
- A.jsonl
- B.jsonl
- C.jsonl
```
While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
### Pre-training without streaming
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
From Hugging Face:
```yaml
datasets:
- path: hf_org/name
type: completion
```
From local files (either example works):
```yaml
datasets:
- path: A.jsonl
type: completion
- path: json
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
type: completion
```
### Pre-training dataset configuration tips
#### Setting max_steps
When using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.
Therefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.
One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.
#### Group_by_length
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
### Reference
Please see docs [here](pretraining.qmd).
## Supervised fine-tuning (SFT)
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.
As there are a wide variety of dataset formats, Axolotl tries to support a majority of the formats available in public datasets.
Axolotl provides four approaches for loading datasets, however, it's easier to work backwards from the dataset you have available to figure out which approach to use.
A flow chart is as follows:
1. Do you already have the dataset tokenized? If yes, check [Pre-Tokenized Dataset](#pre-tokenized-dataset).
2. Do you want to format the dataset yourself and manually choose each section to mask? If yes, check [Template Free Dataset](#template-free-dataset)
3. Is your dataset in a "conversation" format, containing a `list[messages]`? If yes, check [Conversation Dataset](#conversation-dataset)
4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.
::: {.callout-tip}
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
:::
### Pre-Tokenized Dataset
We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
::: {.callout-tip}
Make sure to add BOS/EOS tokens to your prompt and mask it appropriately.
:::
A config for this would look like:
```yaml
datasets:
- path: A.jsonl
type:
```
::: {.callout-note}
`type: ` is empty!
:::
Reference: [Pre-Tokenized Dataset Documentation](tokenized.qmd).
### Template Free Dataset
We reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldn't suffice.
In the example below, you could see that there is no proper structure. At the same time, it's very flexible as there are no constraints on how your prompt can look.
```json
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```
Each prompt must be have a key called `segments` which is a list of `{ text, label }`.
```yaml
datasets:
- path: A.jsonl
type: input_output
```
Reference: [Template Free Documentation](template_free.qmd).
### Conversation Dataset
`conversation` messages are a list of messages which usually contain a `role` and `content` key.
::: {.callout-tip}
Fun fact: Axolotl synonymously refers to "chat" messages as `conversation` messages due to how FastChat initially used this term to build a widely used [fastchat conversation](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) method for formatting chat messages prior to the creation of `chat_templates`.
:::
#### What are `chat_templates`?
The current most popular and convenient method for inference is to use `chat_templates` for formatting prompts. Axolotl supports using `chat_templates` for training to ensure that the model performs in the same environment as in inference.
Here's a quick rundown on `chat_template`: A `chat_template` is a Jinja2 template which formats a list of messages into a prompt.
An example of a prompt formatted into a popular template called ChatML can be seen below:
Single prompt (pretty-printed):
```json
{
"messages": [
{
"role": "user",
"content": "Hi"
},
{
"role": "assistant",
"content": "How can I help you?"
},
{
"role": "user",
"content": "Can you add 3+5?"
},
{
"role": "assistant",
"content": "The answer is 8."
}
]
}
```
The ChatML template is as follows:
```jinja2
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
```
The above prompt formatted into this template will result in:
```
<|im_start|>user
Hi<|im_end|>
<|im_start|>assistant
How can I help you?<|im_end|>
<|im_start|>user
Can you add 3+5?<|im_end|>
<|im_start|>assistant
The answer is 8.<|im_end|>
```
By using delimiters (`<|im_start|>` and `<|im_end|>`), a prompt separates different speakers which helps the model identify which portion belongs to whom.
#### Common Conversation Dataset formats
Older conversation datasets with the following format are colloquially called `sharegpt` datasets.
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
Newer conversation datasets usually follow the OpenAI format.
```json
{"messages": [{"role": "...", "content": "..."}]}
```
Axolotl supports both as well as allowing customization of any kind of key.
#### Chat Template Usage
To properly use this method, it is important to identify three things:
1. Which `chat_template` would you use?
2. What are the keys in your dataset, and what are the possible roles? For example, in OpenAI format, the keys would be `messages`, `role`, and `content`, respectively, whereas the possible roles are `system`, `user`, and `assistant`.
3. What do you want to mask? For instance, only assistant messages, only last message, or nothing.
##### Choosing a `chat_template`
There are a lot of `chat_templates` out there. Axolotl supports the common ones: [supported chat templates](https://github.com/axolotl-ai-cloud/axolotl/blob/860609392184cf62a7e0ca676658b170e059ce6c/src/axolotl/utils/chat_templates.py#L17). For example, to use ChatML, it would be `chat_template: chatml`.
However, it is also possible to use the already configured template within the tokenizer by specifying `chat_template: tokenizer_default`. If you want a fallback (in case some tokenizer does not have it pre-configured), you can do `chat_template: tokenizer_default_fallback_chatml` to fallback to the ChatML template if a tokenizer template was not found.
One last but powerful approach is to bring your own template. This can be set via:
```yaml
chat_template_jinja: # your template
```
##### Setting `chat_template` dataset keys
We currently default to OpenAI format for dataset keys, so if that's your current dataset format, there's nothing to do here.
If your dataset format is different, here are the keys you should check (with their defaults):
```yaml
datasets:
...
field_messages: messages # this should point to the key containing the list of conversations
message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template
role: role
content: content
```
In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
```yaml
datasets:
...
roles:
assistant:
- gpt
- model
user:
- human
```
In the example above, all `gpt` and `model` values are converted to `assistant`. All `human` values are converted to `user.`
##### Handling masking
The common use case for `chat_template` is for chat messages, therefore, it is common to mask all non-assistant messages. Assistant messages refer to the bot messages that you want the model to learn on.
To train on all `assistant` messages, you would set the following configs.
```yaml
datasets:
...
roles_to_train: ["assistant"]
train_on_eos: "turn"
```
The `train_on_eos` config means that it would mask all EOS tokens for turns that aren't assistant-turns. The other options are: `all` and `last` to choose which EOS to train on.
Perhaps, you want to train on `assistant` and `narrator` roles, you can simply add `narrator` to the list of `roles_to_train`. You would also need to add it to the mapping of `roles` above.
```yaml
datasets:
...
roles_to_train: ["assistant", "narrator"]
roles:
assistant:
- gpt
- model
user:
- human
narrator: ["narrator"]
```
::: {.callout-tip}
As chat_templates may use hardcoded EOS/EOT tokens that are different from the tokenizer's EOS, it is highly recommended to set them. For example, `ChatML` uses `<|im_end|>` to end turns.
```yaml
special_tokens:
eos_token: <|im_end|>
```
:::
##### Applying `chat_template`
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset.
```yaml
datasets:
- path: A.jsonl
type: chat_template
# step 1
chat_template: chatml
# step 2
field_messages: messages
message_property_mappings:
role: role
content: content
roles:
assistant:
- gpt
- model
- assistant
user:
- human
- user
# step 3
roles_to_train: ["assistant"]
train_on_eos: "turn"
special_tokens:
eos_token: <|im_end|>
```
If this config were to be applied to the sample dataset above, the output would look as such (which can be retrieved via `axolotl preprocess config.yaml --debug`):
```
<|im_start|>(-100, 128256) user(-100, 882)
(-100, 198) Hi(-100, 13347) <|im_end|>(-100, 128257)
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
(-100, 198) How(4438, 4438) can(649, 649) I(358, 358) help(1520, 1520) you(499, 499) ?(30, 30) <|im_end|>(128257, 128257)
(-100, 198) <|im_start|>(-100, 128256) user(-100, 882)
(-100, 198) Can(-100, 6854) you(-100, 499) add(-100, 923) (-100, 220) 3(-100, 18) +(-100, 10) 5(-100, 20) ?(-100, 30) <|im_end|>(-100, 128257)
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
(-100, 198) The(791, 791) answer(4320, 4320) is(374, 374) (220, 220) 8(23, 23) .(13, 13) <|im_end|>(128257, 128257)
(-100, 198)
```
The first number refers to the label, the second refers to the `token_id`. For example, `-100` labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the `token_id`.
::: {.callout-note}
If during `preprocess`, there are a lot of warnings of `Could not find content __ boundary`, please check the FAQ section for [chat_templates](../faq.qmd#chat-templates).
:::
#### Reference
Please see docs [here](conversation.qmd).
### Instruction Dataset
Instruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.
An example is of a common format called Alpaca:
```json
{"instruction": "...", "input": "...", "output": "..."}
```
Using those keys, a prompt can be built based on it.
```
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
{output}
```
This can be configured as such:
```yaml
datasets:
- path: A.jsonl
type: alpaca
```
Axolotl supports many kinds of instruction dataset. All of them can be found here (https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/inst_tune.html) with their respective type and sample row format.
Reference: [Instruction Dataset Documentation](inst_tune.qmd).
#### Custom Instruct Prompt Format
Due to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.
In the example below, a sample row is used to output in `mistral_v1` format.
```json
{"input": "...", "output": "..."}
```
```yaml
datasets:
- path: repo
type:
system_prompt: ""
field_system:
field_instruction: input
field_input:
field_output: output
# multi-line example with input
format: |-
[INST] {instruction} {input} [/INST]
# single-line example without input
no_input_format: "[INST] {instruction} [/INST]"
```
The config sets that the `field_instruction` is actually named `input`, and the `field_input` is empty as we don't have an `input` in this sample. Generally, `instruction` can be thought as the question to the model, and `input` as the additional information with `output` being the response. It is not necessary to have an `input` nor `system`. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.
Reference: [Custom Instruct Prompt Format Documentation](inst_tune.qmd#how-to-add-custom-prompt-format).
## Reinforcement Learning from Human Feedback (RLHF)
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF documentation](../rlhf.qmd) for more detail.

View File

@@ -27,7 +27,6 @@ pretraining_dataset:
type: pretrain
trust_remote_code:
skip: # number of rows of data to skip over from the beginning
...
```
:::

View File

@@ -1,7 +1,239 @@
---
title: Template-Free
description: Construct prompts without a template.
toc: true
toc-depth: 3
order: 4
---
See [these docs](../input_output.qmd).
## Background {#sec-background}
### Masking Inputs {#masking-inputs}
One of the most popular features of
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
setting the following configuration value:
```yaml
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
### You may not want prompt templates {#sec-you-may-not-want-prompt-templates}
However, there are many situations where you don't want to use one of
these formats or templates. This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can
quickly become footguns if you don't include them correctly at
inference time.
- Enforce a *chat* interface when you do not want one. Sometimes you
just want to fine-tune a model to a very specific task and do NOT
want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
### The `input_output` format {#sec-the-inputoutput-format}
You can construct your prompts without a template by using the
`input_output` format, by setting `type: input_output` in your
configuration file like this:
**config.yml**
```yaml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
```
Unlike `type: completion`, which is also template-free,
`type: input_output` allows you to mask segments of your text. More
details on how this works are described below.
## Usage {#sec-usage}
This is how you can use the `input_output` format:
### 1. Prepare Data {#sec-1-prepare-data}
To use the `input_output` format, collect your data in the following
format into a jsonl file (below is the first row from the file
`output`.jsonl` pretty printed):
```bash
$ head -n1 output.jsonl | python -m json.tool
```
:::{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
:::
Set `label:false` when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
> [!IMPORTANT]
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
concatenates all the segments as-is.** The tokenizer doesn't add
anything additional. Notice how I added spaces, newlines, `<s>`
(BOS), and `</s>` (EOS) myself.
> 2. Make sure you check the materialized output to validate that the
prompt is getting assembled how you like.
### 2. Use `type: input_output` {#sec-2-use-type-inputoutput}
Let's materialize data with our `output.jsonl` file by setting
`type: input_output` in our axolotl config:
```yaml
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
```
You can use the following command to materialize your data. The
`--debug` flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
```bash
axolotl preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```
The format is `decoded_token`(`label`, `token_id`), for example,
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
token_id is `1`. When the label is `-100` then that token is ignored for
training.
### 3. Check the prompts {#sec-3-check-the-prompts}
Here is another way to check the materialized output:
```python
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
```
```python
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ignored by comparing the labels
to each token:
```python
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
```
| token | label | id |
|-------|-------|-------|
| 0 | \<s\> | 1 |
| 1 | Hello | 22557 |
| 2 | \\n | 13 |
| 3 | hi | 12014 |
| 4 | there | 736 |
| 5 | ! | 28808 |
| 6 | . | 28723 |
| 7 | | 28705 |
| 8 | good | -100 |
| 9 | bye | -100 |
| 10 | | -100 |
| 11 | fare | 19111 |
| 12 | well | 5458 |
| 13 | \</s\>| 2 |
If we look at the input data, the above table seems correct! (The jsonl
version is repeated below for reference):
```bash
$ head -n1 output.jsonl | python -m json.tool
```
:::{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
:::

View File

@@ -3,8 +3,11 @@ title: Dataset Preprocessing
description: How datasets are processed
---
## Overview
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
the (dataset format)[../dataset-formats/] and prompt strategies to:
the [dataset format](docs/dataset-formats) and prompt strategies to:
- parse the dataset based on the *dataset format*
- transform the dataset to how you would interact with the model based on the *prompt strategy*
- tokenize the dataset based on the configured model & tokenizer
@@ -12,10 +15,12 @@ the (dataset format)[../dataset-formats/] and prompt strategies to:
The processing of the datasets can happen one of two ways:
1. Before kicking off training by calling `python -m axolotl.cli.preprocess /path/to/your.yaml --debug`
1. Before kicking off training by calling `axolotl preprocess config.yaml --debug`
2. When training is started
What are the benefits of pre-processing? When training interactively or for sweeps
### What are the benefits of pre-processing?
When training interactively or for sweeps
(e.g. you are restarting the trainer often), processing the datasets can oftentimes be frustratingly
slow. Pre-processing will cache the tokenized/formatted datasets according to a hash of dependent
training parameters so that it will intelligently pull from its cache when possible.
@@ -28,8 +33,12 @@ default path of `./last_run_prepared/`, but will ignore anything already cached
setting `dataset_prepared_path: ./last_run_prepared`, the trainer will use whatever pre-processed
data is in the cache.
What are the edge cases? Let's say you are writing a custom prompt strategy or using a user-defined
### What are the edge cases?
Let's say you are writing a custom prompt strategy or using a user-defined
prompt template. Because the trainer cannot readily detect these changes, we cannot change the
calculated hash value for the pre-processed dataset. If you have `dataset_prepared_path: ...` set
calculated hash value for the pre-processed dataset.
If you have `dataset_prepared_path: ...` set
and change your prompt templating logic, it may not pick up the changes you made and you will be
training over the old prompt.

View File

@@ -31,11 +31,13 @@ While debugging it's helpful to simplify your test scenario as much as possible.
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
```yaml
dataset:
datasets:
...
shards: 20
```
3. **Use a small model**: A good example of a small model is [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
4. **Minimize iteration time**: Make sure the training loop finishes as fast as possible, with these settings.
- `micro_batch_size: 1`
@@ -85,7 +87,7 @@ The easiest way to get started is to modify the [.vscode/launch.json](../.vscode
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```jsonc
```json
// .vscode/launch.json
{
"version": "0.2.0",
@@ -132,7 +134,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
Below is the [./vscode/tasks.json](../.vscode/tasks.json) file that defines the `cleanup-for-dataprep` task. This task is run before each debugging session when you use the above configuration. Note how there are two tasks that delete the two folders mentioned above. The third task `cleanup-for-dataprep` is a composite task that combines the two tasks. A composite task is necessary because VSCode does not allow you to specify multiple tasks in the `preLaunchTask` argument of the `launch.json` file.
```jsonc
```json
// .vscode/tasks.json
// this file is used by launch.json
{

View File

@@ -3,6 +3,7 @@ title: FAQ
description: Frequently asked questions
---
### General
**Q: The trainer stopped and hasn't progressed in several minutes.**
@@ -19,3 +20,33 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
**Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
**Q: `Empty template generated for turn ___`**
> A: The `content` is empty for that turn.
**Q: `Could not find content start/end boundary for turn __`**
> A: The specific turn's start/end could not be detected. Please ensure you have set the `eos_token` following your `chat_template`. Otherwise, this could be a `chat_template` which doesn't use proper boundaries for each turn (like system). On the rare occurrence, make sure your content is not `[[dummy_message]]`. Please let us know about this.
**Q: `Content end boundary is before start boundary for turn ___`**
> A: This is an edge case which should not occur. Please create an Issue if this happens.
**Q: `Content end boundary is the same as start boundary for turn ___. This is likely an empty turn.`**
> A: This is likely an empty turn.
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.

View File

@@ -1,5 +1,5 @@
---
title: "Getting Started with Axolotl"
title: "Quickstart"
format:
html:
toc: true
@@ -17,12 +17,12 @@ Let's start by fine-tuning a small language model using LoRA. This example uses
Assuming `axolotl` is installed (if not, see our [Installation Guide](installation.qmd))
1. Download example configs:
```shell
```bash
axolotl fetch examples
```
2. Run the training:
```shell
```bash
axolotl train examples/llama-3/lora-1b.yml
```
@@ -108,7 +108,7 @@ Please consult the supported [Dataset Formats](dataset-formats/) for more detail
3. Run the training:
```shell
```bash
axolotl train my_training.yml
```
@@ -118,7 +118,7 @@ axolotl train my_training.yml
After training, test your model:
```shell
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
```
@@ -126,7 +126,7 @@ axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
For large datasets, preprocess first:
```shell
```bash
axolotl preprocess my_training.yml
```
@@ -134,7 +134,7 @@ axolotl preprocess my_training.yml
Launch a Gradio interface:
```shell
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```

View File

@@ -1,11 +1,10 @@
---
title: "Inference Guide"
title: "Inference"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---

View File

@@ -3,263 +3,4 @@ title: Template-free prompt construction
description: "Template-free prompt construction with the `input_output` format"
---
<!-- TOC -->
- [Background](#background)
- [Masking Inputs](#masking-inputs)
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
- [The `input_output` format](#the-input_output-format)
- [Usage](#usage)
- [1. Prepare Data](#1-prepare-data)
- [2. Use `type: input_output`](#2-use-type-input_output)
- [3. Check the prompts](#3-check-the-prompts)
<!-- /TOC -->
<a id="markdown-background" name="background"></a>
## Background
<a id="markdown-masking-inputs" name="masking-inputs"></a>
### Masking Inputs
One of the most popular features of
[axolotl](https://github.com/axolotl-ai-cloud/axolotl) is
setting the following configuration value:
```yaml
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
### You may not want prompt templates
However, there are many situations where you don't want to use one of
these formats or templates. This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can
quickly become footguns if you don't include them correctly at
inference time.
- Enforce a *chat* interface when you do not want one. Sometimes you
just want to fine-tune a model to a very specific task and do NOT
want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
### The `input_output` format
You can construct your prompts without a template by using the
`input_output` format, by setting `type: input_output` in your
configuration file like this:
**config.yml**
```yaml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
```
Unlike `type: completion`, which is also template-free,
`type: input_output` allows you to mask segments of your text. More
details on how this works are described below.
<a id="markdown-usage" name="usage"></a>
## Usage
This is how you can use the `input_output` format:
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
### 1. Prepare Data
To use the `input_output` format, collect your data in the following
format into a jsonl file (below is the first row from the file
`output`.jsonl` pretty printed):
```bash
$ head -n1 output.jsonl | python -m json.tool
```
:::{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
:::
Set `label:false` when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
> [!IMPORTANT]
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
concatenates all the segments as-is.** The tokenizer doesn't add
anything additional. Notice how I added spaces, newlines, `<s>`
(BOS), and `</s>` (EOS) myself.
> 2. Make sure you check the materialized output to validate that the
prompt is getting assembled how you like.
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
### 2. Use `type: input_output`
Let's materialize data with our `output.jsonl` file by setting
`type: input_output` in our axolotl config:
```yaml
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
```
You can use the following command to materialize your data. The
`--debug` flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
```bash
$ python -m axolotl.cli.preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```
The format is `decoded_token`(`label`, `token_id`), for example,
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
token_id is `1`. When the label is `-100` then that token is ignored for
training.
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
### 3. Check the prompts
Here is another way to check the materialized output:
```python
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
```
```python
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ignored by comparing the labels
to each token:
```python
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
```
| token | label | id |
|-------|-------|-------|
| 0 | \<s\> | 1 |
| 1 | Hello | 22557 |
| 2 | \\n | 13 |
| 3 | hi | 12014 |
| 4 | there | 736 |
| 5 | ! | 28808 |
| 6 | . | 28723 |
| 7 | | 28705 |
| 8 | good | -100 |
| 9 | bye | -100 |
| 10 | | -100 |
| 11 | fare | 19111 |
| 12 | well | 5458 |
| 13 | \</s\>| 2 |
If we look at the input data, the above table seems correct! (The jsonl
version is repeated below for reference):
```bash
$ head -n1 output.jsonl | python -m json.tool
```
:::{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
:::
The documentation moved to [here](dataset-formats/template_free.qmd).

View File

@@ -1,11 +1,10 @@
---
title: "Installation Guide"
title: "Installation"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---

127
docs/lora_optims.qmd Normal file
View File

@@ -0,0 +1,127 @@
---
title: "LoRA Optimizations"
description: "Custom autograd functions and Triton kernels in Axolotl for optimized LoRA fine-tuning"
---
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to):
- `llama`
- `mistral`
- `qwen2`
- `gemma`
- `gemma2`
<details>
The set of models we support is currently limited by our attention patching strategy,
which assumes (and replaces) specific code blocks for query / key / value and output
projections:
```python
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
```
Is replaced with:
```python
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
PATCHED_O_CODE = """
attn_output = self.apply_o(attn_output)
""".lstrip(
"\n"
)
```
Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.
We welcome testing of other model architectures and / or PRs to expand our patching
logic to be compatible with more of them.
</details>
## Usage
These optimizations can be enabled in your Axolotl config YAML file. The
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
`lora_o_kernel` enable the fused query-key-value projection and optimized output
projection, respectively.
```yaml
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
```
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
- Targeted LoRA adapters cannot use Dropout
- This may limit model expressivity / cause overfitting
- Targeted LoRA adapters cannot have bias terms
- This may limit model expressivity
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
be re-finetuned without these features in order to be useful.
## Implementation details
### Custom autograd functions
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
LoRA and base weight computations together and provides a single, efficient backward
pass for the entire MLP block.
For attention components, similar optimizations are provided through a function that
handles the query, key, and value projections, and a function that handles the output
projection. They are designed to work with the existing `transformers` attention
implementation via some monkey-patching logic.
### Triton kernels
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
improved speed and memory performance. These kernels handle both the forward and
backward passes.
### Integration
The custom autograd functions and Triton kernels are designed to work together. The
autograd function manages the high-level computation flow and gradient tracking, while
calling the Triton kernels for the activation function computation. During the backward
pass, the kernel computes both the activation output and the required gradients, which
the autograd function then uses to compute the final gradients for the entire
computation path.
## Future Work
- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Additional operator fusions

View File

@@ -19,4 +19,5 @@ Current support:
- [ ] DeepSpeed
Untested:
- FSDP

View File

@@ -1,5 +1,5 @@
---
title: "Multi-GPU Training Guide"
title: "Multi-GPU"
format:
html:
toc: true
@@ -35,7 +35,11 @@ deepspeed: deepspeed_configs/zero1.json
### Usage {#sec-deepspeed-usage}
```{.bash}
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
# Passing arg via config
axolotl train config.yml
# Passing arg via cli
axolotl train config.yml --deepspeed deepspeed_configs/zero1.json
```
### ZeRO Stages {#sec-zero-stages}
@@ -70,25 +74,7 @@ For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
### Liger Kernel Integration {#sec-liger}
::: {.callout-note}
Liger Kernel provides efficient Triton kernels for LLM training, offering:
- 20% increase in multi-GPU training throughput
- 60% reduction in memory usage
- Compatibility with both FSDP and DeepSpeed
:::
Configuration:
```{.yaml}
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
Please see [docs](custom_integrations.qmd#liger) for more info.
## Troubleshooting {#sec-troubleshooting}

View File

@@ -3,6 +3,18 @@ title: Multi Node
description: How to use Axolotl on multiple machines
---
The below are three ways to train multi-node in Axolotl.
::: {.callout-important}
Each machine needs a copy of Axolotl, we suggest using the same commit to ensure compatibility.
You will also need to have the same configuration file for your model on each machine.
Make sure the main machine is reachable by other machines.
:::
## Accelerate
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
~/.cache/huggingface/accelerate/default_config.yaml
@@ -26,7 +38,7 @@ tpu_use_sudo: false
use_cpu: false
```
Configure your model to use FSDP with for example:
Configure your model to use FSDP in the Axolotl yaml. For example:
```yaml
fsdp:
- full_shard
@@ -37,12 +49,40 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Machine configuration
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
You will also need to have the same configuration file for your model on each machine.
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
## Raytrain
Please see ray train doc [here](ray-integration.qmd).
## Torchrun
If you are using Infiniband, we recommend torchrun to utilize the full bandwidth.
Set the following env (change buffersize/socketname depending on your system):
```bash
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
export NCCL_BUFFSIZE=2097152
```
Run the following on each node:
```bash
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
```
Please make sure to substitute the placeholder variables.
- `num_nodes`: Number of nodes (containing GPUs)
- `gpu_per_node`: Number of gpus per node
- `head_node_ip`: IP of the head node (make sure other machines can connect to this)
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
- `rdzv_id`: A unique job ID that is used by the job across nodes.
::: {.callout-note}
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
:::
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)

View File

@@ -13,13 +13,13 @@ Often, this timeout will happen after 30 minutes (the default setting) and is ac
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
```shell
```bash
nvidia-smi nvlink --status
```
To force NCCL to use NVLink, simply set this in the environment:
```shell
```bash
export NCCL_P2P_LEVEL=NVL
```
@@ -33,13 +33,13 @@ If NVLink is not available in your environment there are other options for ``NCC
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
```shell
```bash
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
```
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
```shell
```bash
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
export TORCH_DISTRIBUTED_DEBUG=INFO

View File

@@ -1,5 +1,5 @@
---
title: Ray Train integration
title: Ray Train
description: How to use Axolotl with Ray Train
---
@@ -9,7 +9,7 @@ With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](htt
## Ray cluster setup
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs [here](https://docs.ray.io/en/latest/cluster/getting-started.html).
Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).
@@ -58,13 +58,11 @@ You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.
The key parameters to note here are:
```yaml
...
use_ray: true
ray_num_workers: 4
# optional
resources_per_worker:
GPU: 1
...
```
- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.

View File

@@ -1,26 +1,39 @@
---
title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-depth: 4
---
### Overview
## Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
feedback. Various methods include, but not limited to:
- [Direct Preference Optimization (DPO)](#dpo)
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- Direct Preference Optimization (DPO)
- Identity Preference Optimization (IPO)
### RLHF using Axolotl
## RLHF using Axolotl
>[!IMPORTANT]
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
::: {.callout-important}
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
:::
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
::: {.callout-tip}
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
:::
### DPO
Example config:
#### DPO
```yaml
rl: dpo
datasets:
@@ -32,12 +45,265 @@ datasets:
type: chatml
```
#### IPO
DPO supports the following types with the following dataset format:
#### chatml.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
```
#### chatml.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
#### chatml.icr
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
#### chatml.intel
```json
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
```
#### chatml.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
#### chatml.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
#### llama3.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
```
#### llama3.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
#### llama3.icr
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
#### llama3.intel
```json
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
```
#### llama3.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
#### llama3.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
#### zephyr.nectar
```json
{
"prompt": "...",
"answers": [
{
"answer": "...",
"rank": 1
},
{
"answer": "...",
"rank": 2
}
// ... more answers with ranks
]
}
```
#### chat_template.default
```yaml
rl: dpo
datasets:
- path: ...
split: train
type: chat_template.default
field_messages: "messages"
field_chosen: "chosen"
field_rejected: "rejected"
message_property_mappings:
role: role
content: content
roles:
user: ["user"]
assistant: ["assistant"]
system: ["system"]
```
Sample input format:
```json
{
"messages": [
{
"role": "system",
"content": "..."
},
{
"role": "user",
"content": "..."
},
// ... more messages
],
"chosen": {
"role": "assistant",
"content": "..."
},
"rejected": {
"role": "assistant",
"content": "..."
}
}
```
#### user_defined.default
For custom behaviors,
```yaml
rl: dpo
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"
```
The input format is a simple JSON input with customizable fields based on the above config.
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
### IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here.
```yaml
rl: ipo
```
#### ORPO
### ORPO
Paper: https://arxiv.org/abs/2403.07691
@@ -52,8 +318,28 @@ datasets:
type: chat_template.argilla
```
ORPO supports the following types with the following dataset format:
#### KTO
#### chat_template.argilla
```json
{
"system": "...", // optional
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### KTO
```yaml
rl: kto
@@ -72,7 +358,186 @@ gradient_checkpointing_kwargs:
use_reentrant: true
```
#### Using local dataset files
KTO supports the following types with the following dataset format:
#### chatml.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
```
#### chatml.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."}
],
"completion": [
{"role": "assistant", "content": "..."}
]
}
```
#### chatml.intel
```json
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
```
#### chatml.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
#### chatml.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
#### llama3.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
```
#### llama3.argilla_chat
```json
{
"completion": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
#### llama3.intel
```json
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
```
#### llama3.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
#### llama3.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
#### user_defined.default
For custom behaviors,
```yaml
rl: kto
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"
```
The input format is a simple JSON input with customizable fields based on the above config.
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "...",
"label": "..."
}
```
### GRPO
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
```python
# rewards.py
import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
```
```yaml
rl: grpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: True
vllm_device: auto
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}'
```
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
### Using local dataset files
```yaml
datasets:
- ds_type: json
@@ -82,9 +547,9 @@ datasets:
type: chatml.intel
```
#### Trl autounwrap for peft
### TRL auto-unwrapping for PEFT
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
```yaml
# load ref model when adapter training.

View File

@@ -3,6 +3,12 @@ title: "PyTorch ao"
description: "Custom data types and layouts for training and inference"
---
To use experimental optimizers (`AdamWFp8`, `AdamW4bit`, `AdamW8bit`) from Pytorch Ao, please install the package as shown below.
::: {.callout-tip}
Some experimental optimizers are already present in regular Pytorch, so please re-check if you actually need this package!
:::
### Installation
Stable Release from the PyTorch index

View File

@@ -8,6 +8,12 @@ description: "Hyper-optimized QLoRA finetuning for single GPUs"
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
standard industry baselines.
::: {.callout-important}
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
:::
### Installation
@@ -17,7 +23,7 @@ The following will install the correct unsloth and extras from source.
python scripts/unsloth_install.py | sh
```
### Using unsloth w Axolotl
### Usage
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.

View File

@@ -21,8 +21,9 @@ datasets:
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -16,8 +16,9 @@ datasets:
type: chat_template
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -13,8 +13,9 @@ datasets:
type: chat_template
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -17,8 +17,9 @@ datasets:
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.02

View File

@@ -17,8 +17,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system

View File

@@ -14,8 +14,9 @@ datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
user:
- user

View File

@@ -17,8 +17,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system
@@ -31,8 +32,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system

View File

@@ -0,0 +1,82 @@
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
# Currently, we don't support dropout with our custom Triton kernels
# lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
# These options enable our custom Triton kernels / autograd
# functions for MLP and attention calculations
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -22,8 +22,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -14,8 +14,9 @@ datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
user:
- user

View File

@@ -12,8 +12,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system

View File

@@ -1,7 +1,7 @@
---
toc-location: right-body
toc-title: Table Of Contents
toc-expand: 2
# toc-location: right-body
# toc-title: Table Of Contents
# toc-expand: 2
---
```{python}

View File

@@ -1,24 +1,24 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.1
bitsandbytes==0.45.2
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
flash-attn==2.7.4.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.2
liger-kernel==0.5.3
# END section
packaging==23.2
peft==0.14.0
transformers==4.48.1
transformers==4.49.0
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.13.0
trl==0.15.1
optimum==1.16.2
hf_transfer
@@ -26,7 +26,7 @@ sentencepiece
gradio==3.50.2
modal==0.70.5
pydantic==2.6.3
pydantic==2.10.6
addict
fire
PyYAML>=6.0

View File

@@ -31,27 +31,26 @@ def parse_dataset(dataset=None, split="train"):
ds_cfg["field_messages"] = field_messages
message_fields = features[field_messages][0].keys()
message_field_role = None
message_property_mappings = {"role": None, "content": None}
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
message_property_mappings["role"] = key
break
if not message_field_role:
if not message_property_mappings["role"]:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
message_property_mappings["content"] = key
break
if not message_field_content:
if not message_property_mappings["content"]:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
ds_cfg["message_property_mappings"] = message_property_mappings
print(yaml.dump({"datasets": [ds_cfg]}))

View File

@@ -71,12 +71,15 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.append("xformers>=0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
@@ -122,7 +125,7 @@ setup(
},
extras_require={
"flash-attn": [
"flash-attn==2.7.0.post2",
"flash-attn==2.7.4.post1",
],
"deepspeed": [
"deepspeed==0.16.1",
@@ -153,5 +156,8 @@ setup(
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
],
},
)

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.6.0"
__version__ = "0.8.0.dev0"

View File

@@ -35,13 +35,18 @@ def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
cwd=None,
**kwargs,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate)
local_dirs = {}
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
local_dirs = {"/workspace/mounts": cwd}
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
def do_cli_lm_eval(

View File

@@ -7,6 +7,7 @@ import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
from typing import Optional
import modal
@@ -22,8 +23,18 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
del new_env["PYTHONPATH"]
paths = ["/workspace/mounts"]
for sub_python_path_str in new_env["PYTHONPATH"].split(":"):
sub_python_path = Path(sub_python_path_str)
if not sub_python_path.joinpath("src", "axolotl").exists():
# we don't want to use the automounted axolotl or unexpected behavior happens
paths.append(str(sub_python_path))
if paths:
new_env["PYTHONPATH"] = ":".join(paths)
else:
del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
if exit_code := subprocess.call( # nosec B603
@@ -112,8 +123,6 @@ class ModalCloud(Cloud):
if env := self.get_env():
image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image
def get_secrets(self):
@@ -203,9 +212,12 @@ class ModalCloud(Cloud):
memory = int(self.config.memory)
return 1024 * memory
def get_train_env(self):
def get_train_env(self, local_dirs=None):
image = self.get_image()
for mount, local_dir in (local_dirs or {}).items():
image = image.add_local_dir(local_dir, mount)
return self.app.function(
image=self.get_image(),
image=image,
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=16.0,
gpu=self.get_train_gpu(),
@@ -214,14 +226,21 @@ class ModalCloud(Cloud):
secrets=self.get_secrets(),
)
def train(self, config_yaml: str, accelerate: bool = True):
modal_fn = self.get_train_env()(_train)
def train(
self,
config_yaml: str,
accelerate: bool = True,
local_dirs: Optional[dict[str, str]] = None,
**kwargs,
):
modal_fn = self.get_train_env(local_dirs)(_train)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
def lm_eval(self, config_yaml: str):
@@ -239,44 +258,41 @@ class ModalCloud(Cloud):
def _preprocess(config_yaml: str, volumes=None):
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_folder = "/workspace/mounts"
run_cmd(
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
"axolotl preprocess /workspace/mounts/config.yaml --dataset-processes=8",
run_folder,
volumes,
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_folder = "/workspace/mounts"
if accelerate:
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
run_folder,
volumes,
)
def _lm_eval(config_yaml: str, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_folder = "/workspace/mounts"
run_cmd(
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
"axolotl lm-eval /workspace/mounts/config.yaml",
run_folder,
volumes,
)

View File

@@ -1,135 +0,0 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import PluginManager
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
LinearLlamaConfig,
)
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
LinearLlamaForCausalLM,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
from axolotl.utils.trainer import setup_trainer
LOG = logging.getLogger(__name__)
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Convert attention to linear attention and perform attention transfer via distillation.
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
cfg.load_in_8bit = False
cfg.load_in_4bit = False
cfg.adapter = None
# load model
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
# freeze model
for p in model.parameters():
p.requires_grad = False
# convert to linear llama
linear_llama_config = LinearLlamaConfig.from_llama(
model.config, cfg.attention_config
)
model = LinearLlamaForCausalLM.from_llama(
model, config=linear_llama_config, train_attention=True
)
# set save_path, save tokenizer and model config.
save_path = str(os.path.join(cfg.output_dir, "distilled"))
tokenizer.save_pretrained(save_path)
if hasattr(model, "config"):
model.config.save_pretrained(save_path)
# Get datasets
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# toggle attention to be trainable
model.toggle_attention(train=True)
# Setup trainer
trainer = setup_trainer(
cfg=cfg,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=(model, None, None),
tokenizer=tokenizer,
processor=None,
total_num_steps=total_num_steps,
)
# train
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
# drop base_attention + remove training attn
model.toggle_attention(train=False)
model.remove_base_attention()
# NOTE: If in peft mode, consider whether to auto-merge
# save model
safe_serialization = cfg.save_safetensors is True
# NOTE: may need to consider other ways of saving due to multi-gpu etc
model.save_pretrained(save_path, safe_serialization=safe_serialization)
# cleanup
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_train`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# load cfg, force linearize and add plugin to linearize
parsed_cfg = load_cfg(
config,
linearize=True,
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
**kwargs,
)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_linearize(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -2,19 +2,19 @@
# pylint: disable=redefined-outer-name
import logging
import random
import os
import subprocess # nosec B404
import tempfile
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import Optional
import click
import yaml
from dotenv import load_dotenv
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
@@ -27,76 +27,6 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
def generate_sweep_configs(base_config, sweeps_config):
"""
Recursively generates all possible configurations by applying sweeps to the base config.
Args:
base_config (dict): The original configuration dictionary
sweeps_config (dict): Dictionary where keys are parameters and values are either:
- lists of values to sweep independently
- or for paired values, a list of dicts under the '_' key
Returns:
list: List of all possible configuration dictionaries
Example:
sweeps_config = {
'learning_rate': [0.1, 0.01],
'_': [
{'load_in_8bit': True, 'adapter': 'lora'},
{'load_in_4bit': True, 'adapter': 'qlora'}
]
}
"""
# Separate paired values from regular sweeps
paired_values = sweeps_config.get("_", [])
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
# Process regular sweeps
param_names = list(regular_sweeps.keys())
param_values = list(regular_sweeps.values())
# Generate combinations for regular sweeps
regular_combinations = list(product(*param_values)) if param_values else [()]
# Combine regular sweeps with paired values
all_combinations = []
for reg_combo in regular_combinations:
if paired_values:
for paired_set in paired_values:
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
else:
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
# randomize the order of trials
random.seed(42)
random.shuffle(all_combinations)
# Generate a new config for each combination
result_configs = []
for combination in all_combinations:
new_config = deepcopy(base_config)
for param_name, param_value in combination.items():
new_config[param_name] = param_value
result_configs.append(new_config)
return result_configs
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
@@ -165,7 +95,6 @@ def train(
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
@@ -199,7 +128,16 @@ def train(
try:
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
from axolotl.cli.cloud import do_cli_train
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud,
config=config,
accelerate=True,
cwd=cwd,
**kwargs,
)
else:
accelerate_args = []
if "main_process_port" in kwargs:
@@ -208,7 +146,7 @@ def train(
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append("--num_processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
@@ -220,7 +158,11 @@ def train(
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
from axolotl.cli.cloud import do_cli_train
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
else:
from axolotl.cli.train import do_cli
@@ -381,4 +323,5 @@ def main():
if __name__ == "__main__":
load_dotenv()
main()

77
src/axolotl/cli/sweeps.py Normal file
View File

@@ -0,0 +1,77 @@
"""Utilities for handling sweeps over configs for axolotl train CLI command"""
import random
from copy import deepcopy
from itertools import product
def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, list]]:
"""
Recursively generates all possible configurations by applying sweeps to the base config.
Args:
base_config (dict): The original configuration dictionary
sweeps_config (dict): Dictionary where keys are parameters and values are either:
- lists of values to sweep independently
- or for paired values, a list of dicts under the '_' key
Returns:
list: List of all possible configuration dictionaries
Example:
sweeps_config = {
'learning_rate': [0.1, 0.01],
'_': [
{'load_in_8bit': True, 'adapter': 'lora'},
{'load_in_4bit': True, 'adapter': 'qlora'}
]
}
"""
# Separate paired values from regular sweeps
paired_values = sweeps_config.get("_", [])
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
# Process regular sweeps
param_names = list(regular_sweeps.keys())
param_values = list(regular_sweeps.values())
# Generate combinations for regular sweeps
regular_combinations = list(product(*param_values)) if param_values else [()]
# Combine regular sweeps with paired values
all_combinations = []
for reg_combo in regular_combinations:
if paired_values:
for paired_set in paired_values:
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
else:
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
# randomize the order of trials
random.seed(42)
random.shuffle(all_combinations)
# Generate a new config for each combination
result_configs = []
for combination in all_combinations:
new_config = deepcopy(base_config)
for param_name, param_value in combination.items():
new_config[param_name] = param_value
result_configs.append(new_config)
return result_configs

View File

@@ -122,9 +122,11 @@ def load_preference_datasets(
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int(
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl == "grpo":
total_num_steps = None
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")

View File

@@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers.base import (
AxolotlCPOTrainer,
AxolotlDPOTrainer,
AxolotlKTOTrainer,
AxolotlMambaTrainer,
AxolotlORPOTrainer,
@@ -48,9 +47,11 @@ from axolotl.core.trainers.base import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlDPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
AxolotlPRMConfig,
@@ -329,6 +330,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs[
"include_tokens_per_second"
] = self.cfg.include_tokens_per_second
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
@@ -641,9 +648,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
tokenizer=self.tokenizer,
)
if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
@@ -652,7 +656,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
@@ -965,10 +969,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
@@ -977,6 +982,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
@@ -1001,11 +1007,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "grpo":
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
@@ -1016,11 +1026,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[
"use_logits_to_keep"
] = self.cfg.dpo_use_logits_to_keep
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
max_steps = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
max_steps=max_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
warmup_steps=self.cfg.warmup_steps,
@@ -1047,8 +1067,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()
trainer_cls_args = [self.model, self.model_ref]
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
@@ -1063,12 +1088,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
if "tokenizer" in sig.parameters.keys():
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
else:
dpo_trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()
):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]

View File

@@ -5,30 +5,21 @@ module for customized trainers
from __future__ import annotations
# pylint: disable=too-many-lines
import gc
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Any, Dict, Literal, Optional, Union
from typing import Dict, Literal, Optional
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import (
CPOTrainer,
DPOTrainer,
KTOTrainer,
ORPOTrainer,
PRMTrainer,
RewardTrainer,
)
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.monkeypatch.relora import ReLoRAScheduler
@@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers

View File

@@ -0,0 +1,33 @@
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
class DPOStrategy:
"""
Strategy for DPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlDPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
return AxolotlDPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
training_args_kwargs = {}
if cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
return training_args_kwargs

View File

@@ -0,0 +1,15 @@
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
from trl import DPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""

View File

@@ -0,0 +1,125 @@
"""
DPO trainer for axolotl
"""
import gc
from functools import wraps
from typing import Any, Dict, Union
import torch
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.base import (
SchedulerMixin,
_sanitize_kwargs_for_ds_tagging,
_sanitize_kwargs_for_tagging,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
# pylint: disable=duplicate-code
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
# pylint: disable=duplicate-code
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss

View File

@@ -0,0 +1,119 @@
"""
GRPO Specific Strategy for training
"""
import importlib
import inspect
import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
LOG = logging.getLogger("axolotl")
class GRPOStrategy:
"""
Strategy for GRPO training
"""
@classmethod
def get_trainer_class(cls):
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(cls):
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
return AxolotlGRPOConfig
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
return grpo_args_kwargs
@classmethod
def set_trainer_args(cls, cfg):
trainer_args = []
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
for reward_func_fqn in cfg.trl.reward_funcs:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_args.append(reward_funcs)
return trainer_args
@classmethod
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
return trainer_kwargs
@classmethod
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
# No data collation is needed in GRPO, handled by trl's trainer __init__
return None
@classmethod
def get_blocklist_args_kwargs(cls):
return ["dataset_num_proc"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
"""
Returns the reward function from the given fully qualified name, or the path to the reward function model.
Args:
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
or a HF hub path to the reward model.
Raises:
ValueError: If the reward function does not accept at least two arguments.
Returns:
RewardFunc: A callable that accepts prompts and completions and returns rewards,
or a path to a reward model.
"""
try:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_fqn.split(".")[-1]
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
reward_func = getattr(reward_func_module, reward_func_module_name)
if not len(inspect.signature(reward_func).parameters) >= 2:
raise ValueError(
"Reward function must accept at least two arguments: prompts: list and completions: list"
)
return reward_func
except ModuleNotFoundError:
# the user has passed a string (ideally indicating the path of a reward model)
LOG.info(
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
)
return reward_func

View File

@@ -0,0 +1,15 @@
"""
Axolotl Specific Training Args
"""
from dataclasses import dataclass
from trl import GRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@dataclass
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""

View File

@@ -0,0 +1,108 @@
"""
Axolotl GRPO trainer
"""
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation
from axolotl.core.trainers.base import SchedulerMixin
# mypy: ignore-errors
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing:
# Ensure use_cache is disabled
if hasattr(self.model, "config"):
self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"):
self.model.gradient_checkpointing_enable()
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
return model
# pylint: enable=unused-argument,redefined-builtin
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""

View File

@@ -1,6 +1,10 @@
# Cut Cross Entropy
### Usage
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
See https://github.com/apple/ml-cross-entropy
## Usage
```yaml
plugins:
@@ -8,3 +12,19 @@ plugins:
cut_cross_entropy: true
```
## Citation
```bib
@article{wijmans2024cut,
author = {Erik Wijmans and
Brody Huval and
Alexander Hertzberg and
Vladlen Koltun and
Philipp Kr\"ahenb\"uhl},
title = {Cut Your Losses in Large-Vocabulary Language Models},
journal = {arXiv},
year = {2024},
url = {https://arxiv.org/abs/2411.09009},
}
```

View File

@@ -2,7 +2,7 @@
See https://github.com/ironjr/grokfast
### Usage
## Usage
```yaml
plugins:
@@ -11,3 +11,14 @@ plugins:
grokfast_alpha: 2.0
grokfast_lamb: 0.98
```
## Citation
```bib
@article{lee2024grokfast,
title={{Grokfast}: Accelerated Grokking by Amplifying Slow Gradients},
author={Lee, Jaerin and Kang, Bong Gyun and Kim, Kihoon and Lee, Kyoung Mu},
journal={arXiv preprint arXiv:2405.20233},
year={2024}
}
```

View File

@@ -0,0 +1,23 @@
# Knowledge Distillation
## Usage
```yaml
plugins:
- "axolotl.integrations.kd.KDPlugin"
kd_trainer: True
kd_ce_alpha: 0.1
kd_alpha: 0.9
kd_temperature: 1.0
torch_compile: True # torch>=2.5.1, recommended to reduce vram
datasets:
- path: ...
type: "axolotl.integrations.kd.chat_template"
field_messages: "messages_combined"
logprobs_field: "llm_text_generation_vllm_logprobs" # for kd only, field of logprobs
```
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)

View File

@@ -1,58 +0,0 @@
### AXOLOTL COMMUNITY LICENSE AGREEMENT
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
and conditions set forth in this Agreement.
1. Definitions
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
which may be licensed separately by their respective authors and/or licensors.
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
permits Plugin Integrations to integrate with the Axolotl service.
2. Grant of License
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
- Licensee must comply with all the terms and conditions of this Agreement.
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
portions of the Software.
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
3. Restrictions
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
third parties to fine-tune artificial intelligence models.
3.2 Licensee shall not:
- Use the Software for any illegal or unauthorized purpose.
- Reverse engineer, decompile, or disassemble the Software.
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
Software or interfere with any third-party use of the Software.
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
4. Intellectual Property Rights
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
Licensee.
5. Disclaimer of Warranty
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
6. Termination
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
copies in its possession.
7. Governing Law
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
without regards to conflicts of laws provisions thereof.
8. Entire Agreement
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
Licensees continued use of the Software after any such updates shall constitute acceptance of updated terms
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
bound by the terms and conditions of this Agreement.
This Agreement was last updated on August 23, 2024.

View File

@@ -1,14 +1,16 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
loss for top_k KL divergence

View File

@@ -0,0 +1,36 @@
# Liger Kernel Integration
Liger Kernel provides efficient Triton kernels for LLM training, offering:
- 20% increase in multi-GPU training throughput
- 60% reduction in memory usage
- Compatibility with both FSDP and DeepSpeed
See https://github.com/linkedin/Liger-Kernel
## Usage
```yaml
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
## Citation
```bib
@article{hsu2024ligerkernelefficienttriton,
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
year={2024},
eprint={2410.10989},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.10989},
journal={arXiv preprint arXiv:2410.10989},
}
```

View File

@@ -1,6 +1,10 @@
# LM Eval Harness
### Usage
Run evaluation on model using the popular lm-evaluation-harness library.
See https://github.com/EleutherAI/lm-evaluation-harness
## Usage
```yaml
plugins:
@@ -10,4 +14,22 @@ lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
## Citation
```bib
@misc{eval-harness,
author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
title = {A framework for few-shot language model evaluation},
month = 07,
year = 2024,
publisher = {Zenodo},
version = {v0.4.3},
doi = {10.5281/zenodo.12608602},
url = {https://zenodo.org/records/12608602}
}
```

View File

@@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,44 +0,0 @@
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
https://github.com/HazyResearch/lolcats/
### Usage
Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
```bash
cd src/axolotl/integrations/lolcats/linear_llama/csrc
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
# nano setup.py
# Build the CUDA kernel
python setup.py install
```
Step 1:
```yaml
plugins:
- axolotl.integrations.lolcats.LinearizePlugin
linearize: true
```
Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
```yaml
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
# with optional config below but this requires patching axolotl
# to allow this config to work with lora
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
```
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
Step 3: Run inference on the finetuned model
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`

View File

@@ -1,43 +0,0 @@
"""
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
Low-rank Linear Conversion via Attention Transfer
"""
import logging
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
DistillAttentionXentMSETrainer,
)
from .args import LinearAttentionArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.lolcats")
class LinearizePlugin(BasePlugin):
"""
Plugin for lolcats integration with Axolotl.
"""
def __init__(self):
super().__init__()
# Register the Linear Llama model with transformers
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
register_linear_llama,
)
register_linear_llama()
def get_input_args(self):
return "axolotl.integrations.lolcats.LinearAttentionArgs"
def get_trainer_cls(self, cfg):
# defualt to XentMSE
# TODO: add check to allow MSE_linear
if cfg.linearize:
return DistillAttentionXentMSETrainer
return None

View File

@@ -1,47 +0,0 @@
"""
Module for handling linear attention input arguments.
"""
from typing import Optional
from pydantic import BaseModel
class FeatureMapKwargs(BaseModel):
"""Args for feature map"""
eps: float
mlp: Optional[None] = None
fullspace: bool
class LearnedKernelKwargs(BaseModel):
"""Args for learned kernel"""
feature_dim: int
skip_connection: bool
bias: bool
zero_init: bool
class AttentionConfig(BaseModel):
"""Args for attention config"""
attention_type: str
feature_map: str
feature_map_kwargs: FeatureMapKwargs
layer_idx: Optional[None] = None
learned_kernel: str
learned_kernel_kwargs: LearnedKernelKwargs
tie_qk_kernels: bool
train_qk: bool
class LinearAttentionArgs(BaseModel):
"""
Input args for linear attention
"""
attention_config: AttentionConfig
linearize: Optional[bool] = False

View File

@@ -1,90 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linear LLaMA model configuration"""
from typing import Optional
from transformers import LlamaConfig
class LinearLlamaConfig(LlamaConfig):
"""
This is the configuration class to store the configuration of a [`LinearLlamaModel`].
It is a modified LlamaConfig that includes additional parameters for linear attention.
Args:
attention_config (`dict`):
Dictionary containing the configuration for linear attention mechanism.
Expected contents:
`attention_type` (str):
The type of attention to convert to.
`feature_map` (`str`):
The type of feature map to use for linear attention.
`feature_map_kwargs` (`dict`):
Additional arguments for the feature map.
`learned_kernel` (`str`, *optional*):
Type of learned kernel to use, if any.
`learned_kernel_kwargs` (`dict`, *optional*):
Additional arguments for the learned kernel.
`tie_qk_kernels` (`bool`, *optional*, defaults to False):
Whether to tie query and key kernels.
`rotary_config` (`dict`, *optional*):
Configuration for rotary embeddings.
`train_attention` (`bool`, *optional*, defaults to False):
Whether to train attention to match softmax attention.
`remove_base_attn` (`bool`, *optional*, defaults to True):
Whether to remove base attention after initialization.
`mask_value` (`int`, *optional*, defaults to 0):
Value to use for masking.
`eps` (`float`, *optional*, defaults to 1e-12):
Epsilon value for numerical stability.
`fp32_attention` (`bool`, *optional*, defaults to False):
Whether to use fp32 precision for attention computation.
`track_state_grads` (`bool`, *optional*, defaults to False):
Whether to track gradients of attention states.
**kwargs:
Additional arguments inherited from LlamaConfig.
"""
model_type = "linear_llama"
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
super().__init__(**kwargs)
# Set auto_map
self.auto_map = {
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
}
# Set default attention config if none provided
self.attention_config = attention_config or {"attention_type": "softmax"}
@classmethod
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
"""
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
Args:
llama_config (:class:`~transformers.LlamaConfig`):
The LlamaConfig to inherit from.
attention_config (`dict`):
Dictionary containing the configuration for linear attention mechanism.
"""
return cls(attention_config=attention_config, **llama_config.to_dict())

View File

@@ -1,30 +0,0 @@
# Causal linear attention CUDA kernel
Usage:
```bash
cd src/axolotl/integrations/lolcats/linear_llama/csrc
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
# nano setup.py
# Build the CUDA kernel
python setup.py install
```
Reference: https://github.com/idiap/fast-transformers/
```bib
@inproceedings{katharopoulos_et_al_2020,
author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
year = {2020}
}
@article{vyas_et_al_2020,
author={Vyas, A. and Katharopoulos, A. and Fleuret, F.},
title={Fast Transformers with Clustered Attention},
booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)},
year={2020}
}
```

View File

@@ -1,6 +0,0 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
from .causal_attention import causal_dot_product

View File

@@ -1,225 +0,0 @@
//
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
// Apoorv Vyas <avyas@idiap.ch>
//
#include <torch/extension.h>
/**
* Compute a*b^T and save it into out.
*
* a \in R^A
* b \in R^B
*/
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
for (int i=0; i<A; i++) {
float * bi = b;
for (int j=0; j<B; j++) {
*out += (*a) * (*bi);
out++;
bi++;
}
a++;
}
}
/**
* Implement a vector matrix product v*m and save it into out.
*
* v \in R^A
* m \in R^{AxB}
*/
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
// TODO: Consider removing the zeroing part and assuming out already
// contains 0s
for (int i=0; i<B; i++) {
out[i] = 0;
}
for (int i=0; i<A; i++) {
float *oi = out;
for (int j=0; j<B; j++) {
*oi += (*v) * (*m);
oi++;
m++;
}
v++;
}
}
/**
* Implement a vector transposed-matrix product and save it into out.
*
* v \in R^B
* m \in R^{AxB}
*/
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
for (int i=0; i<A; i++) {
float *vi = v;
float s = 0;
for (int j=0; j<B; j++) {
s += (*vi) * (*m);
vi++;
m++;
}
// TODO: Should we be aggregating? See the comment on vm_dot.
*out = s;
out++;
}
}
/**
* Compute the causally masked dot products of queries, keys and values.
*
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
* computation is done efficiently by changing the order of the dot products.
*/
void causal_dot_product(
const torch::Tensor queries,
const torch::Tensor keys,
const torch::Tensor values,
torch::Tensor product
) {
// Extract some shapes
int N = queries.size(0);
int H = queries.size(1);
int L = queries.size(2);
int E = queries.size(3);
int M = values.size(3);
// Create accessors for all the arguments
auto qa = queries.accessor<float, 4>();
auto ka = keys.accessor<float, 4>();
auto va = values.accessor<float, 4>();
auto pa = product.accessor<float, 4>();
#pragma omp parallel for collapse(2)
for (int n=0; n<N; n++) {
for (int h=0; h<H; h++) {
auto kv = torch::zeros({E, M}, queries.options());
float *kvp = kv.data_ptr<float>();
for (int l=0; l<L; l++) {
vvt_dot(
&ka[n][h][l][0],
&va[n][h][l][0],
kvp,
E,
M
);
vm_dot(
&qa[n][h][l][0],
kvp,
&pa[n][h][l][0],
E,
M
);
}
}
}
}
/**
* Compute the gradients of queries, keys and values given the gradient of the
* causal_dot_product output.
*
* Make sure that everything is computed in O(N D^2) complexity.
*/
void causal_dot_backward(
const torch::Tensor queries,
const torch::Tensor keys,
const torch::Tensor values,
const torch::Tensor grad_out,
torch::Tensor grad_queries,
torch::Tensor grad_keys,
torch::Tensor grad_values
) {
// Extract some shapes
int N = queries.size(0);
int H = queries.size(1);
int L = queries.size(2);
int E = queries.size(3);
int M = values.size(3);
// Create accessors for all the arguments
auto qa = queries.accessor<float, 4>();
auto ka = keys.accessor<float, 4>();
auto va = values.accessor<float, 4>();
auto ga = grad_out.accessor<float, 4>();
auto gqa = grad_queries.accessor<float, 4>();
auto gka = grad_keys.accessor<float, 4>();
auto gva = grad_values.accessor<float, 4>();
#pragma omp parallel for collapse(2)
for (int n=0; n<N; n++) {
for (int h=0; h<H; h++) {
auto kv = torch::zeros({E, M}, queries.options());
float *kvp = kv.data_ptr<float>();
// Compute the gradient wrt the queries
for (int l=0; l<L; l++) {
vvt_dot(
&ka[n][h][l][0],
&va[n][h][l][0],
kvp,
E,
M
);
vmt_dot(
&ga[n][h][l][0],
kvp,
&gqa[n][h][l][0],
E,
M
);
}
// Compute the gradient wrt the keys and values
kv.zero_();
for (int l=L-1; l>=0; l--) {
vvt_dot(
&qa[n][h][l][0],
&ga[n][h][l][0],
kvp,
E,
M
);
vmt_dot(
&va[n][h][l][0],
kvp,
&gka[n][h][l][0],
E,
M
);
vm_dot(
&ka[n][h][l][0],
kvp,
&gva[n][h][l][0],
E,
M
);
}
}
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"causal_dot_product",
&causal_dot_product,
"Compute the weighted sum of values but attending only to previous "
"values."
);
m.def(
"causal_dot_backward",
&causal_dot_backward,
"Compute the gradient of queries, keys and values given the gradient "
"of causal_dot_product."
);
}

View File

@@ -1,67 +0,0 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
import torch
try:
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
except ImportError as e:
print(e)
causal_dot_product_cuda = causal_dot_backward_cuda = None
class CausalDotProduct(torch.autograd.Function):
"""Compute the weighted sum of values but attending only to previous
values."""
dot = {
# "cpu": causal_dot_product_cpu,
"cuda": causal_dot_product_cuda
}
dot_backward = {
# "cpu": causal_dot_backward_cpu,
"cuda": causal_dot_backward_cuda
}
@staticmethod
def forward(ctx, Q, K, V):
# Save the inputs for the gradient computation
ctx.save_for_backward(Q, K, V)
# Create the output tensor
device = Q.device
N, H, L, _ = Q.shape
_, _, _, M = V.shape
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
# Actually perform the dot product
CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
# breakpoint()
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
return product
@staticmethod
def backward(ctx, grad_out):
# Extract the saved tensors
Q, K, V = ctx.saved_tensors
# Allocate memory for the gradients
grad_Q = torch.zeros_like(Q)
grad_K = torch.zeros_like(K)
grad_V = torch.zeros_like(V)
# Actually compute the gradients
CausalDotProduct.dot_backward[Q.device.type](
Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V
)
return grad_Q, grad_K, grad_V
# Alias the autograd functions to python style snake case naming
causal_dot_product = CausalDotProduct.apply

View File

@@ -1,65 +0,0 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
import subprocess # nosec
import torch
from setuptools import setup
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
def get_last_arch_torch():
arch = torch.cuda.get_arch_list()[-1]
print(f"Found arch: {arch} from existing torch installation")
return arch
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
arch = get_last_arch_torch()
sm_num = arch[-2:]
cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
setup(
name="causal_attention_cuda_cpp",
ext_modules=[
CUDAExtension(
"causal_attention_cuda",
[
# 'causal_attention.cpp',
"causal_attention_cuda.cu",
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(
["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag
),
},
)
],
cmdclass={"build_ext": BuildExtension},
)

View File

@@ -1,856 +0,0 @@
"""
Linear attention classes
"""
import copy
from typing import Any, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
# Causal linear attention dot product CUDA kernel from fast-transformers
try:
from csrc import causal_dot_product as fast_causal_dot_product
except ImportError:
fast_causal_dot_product = None
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
# -------------------
# Attention functions
# -------------------
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
Causal linear attention dot product
- If available, use CUDA kernel from fast-transformers
"""
if fast_causal_dot_product is None:
kv = torch.einsum("bhlf,bhld->bhlfd", k, v)
return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2))
return fast_causal_dot_product(q, k, v)
def linear_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
fp32_attention: bool = False,
eps: float = 1e-12,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Compute linear attention with CUDA kernel implementation from fast-transformers
- https://github.com/idiap/fast-transformers
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
v is shape (b, h, l, head_dim)
"""
dtype = q.dtype
# Causal mask already applied
y = causal_dot_product(
q.contiguous().to(dtype=torch.float32),
k.contiguous().to(dtype=torch.float32),
v.contiguous().to(dtype=torch.float32),
)
if fp32_attention:
y = (
y
/ (
torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps
)[..., None]
).to(dtype=dtype)
else:
y = y.to(dtype=dtype)
k = k.float().cumsum(dim=2).to(dtype=dtype)
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
return y, None, None
def softmax_attention(
q: torch.Tensor,
k: torch.Tensor,
v: Optional[torch.Tensor] = None,
causal: bool = True,
fp32_attention: bool = True,
):
"""
Standard softmax attention; only compute outputs if v is not None
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
"""
y = None
a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5)
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
n - m + 1
)
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
if fp32_attention:
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
else:
a = torch.softmax(a, dim=-1)
if v is not None:
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
return y, a, None
def quadratic_attention(
q: torch.Tensor,
k: torch.Tensor,
v: Optional[torch.Tensor] = None,
causal: bool = True,
fp32_attention: bool = False,
eps: float = 1e-12,
):
"""
Compute attention with feature maps by instantiating L x L matrix of attention weights
-> Use for attention distillation
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
"""
y = None
dtype = q.dtype
if fp32_attention:
q, k = q.float(), k.float()
a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could
if causal: # Apply causal mask
m, n = a.shape[-2:]
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
n - m + 1
)
a = a.masked_fill(causal_mask, 0)
# Normalize to compute attention
a = a / (a.sum(dim=-1, keepdim=True) + eps)
a = a.to(dtype=dtype) if fp32_attention else a
if torch.isnan(a).sum() > 0:
breakpoint()
if v is not None:
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
return y, a, None
# ---------------------
# Attention layer class
# ---------------------
class LolcatsLinearAttention(nn.Module):
"""
LoLCATs attention implementation initialized from a
`LlamaAttention` or `MistralAttention` object (base_attn)
Most of the arguments are directly tied to argparse args
- For now we don't support padding.
"""
def __init__(
self,
base_attn: nn.Module, # like LlamaAttention
feature_map: str,
feature_map_kwargs: dict,
layer_idx: Optional[int] = None,
max_layer_idx: Optional[int] = None,
learned_kernel: Optional[str] = None,
learned_kernel_kwargs: Optional[dict] = None,
tie_qk_kernels: Optional[bool] = False,
rotary_config: Optional[dict] = None,
train_attention: Optional[bool] = False,
remove_base_attn: bool = True,
attention_type: Optional[str] = "lolcats_llama",
mask_value: int = 0,
eps: float = 1e-12,
fp32_attention: bool = False,
track_state_grads: bool = False,
rank: Optional[int] = 0,
**kwargs,
) -> None:
super().__init__()
self.base_config = getattr(base_attn, "config", None)
if self.base_config is not None:
self.base_config = self.base_config.to_dict()
self.attention_type = attention_type
self.mask_value = mask_value
self.eps = eps
self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx
self.max_layer_idx = max_layer_idx
self.tie_qk_kernels = tie_qk_kernels
self.train_attention = train_attention
self.base_inference = False
self.fp32_attention = fp32_attention
self.track_state_grads = track_state_grads
if rank == 0: # multi-gpu
if fp32_attention and layer_idx == 0:
print(f"-> fp32_attention is {fp32_attention}")
if layer_idx == 0 and feature_map_kwargs is not None:
for k, v in feature_map_kwargs.items():
print(f"-> {k}: {v}")
if layer_idx == 0 and learned_kernel_kwargs is not None:
for k, v in learned_kernel_kwargs.items():
print(f"-> {k}: {v}")
self.remove_base_attn = remove_base_attn
self.init_weights_(base_attn, remove_base_attn)
self.init_feature_map_(
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
)
def init_feature_map_(
self,
feature_map: str,
feature_map_kwargs: dict,
learned_kernel: Optional[str] = None,
learned_kernel_kwargs: Optional[dict] = None,
):
"""
Initialize MLP-based feature map
"""
self.fmap_gqa = False # Turn True if specified below
if learned_kernel is not None and learned_kernel_kwargs is not None:
# Ensure dict
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
learned_kernel_kwargs["num_heads"] = self.num_heads
learned_kernel_kwargs["head_dim"] = self.head_dim
learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype
learned_kernel_kwargs["device"] = self.q_proj.weight.device
# Create MLP
mlp_learned_kernel = init_learned_kernel(
learned_kernel, **learned_kernel_kwargs
)
# Add "activation"; see src.models.feature_map.py
self.feature_map_q = init_feature_map(
name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs
)
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
self.feature_map_k = self.feature_map_q
else:
self.feature_map_k = copy.deepcopy(self.feature_map_q)
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
"""
Initialize module layers, weights, positional dependencies, etc.
from original softmax attention layer (base_attn)
"""
# Make other attributes accessible
self.attention_dropout = 0 # We don't use dropout
self.hidden_size = base_attn.config.hidden_size
self.num_heads = base_attn.config.num_attention_heads
self.head_dim = base_attn.head_dim
self.num_key_value_heads = base_attn.config.num_key_value_heads
self.num_key_value_groups = base_attn.num_key_value_groups
self.q_shape = [self.num_heads, self.head_dim]
self.k_shape = [self.num_key_value_heads, self.head_dim]
self.v_shape = [self.num_key_value_heads, self.head_dim]
# Copy original model projection layers
self.q_proj = base_attn.q_proj
self.k_proj = base_attn.k_proj
self.v_proj = base_attn.v_proj
self.o_proj = base_attn.o_proj
try: # If wanting to use FA2 for ground-truth inference
self._flash_attn_uses_top_left_mask = (
base_attn._flash_attn_uses_top_left_mask
)
except AttributeError:
pass
if self.remove_base_attn or remove_base_attn:
del base_attn # We don't need to keep these around
else:
self.base_attn = base_attn # For some training runs helpful to just call
def process_qkv(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[Any] = None,
):
"""
Compute queries, keys, and values
"""
b, l, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
kv_seq_len = k.shape[-2]
# Shape is (batch_size, seq_len, num_heads, head_dim)
q = q.view(b, l, *self.q_shape).transpose(1, 2)
k = k.view(b, l, *self.k_shape).transpose(1, 2)
v = v.view(b, l, *self.v_shape).transpose(1, 2)
if (
past_key_value is not None
): # and k.shape[2] > q.shape[2]: # e.g., when generating
past_key_value.window_size = getattr(
self, "decode_window_size", None
) # self.decode_window_size
if isinstance(
past_key_value, Cache
): # In Transformers v4.36+ this is a DynamicCache object
kv_seq_len += past_key_value.get_usable_length(
kv_seq_len, self.layer_idx
)
else:
kv_seq_len += past_key_value[0].shape[-2]
# Apply rotary embeddings
if position_embeddings is not None:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin)
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
return q, k, v, kv_seq_len
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[Any] = None, # "legacy" cache approach
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
- Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_embeddings, past_key_value
)
if self.base_inference:
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
y_true, _, _ = softmax_attention(q, k, v, causal=True)
y_true = (
y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
attn_weights = (None, None)
elif self.train_attention: # Distilling / learning attentions
# Note for now we assume no padding when distilling; attention masks only enforce causality
assert (
output_attentions is True
), f"When training feature maps, output_attentions should be True but is {output_attentions}"
with torch.no_grad():
# 1. Compute "ground-truth" attention output and weights
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
y_true = (
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention (just weights)
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
attn_weights = ( # type: ignore
(attn_pred, attn_true),
(y_pred, _y_true),
) # Save both attention weights so we can supervise.
else: # Finetuning
q, k = self.feature_map_q(q), self.feature_map_k(k)
# Apply prefill mask
if attention_mask is not None and q.shape[2] > 1:
if len(attention_mask.shape) == 4:
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][
..., None
] # b, 1, k_len, 1
else:
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
k = k.masked_fill(~lin_attn_mask, 0)
if past_key_value is not None: # Initialize states
if len(past_key_value.kv_states) == self.layer_idx:
b, h, _, f = k.shape
past_key_value.kv_states.append(
torch.zeros(
b, h, f, self.head_dim, dtype=q.dtype, device=q.device
)
)
past_key_value.k_states.append(
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
)
# Generating
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
assert use_cache is True
kv_state, k_state = past_key_value.update(
k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention
)
if self.fp32_attention:
q = q.float()
y_true = (
torch.einsum("bhlf,bhfd->bhld", q, kv_state.float())
/ torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[
..., None
]
).to(dtype=k.dtype)
else:
y_true = (
torch.einsum("bhlf,bhfd->bhld", q, kv_state)
/ torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None]
)
else:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
y_true, _, _ = linear_attention(
q, k, v, self.fp32_attention, self.eps
) # Ordinarily the states are ignored
past_key_value.update(
k.detach(),
v.detach(),
self.layer_idx,
accumulate_in_fp32=self.fp32_attention,
)
# doing some unnecessary recomputation here
else:
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
attn_weights = None
return y_true, attn_weights
class LinearAttentionState(Cache):
"""
Handle the KV and K states for linear attention
- Adopts HF Transformers `past_key_values` convention
- Inherits from `Cache` class
- Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self) -> None:
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""
Returns the sequence length of the cached states. A layer index can be optionally passed.
"""
if layer_idx is None:
raise ValueError("Layer index must not be None")
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
self._seen_tokens_by_layer.append(0)
return self._seen_tokens_by_layer[layer_idx]
def get_max_length(self) -> Optional[int]:
"""
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
"""
return None
def get_usable_length(
self, new_seq_length: int, layer_idx: Optional[int] = 0
) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = True,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
if layer_idx is None:
raise ValueError("Layer index must not be None")
with torch.no_grad():
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
key_states, value_states = key_states.float(), value_states.float()
kv_state = torch.einsum(
"bhlf,bhld->bhfd", key_states, value_states
).detach()
k_state = key_states.sum(
dim=-2, keepdim=True
).detach() # b, h, 1, f; note the 1
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
print(
"if len(self.k_states) <= layer_idx: # Initializing kv and k states"
)
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
else:
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def to_legacy_cache(self):
"""Hack, but just return self"""
return self
def reorder_cache(self, beam_idx: torch.LongTensor):
"""
Reorders the cache for beam search, given the selected beam indices.
-> Copied from transformers/src/transformers/cache_utils.py
"""
raise NotImplementedError(
"Reordering cache not implemented for LinearAttentionState"
)
# -------------------
# feature map functions
# -------------------
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
if name == "softmax_dim" and fullspace:
return SoftmaxDim(**kwargs)
elif name == "softmax_dim" and not fullspace:
return SoftmaxDimHalfspace(**kwargs)
elif name == "exp_dim" and fullspace:
return Exp(**kwargs)
elif name == "exp_dim" and not fullspace:
return ExpHalfspace(**kwargs)
elif name == "pos_elu":
return PosELU(**kwargs)
elif name == "relu":
return ReLU(**kwargs)
else:
raise NotImplementedError
def init_learned_kernel(name: str, **kwargs):
"""
Initialize feature map MLP for linear attention
"""
if name == "untied_head_einsum":
return FeatureMapMLP(**kwargs)
elif name == "untied_head_adapter":
return FeatureMapAdapter(**kwargs)
else:
raise NotImplementedError
class FeatureMap(nn.Module):
"""
Final 'activation' of feature map. Can probably be combined with
`FeatureMapMLP` below
Full feature map is like f(xW + b)
-> This is the `f` part
"""
def __init__(
self,
activation_name: str,
head_dim_idx: int = -1,
eps: float = 1e-12,
mlp: Optional[nn.Module] = None,
fullspace: bool = True,
):
super().__init__()
self.head_dim_idx = head_dim_idx
self.eps = eps
self.mlp = mlp if mlp is not None else nn.Identity()
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
"""
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
def q_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
def k_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
# -----------------------
# Feature map activations
# -----------------------
class FeatureMapAct(nn.Module):
"""
Base class for feature map activations
"""
def __init__(self, eps: float = 1e-12):
super().__init__()
self.eps = eps
def forward(self, x: torch.Tensor, *args, **kwargs):
"""
x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return x
class PosELU(FeatureMapAct):
"""
1 + ELU activation as in https://arxiv.org/abs/2006.16236
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return (1 + F.elu(x)).clamp(min=self.eps)
class ReLU(FeatureMapAct):
"""
ReLU activation as in https://arxiv.org/abs/2103.13076
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return F.relu(x).clamp(min=self.eps)
class SoftmaxDim(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.cat(
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
).clamp(min=self.eps)
class SoftmaxDimHalfspace(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.softmax(x, dim=-1).clamp(min=self.eps)
class Exp(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
x_min = torch.amin(x, dim=-1, keepdim=True)
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
min=self.eps
)
class ExpHalfspace(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
return torch.exp(x - x_max).clamp(min=self.eps)
# ----------------
# Feature map MLPs
# ----------------
class FeatureMapMLP(nn.Module):
"""
Learnable MLP in feature map.
Full feature map is like f(xW + b)
-> This is the `W` and (optional) `b` part
"""
def __init__(
self,
num_heads: int,
head_dim: int, # input dim
feature_dim: int, # output dim
dtype: torch.dtype,
device: torch.device,
skip_connection: bool = False,
bias: bool = False,
zero_init: bool = False,
normal_init: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.feature_dim = feature_dim
self.dtype = dtype
self.device = device
self.skip_connection = skip_connection
self.bias = bias
self.zero_init = zero_init
self.normal_init = normal_init
self.init_weights_()
if self.zero_init: # Zero-out weights or set as identity post-initialization
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
if self.normal_init:
with torch.no_grad():
nn.init.normal_(self.layer)
if self.skip_connection:
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
assert self.head_dim == self.feature_dim, assertion_fail
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
self.layer = nn.Parameter(
torch.zeros(
(self.num_heads, self.head_dim, self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.layer)
if self.bias:
self.bias = nn.Parameter(
torch.zeros(
(1, self.num_heads, 1, 1), # self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.bias)
else:
self.bias = 0.0 # hack
def zero_init_with_skip_(self):
"""
Initialize weights to zero matrix if skip connection
"""
with torch.no_grad():
nn.init.zeros_(self.layer)
def zero_init_(self):
"""
Initialize weights to identity matrix if no skip connection
"""
with torch.no_grad():
for i in range(self.layer.shape[0]):
try:
nn.init.eye_(self.layer[i])
except RuntimeError:
with torch.no_grad():
dtype = self.layer[i].dtype
weight = torch.eye(
*self.layer[i].shape,
requires_grad=self.layer[i].requires_grad,
device=self.layer[i].device,
)
self.layer[i] = weight.to(dtype=dtype)
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
"""
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
return x + _x if self.skip_connection else _x
class FeatureMapAdapter(FeatureMapMLP):
"""
Learnable Feature map with bottleneck adapter
as in https://arxiv.org/abs/1902.00751
We don't use but could be fun to try
"""
def __init__(self, hidden_dim: int, *args, **kwargs):
kwargs["skip_connection"] = True
kwargs["bias"] = True
kwargs["zero_init"] = True
self.hidden_dim = hidden_dim
super().__init__(*args, **kwargs)
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
kwargs = {"dtype": self.dtype, "device": self.device}
self.layer0 = nn.Parameter(
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
)
self.layer1 = nn.Parameter(
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.layer0)
nn.init.kaiming_uniform_(self.layer1)
self.bias0 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
)
self.bias1 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.bias0)
nn.init.kaiming_uniform_(self.bias1)
def zero_init_with_skip_(self):
with torch.no_grad():
nn.init.zeros_(self.layer0)
nn.init.zeros_(self.layer1)
nn.init.zeros_(self.bias0)
nn.init.zeros_(self.bias1)
def zero_init_(self):
raise NotImplementedError
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
-> Down-project, apply nonlinearity, up-project; add skip connection
"""
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
_x = F.relu(_x)
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
return x + _x if self.skip_connection else _x

View File

@@ -1,460 +0,0 @@
"""
Subquadratic attention combining sliding window and linear attentions
- Using "standard" sliding windows
- Didactically computes outputs with n^2 attention weights for now
- Copied + adapted from linear_window_attention_tk.py for single-file reference
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
softmax_attention,
)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(
window_size: int, q_len: int, k_len: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
k_len - q_len
)
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
k_len - q_len - window_size
)
window_mask = causal_mask - linear_mask
# Return softmax mask (window), linear attention mask
# -> shapes broadcast over (b, h, q_len, k_len)
return window_mask[None, None, ...], linear_mask[None, None, ...]
def hybrid_attention_quadratic(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(
window_size, q.shape[-2], k.shape[-2], q.device
)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
if (
kv_state is not None and k_state is not None
): # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln += (
linear_factor
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
)
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ---------------------
# Attention layer class
# ---------------------
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(
self,
window_size: int = 64,
decode_window_size: Optional[int] = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs,
):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
super().__init__(**kwargs)
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw'
# Determine how we compute attentions
self.quadratic_attention = hybrid_attention_quadratic
self.attention_type = kwargs[
"attention_type"
] # 'hedgehog_long_llama_window_sw'
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
else:
self.register_buffer(
"window_factors",
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
k
) # Have to do after repeat for grouped-query attn if we use same fmap
if self.train_attention:
# 1. Compute "ground-truth" attention output and weights
with torch.no_grad():
_y_true, a_true = softmax_attention(q, k, v)[:2]
y_true = (
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention outputs
# compute attn weights under sliding window
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
else:
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if (
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
): # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
# Softmax attention terms
a_sm = torch.einsum(
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum(
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
)[..., None]
)
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, _ = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k,
v,
self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True,
)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, attn_weights, past_key_value
class LinearAttentionSlidingWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
grad_enabled: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
if fmap_key_states is None:
raise ValueError("fmap_key_states must not be None")
if layer_idx is None:
raise ValueError("Layer index must not be None")
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, : -self.window_size],
value_states[:, :, : -self.window_size],
)
# KV state
kv_state = decode_kv_state + torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, -self.window_size :],
value_states[:, :, -self.window_size :],
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
dim=-2, keepdim=True
)
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
dim=-2, keepdim=True
)
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size :, :])
self.v_cache.append(
value_states[:, :, -self.window_size :, :].to(dtype)
)
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (
self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state
).to(dtype)
decode_k_state = (
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(
self,
keys: torch.Tensor,
values: torch.Tensor,
layer_idx: int,
feature_map_k: Callable,
dtype: torch.dtype,
):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
# else:
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum(
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
).to(
dtype
) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat(
[k_cache[:, :, 1:, :], keys], dim=-2
)
self.v_cache[layer_idx] = torch.cat(
[v_cache[:, :, 1:, :], values], dim=-2
)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)

View File

@@ -1,685 +0,0 @@
"""
Subquadratic attention combining sliding window and linear attentions
- Using "standard" sliding windows
- Didactically computes outputs with n^2 attention weights for now
- Copied + adapted from linear_window_attention_tk.py for single-file reference
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
import logging
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
try:
from transformers.modeling_flash_attention_utils import _flash_attention_forward
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# Causal linear attention dot product CUDA kernel from fast-transformers
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
causal_dot_product,
)
LOG = logging.getLogger(__name__)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(
window_size: int, q_len: int, k_len: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
max(k_len - q_len, 0)
)
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
max(k_len - q_len, 0) - window_size
)
window_mask = causal_mask - linear_mask
# Return softmax mask (window), linear attention mask
# -> shapes broadcast over (b, h, q_len, k_len)
return window_mask[None, None, ...], linear_mask[None, None, ...]
def hybrid_attention_quadratic(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(
window_size, q.shape[-2], k.shape[-2], q.device
)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
if (
kv_state is not None and k_state is not None
): # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln += (
linear_factor
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
)
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ------------------------------
# Hybrid window attention linear
# ------------------------------
def under_window_linear_attention(
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_size: int,
linear_factor: torch.Tensor,
eps: float = 1e-12,
):
"""Compute hybrid window attention dot product with linear complexity in q_len"""
dtype = f_q.dtype
w = window_size
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
qkv = linear_factor * causal_dot_product(
f_q.contiguous().to(dtype=torch.float32),
f_k.contiguous().to(dtype=torch.float32),
v.contiguous().to(dtype=torch.float32),
).to(dtype=dtype)
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
sum_qk[sum_qk == 0] += eps
return qkv, sum_qk
def sliding_window_softmax_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
window_size: int,
window_factor: torch.Tensor,
mask_value: float = -1e8,
):
"""
Compute sliding window softmax attention without materializing
O(seq_len^2) attention weights
"""
d = q.shape[-1]
# Compute windows for keys
window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
# Compute windowed_softmax(qk); causal in its construction
a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5)
a_sm[a_sm == 0] = -torch.finfo(
q.dtype
).max # heuristic for zeroing out padding above
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
def hybrid_attention_linear(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: Optional[torch.Tensor] = None,
linear_factor: Optional[torch.Tensor] = None,
window_size: int = 64,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Alternative hybrid attention combining sliding window and linear attentions
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
"""
# window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
if window_factor is None:
raise ValueError("window_factor must be provided")
if linear_factor is None:
raise ValueError("linear_factor must be provided")
# 1. Sliding window (softmax attention)
with torch.no_grad():
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(
q, k, v, window_size, window_factor, mask_value
)
# 2. Under window (linear attention)
qkv_ln, sum_qk_ln = under_window_linear_attention(
f_q, f_k, v, window_size, linear_factor, eps
)
# 3. Combine
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
return y, None
# ---------------------
# Attention layer class
# ---------------------
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(
self,
window_size: int = 64,
decode_window_size: Optional[int] = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs,
):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
super().__init__(**kwargs)
# Determine how we compute attentions
self.linear_attention = hybrid_attention_linear
self.attention_type = "lolcats_llama_window_sw"
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
else:
self.register_buffer(
"window_factors",
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
if self.train_attention and self.base_inference:
with torch.no_grad():
_y_true = flash_attention_2(
self, # self.base_attn,
hidden_states=hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
use_cache=False,
)[0]
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
y_true = _y_true.reshape(b, l, -1).contiguous()
y_true = self.o_proj(y_true)
# layer_io = (hidden_states, _y_true) # hack
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
return y_true, layer_io, None
else:
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
k
) # Have to do after repeat for grouped-query attn if we use same fmap
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, a_pred = self.linear_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if (
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
): # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
# Softmax attention terms
a_sm = torch.einsum(
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum(
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
)[..., None]
)
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, _ = self.linear_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k,
v,
self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True,
)
# Concatenate heads and apply output projection
_y_true = y_true.transpose(1, 2).contiguous()
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
if self.train_attention:
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
return y_true, attn_weights, past_key_value
class LinearAttentionSlidingWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
grad_enabled: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
if fmap_key_states is None:
raise ValueError("fmap_key_states must not be None")
if layer_idx is None:
raise ValueError("Layer index must not be None")
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, : -self.window_size],
value_states[:, :, : -self.window_size],
)
# KV state
kv_state = decode_kv_state + torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, -self.window_size :],
value_states[:, :, -self.window_size :],
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
dim=-2, keepdim=True
)
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
dim=-2, keepdim=True
)
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size :, :])
self.v_cache.append(
value_states[:, :, -self.window_size :, :].to(dtype)
)
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (
self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state
).to(dtype)
decode_k_state = (
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(
self,
keys: torch.Tensor,
values: torch.Tensor,
layer_idx: int,
feature_map_k: Callable,
dtype: torch.dtype,
):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
# else:
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum(
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
).to(
dtype
) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat(
[k_cache[:, :, 1:, :], keys], dim=-2
)
self.v_cache[layer_idx] = torch.cat(
[v_cache[:, :, 1:, :], values], dim=-2
)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)
# -----------------
# Flash Attention 2
# -----------------
def flash_attention_2(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
"""
Wrapper for LlamaFlashAttention2
Copied and modified from HF Transformers v4.36 and v4.43 implementations
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
try: # As in Transformers v4.36
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
except Exception: # As in Transformers v4.39
cos, sin = self.rotary_emb(key_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
LOG.debug(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self, "_flash_attention_forward", False):
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=True,
)
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0, # dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
return attn_output, past_key_value

View File

@@ -1,24 +0,0 @@
"""
LoLCATs attention combining sliding window and linear attentions
- Using standard sliding window arrangement
- Training over long sequences with fixed memory with recurrent view
- During attention transfer, use Flash Attention to compute softmax attention outputs
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
from .linear_window_attention_sw import hybrid_attention_quadratic
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(self, remove_base_attn=True, **kwargs):
# keep self.base_attn for Flash Attention inference
super().__init__(remove_base_attn=True, **kwargs)
self.quadratic_attention = hybrid_attention_quadratic

View File

@@ -1,466 +0,0 @@
"""
Subquadratic attention combining sliding window and linear attentions
- Using the TK "terracing" arrangement
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
import math
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
softmax_attention,
)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(
window_size: int, q_len: int, k_len: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
win_len = window_size
m = math.ceil(max(q_len, k_len) / window_size)
# Creates an n x n mask where n = window_size^2
mask = torch.block_diag(
*[
torch.ones(
(win_len, win_len),
)
]
* m
)
mask += torch.roll(mask, -win_len, -1) # this adds the terracing
if mask.shape[0] > q_len:
mask = mask[-q_len:]
if mask.shape[1] > k_len:
mask = mask[:, -k_len:]
# Return softmax mask (window), linear attention mask
mask = mask[None, None, ...] # b, h, q_len, k_len
return (
torch.tril(mask).to(device=device, dtype=torch.int),
torch.tril(1 - mask).to(device=device, dtype=torch.int),
)
def hybrid_attention_quadratic(
q: torch.Tensor,
k: torch.Tensor,
f_q: torch.Tensor,
f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: Optional[torch.Tensor] = None,
k_state: Optional[torch.Tensor] = None,
eps: float = 1e-12,
mask_value: float = -1e8,
):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(
window_size, q.shape[-2], k.shape[-2], q.device
)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
if (
kv_state is not None and k_state is not None
): # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln += (
linear_factor
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
)
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ---------------------
# Attention layer class
# ---------------------
class LolcatsTKWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(
self,
window_size: int = 64,
decode_window_size: Optional[int] = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs,
):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
super().__init__(**kwargs)
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk'
# Determine how we compute attentions
self.quadratic_attention = hybrid_attention_quadratic
self.attention_type = kwargs[
"attention_type"
] # 'hedgehog_long_llama_window_tk'
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
else:
self.register_buffer(
"window_factors",
init_window_factor
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
self.window_factor = self.window_factors # legacy naming support
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
k
) # Have to do after repeat for grouped-query attn if we use same fmap
if self.train_attention:
# 1. Compute "ground-truth" attention output and weights
with torch.no_grad():
_y_true, a_true = softmax_attention(q, k, v)[:2]
y_true = (
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention outputs
# compute attn weights under sliding window
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
else:
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if (
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
): # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
# Softmax attention terms
a_sm = torch.einsum(
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum(
"bhld,bhnd->bhl", f_q.float(), f_k_state.float()
)[..., None]
)
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_true, _ = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k,
v,
self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True,
)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, attn_weights, past_key_value
class LinearAttentionTKWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: Optional[int] = None,
cache_kwargs: Optional[Any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
grad_enabled: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
if fmap_key_states is None:
raise ValueError("fmap_key_states should not be None")
if layer_idx is None:
raise ValueError("layer_idx should not be None")
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, : -self.window_size],
value_states[:, :, : -self.window_size],
)
# KV state
kv_state = decode_kv_state + torch.einsum(
"bhlf,bhld->bhfd",
fmap_key_states[:, :, -self.window_size :],
value_states[:, :, -self.window_size :],
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
dim=-2, keepdim=True
)
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
dim=-2, keepdim=True
)
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size :, :])
self.v_cache.append(
value_states[:, :, -self.window_size :, :].to(dtype)
)
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
dtype
)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
dtype
)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (
self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state
).to(dtype)
decode_k_state = (
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(
self,
keys: torch.Tensor,
values: torch.Tensor,
layer_idx: int,
feature_map_k: Callable,
dtype: torch.dtype,
):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum(
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
).to(
dtype
) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat(
[k_cache[:, :, 1:, :], keys], dim=-2
)
self.v_cache[layer_idx] = torch.cat(
[v_cache[:, :, 1:, :], values], dim=-2
)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)

View File

@@ -1,219 +0,0 @@
"""
LoLCATs + ThunderKittens linear attention + sliding window for generation
"""
import logging
from typing import Any, Callable, List, Optional
import torch
import torch.nn.functional as F
from .linear_attention import LinearAttentionState
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
LOG = logging.getLogger(__name__)
try:
from thunderkittens import hedgehog as tk_window_hedgehog_attention
LOG.debug("Successfully imported ThunderKittens for TK window attention")
except ImportError:
LOG.debug("Failed to import ThunderKittens for TK window attention")
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
def __init__(self, *args, window_size: int = 64, **kwargs):
super().__init__(*args, **kwargs)
self.train_attention = False
self.base_inference = False
self.window_size = 64 # hard-coded support for TK kernel
self.decode_window_size = 64
b, h, l, d = 1, 32, 8192, 128
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device="cuda")
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device="cuda")
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device="cuda")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Any] = None, # “legacy” cache approach
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
assert (
past_key_value is not None
), "past_key_value must be provided for generation"
assert (
self.train_attention is False
), "train_attention is not supported for generation"
assert (
self.base_inference is False
), "base_inference is not supported for generation"
assert use_cache is True, "use_cache must be True for generation"
past_key_value.window_size = self.decode_window_size
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill
f_q = self.feature_map_q(q)
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k
)
k_cache, v_cache, kv_state, k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
# Softmax attention terms
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
k.shape[-1] ** -0.5
)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[
..., None
]
)
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Process prefill
# Use TK-implemented linear + terrace window attention
b, h, l, d = q.shape
device = q.device
# tk.hedgehog arguments
# y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
# kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
# k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device)
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
alphas = (
1 - betas
if self.affine_attention_factors
else torch.ones(betas.shape, dtype=torch.float32, device=device)
)
q_map = self.feature_map_q.mlp.layer
k_map = self.feature_map_k.mlp.layer
# Saves outputs to y_pred, k_state, kv_state, where we fuse:
# 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
# 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d
# 3. kv_state = torch.einsum(bhlf,bhld->bhfd,
# f_k[:, :, :-self.window_size],
# v[:, :, :-self.window_size]) # b, h, f, d
# 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d
tk_window_hedgehog_attention(
q.contiguous(),
k.contiguous(),
v.contiguous(),
self.y_true,
self.k_state,
self.kv_state,
q_map,
k_map,
alphas,
betas,
)
past_key_value.update_with_kv(
self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx
)
# Concatenate heads and apply output projection
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, None, past_key_value
class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a “KV state” and “K state”
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.window_size = window_size
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
def update_with_kv(
self,
kv_state: torch.Tensor,
k_state: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_idx: int,
):
"""
Update the cache with new KV and K states
"""
if layer_idx == 0:
self._seen_tokens += k.shape[2]
self._seen_tokens_by_layer.append(k.shape[2])
# Initialize KV and K states
if len(self.decode_k_states) <= layer_idx:
self.decode_kv_states.append(kv_state)
self.decode_k_states.append(k_state)
else: # Update KV and K states
self.decode_kv_states[layer_idx] = (
self.decode_kv_states[layer_idx] + kv_state
)
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state
self.k_cache.append(k[:, :, -self.window_size :, :])
self.v_cache.append(v[:, :, -self.window_size :, :])
def update_for_decoding(
self, k: torch.Tensor, v: torch.Tensor, layer_idx: int, feature_map_k: Callable
):
"""
Update the cache for decoding
"""
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum("bhlf,bhld->bhfd", k_state.float(), v_state.float()).to(
k.dtype
)
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
if layer_idx == 0:
self._seen_tokens += k.shape[-2]
self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
return (
self.k_cache[layer_idx],
self.v_cache[layer_idx],
self.decode_kv_states[layer_idx],
self.decode_k_states[layer_idx],
)

View File

@@ -1,306 +0,0 @@
"""
LoLCATs attention combining sliding window and linear attentions
- Using the TK "terracing" arrangement
- Training over long sequences with fixed memory with recurrent view
- During attention transfer, use Flash Attention to compute softmax attention outputs
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
import logging
from typing import Optional
import torch
import torch.nn.functional as F
from transformers.cache_utils import Cache
try:
from transformers.modeling_flash_attention_utils import _flash_attention_forward
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from .linear_attention import softmax_attention
from .linear_window_attention_tk import LolcatsTKWindowAttention
LOG = logging.getLogger(
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
)
class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(self, remove_base_attn=True, **kwargs):
# keep self.base_attn for Flash Attention inference
super().__init__(remove_base_attn=True, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
if self.train_attention and self.base_inference:
with torch.no_grad():
# LOG.debug(hidden_states.shape)
_y_true = flash_attention_2(
self, # self.base_attn,
hidden_states=hidden_states,
attention_mask=None,
position_ids=position_ids,
past_key_value=None,
output_attentions=False,
# output_hidden_states=False,
use_cache=False,
)[0]
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
y_true = _y_true.reshape(b, l, -1).contiguous()
y_true = self.o_proj(y_true)
layer_io = (hidden_states, _y_true) # hack
# layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
return y_true, layer_io, None
q, k, v, kv_seq_len = self.process_qkv(
hidden_states, attention_mask, position_ids, past_key_value
)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
)
else:
past_key_value.window_size = self.decode_window_size
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
k.shape[-1] ** -0.5
)
# a_sm = torch.softmax(a_sm, dim=-1)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
y_pred = torch.einsum(
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
) + linear_factors * torch.einsum(
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
)
sum_ln = (
linear_factors
* torch.einsum("bhlf,bhnf->bhl", f_q.float(), f_k_state.float())[
..., None
]
)
y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
if (
self.state_grad_enabled
and self.layer_idx == 0
and position_ids is not None
):
LOG.debug(
f"\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]"
)
LOG.debug(
f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}"
)
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = (
1 - window_factors if self.affine_attention_factors else 1
)
y_pred, a_pred = self.quadratic_attention(
q,
k,
f_q,
f_k,
v,
window_factors,
linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state,
)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(
k, v, self.layer_idx, fmap_key_states=f_k, accumulate_in_fp32=True
)
# Concatenate heads and apply output projection
_y_pred = y_pred.transpose(1, 2).contiguous()
y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size))
if self.train_attention:
with torch.no_grad():
a_true = softmax_attention(q, k, None, causal=True)[1]
attn_weights = (_y_pred, (a_pred, a_true))
else:
attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d)
return y_pred, attn_weights, past_key_value
# -----------------
# Flash Attention 2
# -----------------
def flash_attention_2(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
"""
Wrapper for LlamaFlashAttention2
Copied and modified from HF Transformers v4.36 and v4.43 implementations
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
try: # As in Transformers v4.36
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
except Exception: # As in Transformers v4.39
cos, sin = self.rotary_emb(key_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
LOG.debug(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self, "_flash_attention_forward", False):
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
is_causal=True,
)
else:
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0, # dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
return attn_output, past_key_value

View File

@@ -1,361 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Linear LLaMA model implementation."""
import logging
from functools import partial
from typing import Any, Optional
from torch import nn
from tqdm import tqdm
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from .configuration_linear_llama import LinearLlamaConfig
LOG = logging.getLogger(__name__)
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
"""
Modified LlamaDecoderLayer that uses LinearAttention instead of standard attention.
"""
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
# Replace the attention layer with our custom attention
self.self_attn = convert_llama_attention(
layer=self, attention_config=config.attention_config
)
class LinearLlamaModel(LlamaModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LinearLlamaDecoderLayer`]
Args:
config: LinearLlamaConfig
"""
config_class = LinearLlamaConfig
base_model_prefix = "linear_llama"
def __init__(self, config: LinearLlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LinearLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
class LinearLlamaForCausalLM(LlamaForCausalLM):
"""
Linear LLaMA model for causal language modeling.
"""
config_class = LinearLlamaConfig
base_model_prefix = "linear_llama"
def __init__(self, config):
super().__init__(config)
self.model = LinearLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@classmethod
def from_llama(
cls,
model: LlamaForCausalLM,
config: LinearLlamaConfig,
train_attention: bool = False,
remove_base_attn: bool = True,
) -> "LinearLlamaForCausalLM":
"""
Initialize a LinearLlamaForCausalLM from a LlamaModel
"""
if config is None:
raise ValueError("Missing config")
# initialize a new model with config
new_model = cls(config=config)
# remove the default model and lm_head
del new_model.model
del new_model.lm_head
# load converted model, lm_head, and vocab_size from llama model
new_model.model = convert_attention(
model.model,
attention_config=config.attention_config,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
new_model.lm_head = model.lm_head
new_model.vocab_size = model.vocab_size
return new_model
def toggle_attention(self, train: bool = True):
"""
Toggle attention to be trainable or not
"""
toggle_attention(self.model, train=train)
def remove_base_attention(self):
"""
Remove base attention after distillation
"""
remove_base_attention(self.model)
def convert_attention(
model: nn.Module,
attention_config: dict,
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Call to convert all attention layers
"""
# Get the layers to convert if provided
softmax_attns = attention_config.get("softmax_attentions", [])
# Get the attention to convert to
attention_type = attention_config.get("attention_type")
if attention_type != "softmax":
layers = traverse_layers(model)
for layer_idx, layer in enumerate(
tqdm(layers, desc="Converting attentions...")
):
if layer_idx not in softmax_attns:
layer.self_attn = convert_llama_attention(
layer,
attention_config,
layers,
train_attention,
remove_base_attn,
)
layer.self_attn.converted = True
else:
# Freeze any preserved softmax attention layers
for p in layer.parameters():
p.requires_grad = False
else:
LOG.info(
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
)
return model
def toggle_attention(llama_model: nn.Module, train: bool = False):
"""
Make attentions trainable if train is True
-> Set train_attention = False when finetuning
"""
for layer in traverse_layers(llama_model):
layer.self_attn.train_attention = train
return llama_model
def remove_base_attention(llama_model: nn.Module):
"""
Remove teacher attention after distillation (if we keep it)
"""
for layer in traverse_layers(llama_model):
if getattr(layer.self_attn, "base_attn", False):
del layer.self_attn.base_attn
return llama_model
def traverse_layers(model: nn.Module, verbose: bool = False):
"""
Return list of model layers
"""
try:
layers = model.model.layers
if verbose:
LOG.info("-> Loading from model.model.layers")
except AttributeError as e: # if base model
if verbose:
LOG.info(e)
try:
layers = model.layers
if verbose:
LOG.info("-> Loading from model.layers")
except AttributeError as e1: # If we make a PEFT model
if verbose:
LOG.info(e1)
layers = model.base_model.model.model.layers
if verbose:
LOG.info("-> Loading from model.base_model.model.model.layers")
return layers
def convert_llama_attention(
layer: nn.Module,
attention_config: dict,
layers: Optional[list[nn.Module]] = None, # list of layers
train_attention: bool = False,
remove_base_attn: bool = True,
):
"""
Converts a single layer's attention layer as specified by attention_config
"""
return get_attention(**attention_config)(
base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1 if layers else None,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
def get_attention(attention_type: str, **kwargs):
"""
Get the linear attention class; either purely linear or linear with sliding window
-> 'linear' == 'lolcats_llama'
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
"""
kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama":
from .linear_attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk":
from .linear_window_attention_tk import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw":
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear":
from .linear_window_attention_sw_linear import (
LolcatsLinearSlidingWindowAttention,
)
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk":
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw":
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs)
else:
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
return None
def get_attention_cache(attention_type: str, past_key_values: Any = None):
"""
Determine how we store past keys and values when generating
"""
if attention_type is None:
return past_key_values
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type:
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type:
from .linear_window_attention_tk import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type:
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type:
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache()
elif "softmax" in attention_type:
return past_key_values
else:
from .linear_attention import LinearAttentionState
return LinearAttentionState()
def register_linear_llama():
"""
Register Linear LLaMA model with the Transformers library.
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
# registering for auto classes to save files
LinearLlamaConfig.register_for_auto_class("AutoConfig")
LinearLlamaModel.register_for_auto_class("AutoModel")
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")

View File

@@ -1,118 +0,0 @@
"""
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
"""
from typing import Any
from torch import Tensor, nn, tensor
from axolotl.core.trainers.base import AxolotlTrainer
class DistillAttentionXentMSETrainer(AxolotlTrainer):
"""
Custom trainer class for distilling attentions.
- We compute and store the attention outputs and/or weights for each head and layer,
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
"""
def __init__(
self,
model: nn.Module,
mse_factor: float = 1e3,
xent_factor: float = 0,
**kwargs: Any,
):
super().__init__(model=model, **kwargs)
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
self.criterion_mse = nn.MSELoss(reduction="mean")
self.mse_factor = mse_factor
self.xent_factor = xent_factor
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
self.model_accepts_loss_kwargs = False # added to combat explosive loss
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, Tensor],
return_outputs=False,
num_items_in_batch=None,
) -> tuple[Tensor, dict]:
"""
Attention distillation ("attention transfer")
- For each layer and head, get attentions and train to
minimize some combo of MSE and cross-entropy loss
"""
# alias inputs to data
data = inputs
device = model.device
# Filter out labels
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
# set num_items_in_batch
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
# Forward pass
outputs = model(**inputs, output_attentions=True, use_cache=False)
outputs = outputs.get("attentions")
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
# n_layers x (predicted_attns, true_attns)
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
loss_mse = tensor(0.0, device=device)
loss_xent = tensor(0.0, device=device)
n_layers = 0 # Number of layers to distill
softmax_layers = []
for layer_idx, attns in enumerate(outputs):
if attns is not None:
if len(attns) != 2:
attns = attns.cpu()
else:
if self.xent_factor > 0:
# Cross-entropy loss
a_pred, a_true = attns[0]
a_pred = a_pred.clamp(
min=1e-12
).log() # nn.CrossEntropy assumes unnormalized logits
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
# Compute mean cross-entropy over all queries
a_pred = a_pred.contiguous().view(-1, k_len)
a_true = a_true.contiguous().view(-1, k_len)
loss_xent += self.criterion_xent(a_pred, a_true)
if self.mse_factor > 0:
loss_mse += self.criterion_mse(*attns[1])
n_layers += 1
else:
softmax_layers.append(layer_idx)
if n_layers > 0:
loss_xent = loss_xent / n_layers * self.xent_factor
loss_mse = loss_mse / n_layers * self.mse_factor
loss = loss_xent + loss_mse
if "position_ids" in data:
outputs = {
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
"input_len": data["position_ids"].shape[1],
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
"mse_factor": self.mse_factor,
"xent_factor": self.xent_factor,
}
else:
outputs = {
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
"mse_factor": self.mse_factor,
"xent_factor": self.xent_factor,
}
return (loss, outputs) if return_outputs else loss

View File

@@ -1,15 +1,17 @@
## Spectrum: Targeted Training on Signal to Noise Ratio
# Spectrum: Targeted Training on Signal to Noise Ratio
by Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar
This plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR).
### Overview
See https://github.com/cognitivecomputations/spectrum
## Overview
Spectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models.
By identifying the top n% of layers with the highest SNR, you can optimize training efficiency.
### Usage
## Usage
```yaml
plugins:
@@ -19,3 +21,17 @@ spectrum_top_fraction: 0.5
# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror
spectrum_model_name: meta-llama/Meta-Llama-3.1-8B
```
## Citation
```bib
@misc{hartford2024spectrumtargetedtrainingsignal,
title={Spectrum: Targeted Training on Signal to Noise Ratio},
author={Eric Hartford and Lucas Atkins and Fernando Fernandes Neto and David Golchinfar},
year={2024},
eprint={2406.06623},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2406.06623},
}
```

Some files were not shown because too many files have changed in this diff Show More