Compare commits
30 Commits
kto_fix
...
fix/xforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7888a35118 | ||
|
|
873385b7d5 | ||
|
|
cf0c79d52e | ||
|
|
4ba80a0e5a | ||
|
|
c49682132b | ||
|
|
e46239f8d3 | ||
|
|
05f03b541a | ||
|
|
a4e430e7c4 | ||
|
|
6cdcb8ddd5 | ||
|
|
a7811ad4a0 | ||
|
|
e2da821e67 | ||
|
|
2c34a4634e | ||
|
|
a9b0733f2c | ||
|
|
9f00465a5c | ||
|
|
86bac48d14 | ||
|
|
e44953d50c | ||
|
|
23f0c51d88 | ||
|
|
113e9cd193 | ||
|
|
61825a464a | ||
|
|
c907ac173e | ||
|
|
187227d837 | ||
|
|
f8de8bb4f2 | ||
|
|
8e604848a4 | ||
|
|
aae4337f40 | ||
|
|
38df5a36ea | ||
|
|
4d92a68a96 | ||
|
|
85147ec430 | ||
|
|
51cd409488 | ||
|
|
7235123d44 | ||
|
|
4f5eb42a73 |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -40,6 +40,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: nightly
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -61,7 +67,7 @@ jobs:
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/Dockerfile-base
|
||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
7
.github/workflows/docs.yml
vendored
7
.github/workflows/docs.yml
vendored
@@ -20,9 +20,12 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- name: install dependencies
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter
|
||||
python3 -m pip install jupyter quartodoc
|
||||
python3 -m pip install -e . --no-deps
|
||||
- name: Build autodoc
|
||||
run: quartodoc build
|
||||
- name: Publish to GitHub Pages (and render)
|
||||
uses: quarto-dev/quarto-actions/publish@v2
|
||||
with:
|
||||
|
||||
49
.github/workflows/precommit-autoupdate.yml
vendored
Normal file
49
.github/workflows/precommit-autoupdate.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: Pre-commit auto-update
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 0' # Run weekly
|
||||
workflow_dispatch: # Manual kickoff
|
||||
|
||||
jobs:
|
||||
auto-update:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Update pre-commit hooks
|
||||
id: update
|
||||
run: |
|
||||
pip install pre-commit
|
||||
pre-commit autoupdate
|
||||
if [[ -n $(git status --porcelain) ]]; then
|
||||
echo "changes=true" >> $GITHUB_OUTPUT
|
||||
git diff .pre-commit-config.yaml > pre-commit-update.diff
|
||||
fi
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.update.outputs.changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
branch: update/pre-commit-hooks
|
||||
delete-branch: true
|
||||
title: "chore: update pre-commit hooks"
|
||||
commit-message: "chore: update pre-commit hooks"
|
||||
body: |
|
||||
Automated PR to update pre-commit hooks to their latest versions.
|
||||
|
||||
<details>
|
||||
<summary>Changes:</summary>
|
||||
|
||||
```diff
|
||||
${{ steps.update.outputs.diff }}
|
||||
```
|
||||
</details>
|
||||
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging
|
||||
pip3 install wheel packaging==23.2
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
|
||||
6
.github/workflows/tests-nightly.yml
vendored
6
.github/workflows/tests-nightly.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install --upgrade packaging==23.2
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
@@ -136,4 +136,4 @@ jobs:
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
21
.github/workflows/tests.yml
vendored
21
.github/workflows/tests.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -98,8 +98,9 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -136,7 +137,7 @@ jobs:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -147,7 +148,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -170,10 +171,14 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Show HF cache
|
||||
run: huggingface-cli scan-cache
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -227,7 +232,7 @@ jobs:
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
@@ -274,4 +279,4 @@ jobs:
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -181,6 +181,10 @@ prepared-datasets/
|
||||
submit.sh
|
||||
*.out*
|
||||
|
||||
# Quartodoc generated files
|
||||
objects.json
|
||||
site_libs/
|
||||
|
||||
typings/
|
||||
out/
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
[settings]
|
||||
profile=black
|
||||
known_third_party=wandb,comet_ml
|
||||
known_local_folder=src,tests
|
||||
|
||||
@@ -3,7 +3,7 @@ default_language_version:
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
@@ -11,23 +11,23 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.1.0
|
||||
rev: 7.1.2
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/pylint
|
||||
rev: v3.3.0
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
rev: v3.3.6
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.3.0
|
||||
rev: v1.15.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
@@ -36,7 +36,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.7.5
|
||||
rev: 1.8.3
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
@@ -55,6 +55,7 @@ Features:
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
@@ -96,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||
|
||||
## 🤝 Getting Help
|
||||
|
||||
201
_quarto.yml
201
_quarto.yml
@@ -1,6 +1,179 @@
|
||||
project:
|
||||
type: website
|
||||
|
||||
quartodoc:
|
||||
dir: docs/api
|
||||
package: axolotl
|
||||
title: API Reference
|
||||
parser: google
|
||||
|
||||
sections:
|
||||
- title: Core
|
||||
desc: Core functionality for training
|
||||
contents:
|
||||
- train
|
||||
- evaluate
|
||||
- datasets
|
||||
- convert
|
||||
- prompt_tokenizers
|
||||
- logging_config
|
||||
- core.trainer_builder
|
||||
- core.training_args
|
||||
- core.chat.messages
|
||||
- core.chat.format.chatml
|
||||
- core.chat.format.llama3x
|
||||
- core.chat.format.shared
|
||||
- core.datasets.chat
|
||||
- core.datasets.transforms.chat_builder
|
||||
- title: CLI
|
||||
desc: Command-line interface
|
||||
contents:
|
||||
- cli.main
|
||||
- cli.train
|
||||
- cli.evaluate
|
||||
- cli.args
|
||||
- cli.checks
|
||||
- cli.config
|
||||
- cli.inference
|
||||
- cli.merge_lora
|
||||
- cli.merge_sharded_fsdp_weights
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
- core.trainers.base
|
||||
- core.trainers.trl
|
||||
- core.trainers.dpo.trainer
|
||||
- core.trainers.grpo.trainer
|
||||
- title: Prompt Strategies
|
||||
desc: Prompt formatting strategies
|
||||
contents:
|
||||
- prompt_strategies.base
|
||||
- prompt_strategies.chat_template
|
||||
- prompt_strategies.alpaca_chat
|
||||
- prompt_strategies.alpaca_instruct
|
||||
- prompt_strategies.alpaca_w_system
|
||||
- prompt_strategies.user_defined
|
||||
- prompt_strategies.llama2_chat
|
||||
- prompt_strategies.completion
|
||||
- prompt_strategies.input_output
|
||||
- prompt_strategies.stepwise_supervised
|
||||
- prompt_strategies.metharme
|
||||
- prompt_strategies.orcamini
|
||||
- prompt_strategies.pygmalion
|
||||
- prompt_strategies.messages.chat
|
||||
- prompt_strategies.dpo.chat_template
|
||||
- prompt_strategies.dpo.llama3
|
||||
- prompt_strategies.dpo.chatml
|
||||
- prompt_strategies.dpo.zephyr
|
||||
- prompt_strategies.dpo.user_defined
|
||||
- prompt_strategies.dpo.passthrough
|
||||
- prompt_strategies.kto.llama3
|
||||
- prompt_strategies.kto.chatml
|
||||
- prompt_strategies.kto.user_defined
|
||||
- prompt_strategies.orpo.chat_template
|
||||
- prompt_strategies.bradley_terry.llama3
|
||||
- title: Kernels
|
||||
desc: Low-level performance optimizations
|
||||
contents:
|
||||
- kernels.lora
|
||||
- kernels.geglu
|
||||
- kernels.swiglu
|
||||
- kernels.quantize
|
||||
- kernels.utils
|
||||
- title: MonkeyPatches
|
||||
desc: Runtime patches for model optimizations
|
||||
contents:
|
||||
- monkeypatch.llama_attn_hijack_flash
|
||||
- monkeypatch.llama_attn_hijack_xformers
|
||||
- monkeypatch.mistral_attn_hijack_flash
|
||||
- monkeypatch.multipack
|
||||
- monkeypatch.relora
|
||||
- monkeypatch.llama_expand_mask
|
||||
- monkeypatch.lora_kernels
|
||||
- monkeypatch.utils
|
||||
- monkeypatch.btlm_attn_hijack_flash
|
||||
- monkeypatch.llama_patch_multipack
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
- monkeypatch.unsloth_
|
||||
- monkeypatch.attention.mllama
|
||||
- monkeypatch.data.batch_dataset_fetcher
|
||||
- monkeypatch.mixtral
|
||||
- title: Utils
|
||||
desc: Utility functions
|
||||
contents:
|
||||
- utils.models
|
||||
- utils.tokenization
|
||||
- utils.chat_templates
|
||||
- utils.lora
|
||||
- utils.lora_embeddings
|
||||
- utils.model_shard_quant
|
||||
- utils.bench
|
||||
- utils.freeze
|
||||
- utils.trainer
|
||||
- utils.schedulers
|
||||
- utils.distributed
|
||||
- utils.dict
|
||||
- utils.optimizers.adopt
|
||||
- utils.data.pretraining
|
||||
- utils.data.sft
|
||||
- utils.gradient_checkpointing.unsloth
|
||||
- title: Schemas
|
||||
desc: Pydantic data models for Axolotl config
|
||||
contents:
|
||||
- utils.schemas.config
|
||||
- utils.schemas.model
|
||||
- utils.schemas.training
|
||||
- utils.schemas.datasets
|
||||
- utils.schemas.peft
|
||||
- utils.schemas.trl
|
||||
- utils.schemas.multimodal
|
||||
- utils.schemas.integrations
|
||||
- utils.schemas.enums
|
||||
- utils.schemas.utils
|
||||
- title: Integrations
|
||||
desc: Third-party integrations and extensions
|
||||
contents:
|
||||
- integrations.base
|
||||
- integrations.cut_cross_entropy.args
|
||||
- integrations.grokfast.optimizer
|
||||
- integrations.kd.trainer
|
||||
- integrations.liger.args
|
||||
- integrations.lm_eval.args
|
||||
- integrations.spectrum.args
|
||||
- title: Common
|
||||
desc: Common utilities and shared functionality
|
||||
contents:
|
||||
- common.architectures
|
||||
- common.const
|
||||
- common.datasets
|
||||
- title: Models
|
||||
desc: Custom model implementations
|
||||
contents:
|
||||
- models.mamba.modeling_mamba
|
||||
- title: Data Processing
|
||||
desc: Data processing utilities
|
||||
contents:
|
||||
- utils.collators.core
|
||||
- utils.collators.batching
|
||||
- utils.collators.mamba
|
||||
- utils.collators.mm_chat
|
||||
- utils.samplers.multipack
|
||||
- title: Callbacks
|
||||
desc: Training callbacks
|
||||
contents:
|
||||
- utils.callbacks.perplexity
|
||||
- utils.callbacks.profiler
|
||||
- utils.callbacks.lisa
|
||||
- utils.callbacks.mlflow_
|
||||
- utils.callbacks.comet_
|
||||
|
||||
website:
|
||||
title: "Axolotl"
|
||||
description: "We make fine-tuning accessible, scalable, and fun"
|
||||
@@ -32,8 +205,11 @@ website:
|
||||
contents:
|
||||
- docs/getting-started.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/inference.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/config.qmd
|
||||
- text: "API Reference"
|
||||
href: docs/api
|
||||
|
||||
- section: "Dataset Formats"
|
||||
contents: docs/dataset-formats/*
|
||||
@@ -74,12 +250,27 @@ website:
|
||||
- docs/debugging.qmd
|
||||
- docs/nccl.qmd
|
||||
|
||||
- section: "Reference"
|
||||
contents:
|
||||
- docs/config.qmd
|
||||
|
||||
format:
|
||||
html:
|
||||
theme: darkly
|
||||
css: styles.css
|
||||
toc: true
|
||||
# Enable better handling of line breaks in markdown
|
||||
preserve-tabs: true
|
||||
html-math-method: mathjax
|
||||
# Improved markdown processing options
|
||||
md-extensions:
|
||||
- markdown_it
|
||||
- def_list
|
||||
- attr_list
|
||||
- fenced_divs
|
||||
- tables
|
||||
- html_admonition
|
||||
- lineblocks
|
||||
- fancy_lists
|
||||
# Control whitespace handling
|
||||
whitespace: preserve
|
||||
# Process newlines in paragraphs
|
||||
wrap: preserve
|
||||
# Better line break handling
|
||||
preserve-linebreaks: true
|
||||
|
||||
@@ -31,10 +31,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
|
||||
@@ -3,9 +3,10 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
|
||||
@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
39
docker/Dockerfile-base-nightly
Normal file
39
docker/Dockerfile-base-nightly
Normal file
@@ -0,0 +1,39 @@
|
||||
ARG CUDA_VERSION="12.8.1"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="nightly"
|
||||
ARG CUDA="128"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1,2 +1,4 @@
|
||||
/.quarto/
|
||||
_site/
|
||||
/api/*.qmd
|
||||
/api/*.html
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "CLI Reference"
|
||||
title: "Command Line Interface (CLI)"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
|
||||
102
docs/config.qmd
102
docs/config.qmd
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Config options
|
||||
title: Config Reference
|
||||
description: A complete list of all configuration options.
|
||||
---
|
||||
|
||||
@@ -30,6 +30,11 @@ tokenizer_legacy:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
# Whether to load the model with randomly initialized weights. Useful for
|
||||
# pre-training a model from scratch or debugging purposes.
|
||||
random_init_weights:
|
||||
|
||||
# (Internal use only)
|
||||
# Used to identify which the model is based on
|
||||
@@ -83,6 +88,12 @@ gpu_memory_limit: 20GiB
|
||||
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
||||
lora_on_cpu: true
|
||||
|
||||
# List[str]. Add plugins to extend the pipeline.
|
||||
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
|
||||
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
|
||||
plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
@@ -205,10 +216,46 @@ test_datasets:
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto'
|
||||
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
||||
rl:
|
||||
# whether to perform weighting if doing DPO training. Boolean.
|
||||
dpo_use_weighting:
|
||||
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
||||
|
||||
# dpo
|
||||
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
||||
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
||||
|
||||
# orpo
|
||||
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
||||
|
||||
# kto
|
||||
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
||||
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
||||
|
||||
# simpo
|
||||
cpo_alpha: 1.0 # Weight of the BC regularizer
|
||||
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||
|
||||
# grpo
|
||||
trl:
|
||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||
vllm_device: # Optional[str]. Device to use for VLLM.
|
||||
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
||||
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
||||
vllm_dtype: # Optional[str]. Data type for VLLM.
|
||||
|
||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||
|
||||
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
||||
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
||||
|
||||
num_generations: # Optional[int]. Number of generations to sample.
|
||||
log_completions: # Optional[bool]. Whether to log completions.
|
||||
|
||||
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
||||
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
||||
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
||||
|
||||
|
||||
# reward modelling: `True` or `False`
|
||||
reward_model:
|
||||
@@ -232,7 +279,7 @@ default_system_message: You are a helpful assistant. Please give a long and deta
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# Push prepared dataset to hub
|
||||
push_dataset_to_hub: # repo path
|
||||
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
dataset_processes: # defaults to os.cpu_count() if not set
|
||||
@@ -419,6 +466,7 @@ auto_find_batch_size: # Optional[bool]
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||
@@ -459,36 +507,58 @@ lr_div_factor: # Learning rate div factor
|
||||
|
||||
# Specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
||||
#
|
||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||
# in the examples/ for your model and fine-tuning use case.
|
||||
#
|
||||
# Valid values for 'optimizer' include:
|
||||
# - adamw_hf
|
||||
# - adamw_torch
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_torch_npu_fused
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - adamw_torch_4bit
|
||||
# - ademamix
|
||||
# - sgd
|
||||
# - adagrad
|
||||
# - adamw_bnb_8bit
|
||||
# - adamw_8bit # alias for adamw_bnb_8bit
|
||||
# - ademamix_8bit
|
||||
# - lion_8bit
|
||||
# - lion_32bit
|
||||
# - paged_adamw_32bit
|
||||
# - paged_adamw_8bit
|
||||
# - paged_ademamix_32bit
|
||||
# - paged_ademamix_8bit
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
# - rmsprop
|
||||
# - rmsprop_bnb
|
||||
# - rmsprop_bnb_8bit
|
||||
# - rmsprop_bnb_32bit
|
||||
# - galore_adamw
|
||||
# - galore_adamw_8bit
|
||||
# - galore_adafactor
|
||||
# - galore_adamw_layerwise
|
||||
# - galore_adamw_8bit_layerwise
|
||||
# - galore_adafactor_layerwise
|
||||
# - lomo
|
||||
# - adalomo
|
||||
# - grokadamw
|
||||
# - schedule_free_adamw
|
||||
# - schedule_free_sgd
|
||||
# - apollo_adamw
|
||||
# - apollo_adamw_layerwise
|
||||
#
|
||||
# Additional custom optimizers include:
|
||||
# - optimi_adamw
|
||||
# - ao_adamw_8bit
|
||||
# - ao_adamw_fp8
|
||||
optimizer:
|
||||
# Dictionary of arguments to pass to the optimizer
|
||||
optim_args:
|
||||
@@ -540,6 +610,14 @@ resume_from_checkpoint:
|
||||
# Be careful with this being turned on between different models.
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
## Multimodal section
|
||||
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
||||
# Will read from model/processor config if not set.
|
||||
image_size:
|
||||
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
||||
image_resize_algorithm: 'bilinear'
|
||||
## End of multimodal section
|
||||
|
||||
# Don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
|
||||
@@ -573,6 +651,14 @@ ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Sequence parallelism
|
||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
|
||||
@@ -55,3 +55,47 @@ sections = [
|
||||
for section_name, folder_name in sections:
|
||||
print(print_section(section_name, folder_name))
|
||||
```
|
||||
|
||||
## Adding a new integration
|
||||
|
||||
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
|
||||
|
||||
To add a new integration, please follow these steps:
|
||||
|
||||
1. Create a new folder in the `src/axolotl/integrations` directory.
|
||||
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
|
||||
3. Add `__init__.py` and `args.py` files to the new folder.
|
||||
- `__init__.py` should import the integration and hook into the appropriate functions.
|
||||
- `args.py` should define the arguments for the integration.
|
||||
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
If you could not load your integration, please ensure you are pip installing in editable mode.
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
and correctly spelled the integration name in the config file.
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
|
||||
|
||||
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
|
||||
|
||||
:::
|
||||
|
||||
@@ -6,7 +6,7 @@ description: How datasets are processed
|
||||
## Overview
|
||||
|
||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
||||
the [dataset format](dataset-formats) and prompt strategies to:
|
||||
|
||||
- parse the dataset based on the *dataset format*
|
||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||
|
||||
@@ -103,8 +103,7 @@ This uses the same tags as the [`main` image](#sec-main-tags).
|
||||
|
||||
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
||||
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
||||
- `PUBLIC_KEY`: Add a public key for the SSH service.
|
||||
- `SSH_KEY`: Add a private key for the SSH service.
|
||||
- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service.
|
||||
|
||||
#### Volume mounts
|
||||
|
||||
|
||||
14
docs/faq.qmd
14
docs/faq.qmd
@@ -27,6 +27,20 @@ description: Frequently asked questions
|
||||
|
||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||
|
||||
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
|
||||
|
||||
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
|
||||
|
||||
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
|
||||
|
||||
**Q: How to call Axolotl via custom python scripts?**
|
||||
|
||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||
|
||||
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
|
||||
|
||||
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
@@ -36,7 +36,9 @@ The YAML configuration file controls everything about your training. Here's what
|
||||
|
||||
```yaml
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
adapter: lora
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
@@ -44,11 +46,15 @@ datasets:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
|
||||
|
||||
- To perform Full finetuning, remove these two lines.
|
||||
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
||||
:::
|
||||
|
||||
See our [Config options](config.qmd) for more details.
|
||||
|
||||
### Training {#sec-training}
|
||||
@@ -56,7 +62,7 @@ See our [Config options](config.qmd) for more details.
|
||||
When you run `axolotl train`, Axolotl:
|
||||
|
||||
1. Downloads the base model
|
||||
2. (If specified) applies LoRA adapter layers
|
||||
2. (If specified) applies QLoRA/LoRA adapter layers
|
||||
3. Loads and processes the dataset
|
||||
4. Runs the training loop
|
||||
5. Saves the trained model and / or LoRA weights
|
||||
@@ -69,6 +75,8 @@ Let's modify the example for your own data:
|
||||
|
||||
```yaml
|
||||
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
||||
|
||||
load_in_8bit: true
|
||||
adapter: lora
|
||||
|
||||
# Training settings
|
||||
@@ -104,8 +112,6 @@ format):
|
||||
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
||||
```
|
||||
|
||||
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
|
||||
|
||||
3. Run the training:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Inference"
|
||||
title: "Inference and Merging"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
@@ -9,10 +9,14 @@ execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
|
||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
|
||||
|
||||
## Quick Start {#sec-quickstart}
|
||||
|
||||
::: {.callout-tip}
|
||||
Use the same config used for training on inference/merging.
|
||||
:::
|
||||
|
||||
### Basic Inference {#sec-basic}
|
||||
|
||||
::: {.panel-tabset}
|
||||
|
||||
@@ -22,6 +22,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
@@ -37,7 +38,7 @@ For the latest features between releases:
|
||||
```{.bash}
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install packaging ninja
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
@@ -78,6 +79,7 @@ For providers supporting Docker:
|
||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
||||
|
||||
### Google Colab {#sec-colab}
|
||||
|
||||
@@ -107,7 +109,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||
3. Install Axolotl:
|
||||
```{.bash}
|
||||
pip3 install packaging
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
|
||||
@@ -66,6 +66,10 @@ logic to be compatible with more of them.
|
||||
|
||||
</details>
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
|
||||
:::
|
||||
|
||||
## Usage
|
||||
|
||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||
|
||||
@@ -1,28 +1,171 @@
|
||||
# MultiModal / Vision Language Models (BETA)
|
||||
---
|
||||
title: MultiModal / Vision Language Models (BETA)
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
---
|
||||
|
||||
### Supported Models
|
||||
## Supported Models
|
||||
|
||||
- Mllama, i.e. llama with vision models
|
||||
- [Mllama](#sec-mllama)
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
|
||||
### Usage
|
||||
## Usage
|
||||
|
||||
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
|
||||
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
|
||||
Multimodal support is limited and doesn't have full feature parity.
|
||||
|
||||
Here are the hyperparams you'll need to use to finetune a multimodal model.
|
||||
|
||||
```yaml
|
||||
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
||||
processor_type: AutoProcessor
|
||||
skip_prepare_dataset: true
|
||||
|
||||
chat_template: llama3_2_vision
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
||||
sample_packing: false # not yet supported with multimodal
|
||||
|
||||
chat_template: # see in next section
|
||||
|
||||
# example dataset
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# only finetune the Language model, leave the vision model and vision tower frozen
|
||||
# (optional) if doing lora, only finetune the Language model,
|
||||
# leave the vision model and vision tower frozen
|
||||
# load_in_8bit: true
|
||||
adapter: lora
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# (optional) if you want to resize images to a set size
|
||||
image_size: 512
|
||||
image_resize_algorithm: bilinear
|
||||
```
|
||||
|
||||
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||
|
||||
::: {.callout-warning}
|
||||
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||
:::
|
||||
|
||||
### Mllama {#sec-mllama}
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
|
||||
chat_template: llama3_2_vision
|
||||
```
|
||||
|
||||
### Pixtral {#sec-pixtral}
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Pixtral-12B-2409
|
||||
|
||||
chat_template: pixtral
|
||||
```
|
||||
|
||||
### Llava-1.5 {#sec-llava-15}
|
||||
|
||||
```yaml
|
||||
base_model: llava-hf/llava-1.5-7b-hf
|
||||
|
||||
chat_template: llava
|
||||
```
|
||||
|
||||
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
```
|
||||
|
||||
### Gemma-3 {#sec-gemma-3}
|
||||
|
||||
::: {.callout-tip}
|
||||
The Gemma3-1B model is a text-only model, so please train as regular text model.
|
||||
:::
|
||||
|
||||
For multi-modal 4B/12B/27B models, use the following config:
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
chat_template: gemma3
|
||||
```
|
||||
|
||||
### Qwen2-VL {#sec-qwen2-vl}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||
|
||||
chat_template: qwen2_vl
|
||||
```
|
||||
|
||||
### Qwen2.5-VL {#sec-qwen25-vl}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||
|
||||
- A message is a list of `role` and `content`.
|
||||
- `role` can be `system`, `user`, `assistant`, etc.
|
||||
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
|
||||
|
||||
::: {.callout-note}
|
||||
For backwards compatibility:
|
||||
|
||||
- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key.
|
||||
- If `content` is a string, it will be converted to a list with `type` as `text`.
|
||||
:::
|
||||
|
||||
::: {.callout-tip}
|
||||
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
|
||||
|
||||
- `"path": "/path/to/image.jpg"`
|
||||
- `"url": "https://example.com/image.jpg"`
|
||||
- `"base64": "..."`
|
||||
- `"image": PIL.Image`
|
||||
:::
|
||||
|
||||
Here is an example of a multi-modal dataset:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{"type": "text", "text": "You are a helpful assistant."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
||||
{"type": "text", "text": "Describe this image in detail."}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "The image is a bee."}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
@@ -41,6 +41,10 @@ Bradley-Terry chat templates expect single-turn conversations in the following f
|
||||
|
||||
### Process Reward Models (PRM)
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
|
||||
:::
|
||||
|
||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-3B
|
||||
|
||||
@@ -298,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
|
||||
### IPO
|
||||
|
||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||
|
||||
```yaml
|
||||
rl: ipo
|
||||
@@ -344,8 +344,9 @@ ORPO supports the following types with the following dataset format:
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
rl_beta: 0.5
|
||||
kto_desirable_weight: 0.2
|
||||
rl_beta: 0.1 # default
|
||||
kto_desirable_weight: 1.0 # default
|
||||
kto_undesirable_weight: 1.0 # default
|
||||
|
||||
remove_unused_columns: false
|
||||
|
||||
@@ -497,6 +498,10 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
|
||||
### GRPO
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||
:::
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
|
||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
@@ -540,6 +545,19 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
|
||||
|
||||
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||
|
||||
### SimPO
|
||||
|
||||
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
||||
|
||||
```yaml
|
||||
rl: simpo
|
||||
rl_beta: 0.1 # default in CPOTrainer
|
||||
cpo_alpha: 1.0 # default in CPOTrainer
|
||||
simpo_gamma: 0.5 # default in CPOTrainer
|
||||
```
|
||||
|
||||
This method uses the same dataset format as [DPO](#dpo).
|
||||
|
||||
### Using local dataset files
|
||||
|
||||
```yaml
|
||||
|
||||
90
docs/sequence_parallelism.qmd
Normal file
90
docs/sequence_parallelism.qmd
Normal file
@@ -0,0 +1,90 @@
|
||||
---
|
||||
title: Sequence Parallelism
|
||||
description: Train with long sequences split across multiple GPUs.
|
||||
---
|
||||
|
||||
# Sequence Parallelism
|
||||
|
||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||
GPU processes a different portion of the sequence, and the results are aggregated
|
||||
through a ring communication pattern.
|
||||
|
||||
## When to Use Sequence Parallelism
|
||||
|
||||
Use sequence parallelism when:
|
||||
|
||||
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||
- You have multiple GPUs available
|
||||
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
|
||||
## Implementation Details
|
||||
|
||||
When sequence parallelism is enabled:
|
||||
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
||||
4. The trainer uses special ring communication patterns for attention operations
|
||||
|
||||
## Requirements
|
||||
|
||||
To use sequence parallelism, you need:
|
||||
|
||||
- Multiple GPUs (at least 2)
|
||||
- The `ring-flash-attn` package. Install with:
|
||||
- `pip install axolotl[ring-flash-attn]` (preferred)
|
||||
- `pip install ring-flash-attn>=0.1.4`
|
||||
|
||||
## Limitations
|
||||
|
||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
||||
- May have a small performance overhead due to communication between GPUs
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
# Example config with sequence parallelism
|
||||
base_model: meta-llama/Llama-3-8B-Instruct
|
||||
sequence_len: 8192
|
||||
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
...
|
||||
```
|
||||
|
||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||
into 2 subsequences of length 4096 across 2 GPUs.
|
||||
|
||||
## Sample Packing with Sequence Parallelism
|
||||
|
||||
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||
|
||||
1. Samples are first packed together
|
||||
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
71
examples/cohere/command-r-7b-qlora.yml
Normal file
71
examples/cohere/command-r-7b-qlora.yml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: CohereForAI/c4ai-command-r7b-12-2024
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
# huggingface repo
|
||||
chat_template: cohere
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
74
examples/gemma3/gemma-3-1b-qlora.yml
Normal file
74
examples/gemma3/gemma-3-1b-qlora.yml
Normal file
@@ -0,0 +1,74 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
# huggingface repo
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
63
examples/gemma3/gemma-3-4b-lora.yml
Normal file
63
examples/gemma3/gemma-3-4b-lora.yml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -19,7 +19,6 @@ val_set_size: 0.0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
dataset_exact_deduplication: true
|
||||
test_value: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -55,7 +55,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
|
||||
63
examples/llava/lora-7b.yaml
Normal file
63
examples/llava/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: llava-hf/llava-1.5-7b-hf
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: llava
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
@@ -0,0 +1,66 @@
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
load_in_8bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: mistral_v7_tekken
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
65
examples/pixtral/lora-12b.yml
Normal file
65
examples/pixtral/lora-12b.yml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: mistral-community/pixtral-12b
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: pixtral
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <pad>
|
||||
63
examples/qwen2-vl/lora-7b.yaml
Normal file
63
examples/qwen2-vl/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen2_vl
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -1,5 +1,5 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
@@ -8,6 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
|
||||
description = "LLM Trainer"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
# license = "Apache-2.0"
|
||||
|
||||
[project.scripts]
|
||||
axolotl = "axolotl.cli.main:main"
|
||||
|
||||
@@ -2,3 +2,5 @@ pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
quartodoc
|
||||
jupyter
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.2
|
||||
bitsandbytes==0.45.3
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.4.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.3
|
||||
liger-kernel==0.5.5
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers==4.49.0
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
peft==0.15.0
|
||||
transformers==4.50.0
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.5.2
|
||||
datasets==3.5.0
|
||||
deepspeed==0.16.4
|
||||
trl==0.15.1
|
||||
|
||||
optimum==1.16.2
|
||||
@@ -36,6 +35,7 @@ einops
|
||||
colorama
|
||||
numba
|
||||
numpy>=1.24.4,<=2.0.1
|
||||
|
||||
# qlora things
|
||||
evaluate==0.4.1
|
||||
scipy
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
accelerate==0.34.1
|
||||
addict==2.4.0
|
||||
aiofiles==23.2.1
|
||||
aiohttp==3.9.0
|
||||
aiosignal==1.3.1
|
||||
aiostream==0.5.2
|
||||
alembic==1.13.1
|
||||
annotated-types==0.6.0
|
||||
annoy==1.17.3
|
||||
ansible==6.7.0
|
||||
ansible-core==2.13.13
|
||||
ansible-vault==2.1.0
|
||||
anyio==3.7.1
|
||||
appdirs==1.4.4
|
||||
art==6.0
|
||||
asgiref==3.7.2
|
||||
async-timeout==4.0.2
|
||||
attrdict==2.0.1
|
||||
attrs==22.2.0
|
||||
awscli==1.32.75
|
||||
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
||||
backoff==2.2.1
|
||||
base58==2.1.1
|
||||
beartype==0.17.2
|
||||
bitnet==0.2.1
|
||||
bitsandbytes==0.42.0
|
||||
bittensor==6.7.0
|
||||
black==23.7.0
|
||||
blinker==1.7.0
|
||||
boto3==1.34.75
|
||||
botocore==1.34.75
|
||||
cachetools==5.3.3
|
||||
cachy==0.1.1
|
||||
certifi==2023.7.22
|
||||
cffi==1.16.0
|
||||
cfgv==3.3.1
|
||||
chai-guanaco==1.2.4
|
||||
charset-normalizer==3.2.0
|
||||
cleo==0.6.8
|
||||
click==8.1.7
|
||||
cloudpickle==2.0.0
|
||||
cohere==4.11.2
|
||||
colorama==0.4.4
|
||||
coloredlogs==15.0.1
|
||||
CoLT5-attention==0.10.20
|
||||
contextlib2==21.6.0
|
||||
contourpy==1.2.0
|
||||
cryptography==41.0.3
|
||||
cycler==0.12.1
|
||||
cytoolz==0.12.3
|
||||
databricks-cli==0.18.0
|
||||
dataclasses-json==0.5.7
|
||||
datasets==2.11.0
|
||||
ddt==1.6.0
|
||||
decorator==5.1.1
|
||||
deepspeed==0.15.0
|
||||
# Editable Git install with no remote (dialogpt==0.1)
|
||||
-e /Users/wing/Projects/ml/dialogpt/src
|
||||
dill==0.3.6
|
||||
distlib==0.3.6
|
||||
docker==7.0.0
|
||||
docker-pycreds==0.4.0
|
||||
docstring-parser==0.15
|
||||
docutils==0.16
|
||||
ecdsa==0.18.0
|
||||
einops==0.7.0
|
||||
einops-exts==0.0.4
|
||||
einx==0.1.3
|
||||
entrypoints==0.4
|
||||
eth-hash==0.6.0
|
||||
eth-keys==0.5.0
|
||||
eth-typing==4.0.0
|
||||
eth-utils==2.3.1
|
||||
evaluate==0.4.0
|
||||
exceptiongroup==1.1.1
|
||||
fastapi==0.109.2
|
||||
fastcore==1.5.29
|
||||
ffmpy==0.4.0
|
||||
filelock==3.12.2
|
||||
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
||||
fire==0.5.0
|
||||
first==2.0.2
|
||||
flake8==7.0.0
|
||||
Flask==3.0.1
|
||||
fonttools==4.47.2
|
||||
frozendict==2.4.1
|
||||
frozenlist==1.3.3
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||
fsspec==2023.6.0
|
||||
fuzzywuzzy==0.18.0
|
||||
gitdb==4.0.10
|
||||
GitPython==3.1.31
|
||||
google-pasta==0.2.0
|
||||
gradio==4.42.0
|
||||
gradio_client==1.3.0
|
||||
greenlet==2.0.2
|
||||
grpclib==0.4.7
|
||||
gunicorn==21.2.0
|
||||
h11==0.14.0
|
||||
h2==4.1.0
|
||||
hpack==4.0.0
|
||||
httpcore==0.17.3
|
||||
httpx==0.24.1
|
||||
huggingface-hub==0.23.4
|
||||
humanfriendly==10.0
|
||||
hyperframe==6.0.1
|
||||
identify==2.5.24
|
||||
idna==3.4
|
||||
immutables==0.20
|
||||
importlib-metadata==6.7.0
|
||||
importlib-resources==6.1.1
|
||||
inflection==0.5.1
|
||||
iniconfig==2.0.0
|
||||
itsdangerous==2.1.2
|
||||
Jinja2==3.1.2
|
||||
jmespath==1.0.1
|
||||
joblib==1.3.2
|
||||
jsonlines==3.1.0
|
||||
jsonschema==2.6.0
|
||||
kiwisolver==1.4.5
|
||||
langchain==0.0.144
|
||||
Levenshtein==0.24.0
|
||||
libcst==1.1.0
|
||||
liger-kernel==0.0.0
|
||||
lion-pytorch==0.1.2
|
||||
llama-cpp-python==0.1.36
|
||||
llvmlite==0.40.1
|
||||
local-attention==1.9.0
|
||||
loguru==0.7.0
|
||||
Mako==1.3.2
|
||||
Markdown==3.5.2
|
||||
markdown-it-py==3.0.0
|
||||
markdown2==2.4.10
|
||||
MarkupSafe==2.1.2
|
||||
marshmallow==3.19.0
|
||||
marshmallow-enum==1.5.1
|
||||
matplotlib==3.8.2
|
||||
mccabe==0.7.0
|
||||
mdurl==0.1.2
|
||||
MEGABYTE-pytorch==0.0.7
|
||||
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
||||
mlflow==2.10.0
|
||||
modal==0.62.77
|
||||
more-itertools==10.2.0
|
||||
mpmath==1.2.1
|
||||
msgpack==1.0.7
|
||||
msgpack-numpy-opentensor==0.5.0
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.14
|
||||
munch==2.5.0
|
||||
mypy==1.3.0
|
||||
mypy-extensions==1.0.0
|
||||
nest-asyncio==1.6.0
|
||||
netaddr==0.10.1
|
||||
networkx==3.0rc1
|
||||
nh3==0.2.14
|
||||
nodeenv==1.8.0
|
||||
nomic==2.0.2
|
||||
numba==0.57.1
|
||||
numexpr==2.8.4
|
||||
numpy==1.24.4
|
||||
oauthlib==3.2.2
|
||||
openai==0.27.4
|
||||
openapi==1.1.0
|
||||
openapi-schema-pydantic==1.2.4
|
||||
optimum==1.8.6
|
||||
orjson==3.10.7
|
||||
packaging==23.1
|
||||
pandas==2.0.0
|
||||
parameterized==0.9.0
|
||||
password-strength==0.0.3.post2
|
||||
pastel==0.1.1
|
||||
pathos==0.3.0
|
||||
pathspec==0.11.1
|
||||
pathtools==0.1.2
|
||||
peft==0.11.1
|
||||
pendulum==3.0.0
|
||||
Pillow==9.5.0
|
||||
pip-tools==1.11.0
|
||||
platformdirs==3.2.0
|
||||
pluggy==1.4.0
|
||||
poetry==0.7.1
|
||||
pox==0.3.2
|
||||
ppft==1.7.6.6
|
||||
pre-commit==3.3.2
|
||||
prettytable==3.10.0
|
||||
prompt-toolkit==3.0.39
|
||||
protobuf==3.20.2
|
||||
protobuf3-to-dict==0.1.5
|
||||
psutil==5.9.5
|
||||
psycopg==3.1.18
|
||||
PuLP==2.8.0
|
||||
py==1.11.0
|
||||
py-bip39-bindings==0.1.11
|
||||
py-cpuinfo==9.0.0
|
||||
py-ed25519-zebra-bindings==1.0.1
|
||||
py-sr25519-bindings==0.2.0
|
||||
pyarrow==11.0.0
|
||||
pyasn1==0.6.0
|
||||
pycodestyle==2.11.1
|
||||
pycparser==2.21
|
||||
pycryptodome==3.20.0
|
||||
pydantic==2.5.3
|
||||
pydantic_core==2.14.6
|
||||
pydub==0.25.1
|
||||
pyfiglet==0.8.post1
|
||||
pyflakes==3.2.0
|
||||
Pygments==2.15.1
|
||||
PyJWT==2.8.0
|
||||
pylev==1.4.0
|
||||
PyNaCl==1.5.0
|
||||
pynvml==11.5.0
|
||||
pyparsing==2.4.7
|
||||
pyrsistent==0.14.11
|
||||
pytest==8.0.2
|
||||
pytest-asyncio==0.23.4
|
||||
python-dateutil==2.8.2
|
||||
python-dotenv==1.0.1
|
||||
python-Levenshtein==0.24.0
|
||||
python-multipart==0.0.9
|
||||
pytz==2023.3
|
||||
PyYAML==6.0.1
|
||||
querystring-parser==1.2.4
|
||||
rapidfuzz==3.6.1
|
||||
regex==2023.6.3
|
||||
requests==2.31.0
|
||||
requests-toolbelt==0.8.0
|
||||
resolvelib==0.8.1
|
||||
responses==0.18.0
|
||||
retry==0.9.2
|
||||
rich==13.7.0
|
||||
rsa==4.7.2
|
||||
ruff==0.6.3
|
||||
s3transfer==0.10.1
|
||||
safetensors==0.4.5
|
||||
sagemaker==2.148.0
|
||||
scalecodec==1.2.7
|
||||
schedulefree==1.2.1
|
||||
schema==0.7.5
|
||||
scikit-learn==1.4.0
|
||||
scipy==1.9.3
|
||||
seaborn==0.13.2
|
||||
semantic-version==2.10.0
|
||||
sentencepiece==0.2.0
|
||||
sentry-sdk==1.19.1
|
||||
setproctitle==1.3.2
|
||||
shellingham==1.5.4
|
||||
shortuuid==1.0.11
|
||||
shtab==1.6.5
|
||||
sigtools==4.0.1
|
||||
six==1.16.0
|
||||
skypilot==0.4.1
|
||||
smdebug-rulesconfig==1.0.1
|
||||
smmap==5.0.0
|
||||
sniffio==1.3.0
|
||||
SQLAlchemy==1.4.47
|
||||
sqlparse==0.4.4
|
||||
starlette==0.36.3
|
||||
substrate-interface==1.5.2
|
||||
svgwrite==1.4.3
|
||||
sympy==1.11.1
|
||||
synchronicity==0.6.7
|
||||
tabulate==0.9.0
|
||||
tblib==1.7.0
|
||||
tenacity==8.2.2
|
||||
tensor-parallel==2.0.0
|
||||
termcolor==2.2.0
|
||||
text2art==0.2.0
|
||||
threadpoolctl==3.2.0
|
||||
tiktoken==0.6.0
|
||||
time-machine==2.14.1
|
||||
timm==0.9.16
|
||||
tokenizers==0.19.1
|
||||
tokenmonster==1.1.12
|
||||
toml==0.9.6
|
||||
tomli==2.0.1
|
||||
tomlkit==0.12.0
|
||||
toolz==0.12.1
|
||||
torch==2.2.0
|
||||
torchdata==0.6.1
|
||||
torchdiffeq==0.2.3
|
||||
TorchFix==0.4.0
|
||||
torchtext==0.15.2
|
||||
torchvision==0.17.0
|
||||
tqdm==4.66.2
|
||||
transformers==4.44.2
|
||||
trl==0.9.6
|
||||
typer==0.12.5
|
||||
types-certifi==2021.10.8.3
|
||||
types-requests==2.31.0.20240125
|
||||
types-setuptools==69.0.0.20240125
|
||||
types-toml==0.10.8.7
|
||||
typing==3.7.4.3
|
||||
typing-inspect==0.8.0
|
||||
typing_extensions==4.9.0
|
||||
tyro==0.5.18
|
||||
tzdata==2023.3
|
||||
unique-names-generator==1.0.2
|
||||
urllib3==2.2.2
|
||||
uvicorn==0.22.0
|
||||
vector_quantize_pytorch==1.14.1
|
||||
virtualenv==20.23.0
|
||||
voyager==2.0.2
|
||||
wandb==0.16.2
|
||||
watchfiles==0.21.0
|
||||
wavedrom==2.0.3.post3
|
||||
wcwidth==0.2.6
|
||||
websocket-client==1.7.0
|
||||
websockets==12.0
|
||||
Werkzeug==3.0.1
|
||||
wonderwords==2.2.0
|
||||
xxhash==3.2.0
|
||||
yarl==1.8.2
|
||||
zetascale==2.2.7
|
||||
zipp==3.15.0
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
helper script to parse chat datasets into a usable yaml
|
||||
"""
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from datasets import load_dataset
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
|
||||
24
setup.py
24
setup.py
@@ -16,13 +16,7 @@ def parse_requirements():
|
||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||
lines = [r.strip() for r in requirements_file.readlines()]
|
||||
for line in lines:
|
||||
is_extras = (
|
||||
"flash-attn" in line
|
||||
or "flash-attention" in line
|
||||
or "deepspeed" in line
|
||||
or "mamba-ssm" in line
|
||||
or "lion-pytorch" in line
|
||||
)
|
||||
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
||||
if line.startswith("--extra-index-url"):
|
||||
# Handle custom index URLs
|
||||
_, url = line.split()
|
||||
@@ -39,7 +33,6 @@ def parse_requirements():
|
||||
"bitsandbytes",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"flash-attn",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"liger-kernel",
|
||||
@@ -124,11 +117,10 @@ setup(
|
||||
],
|
||||
},
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.7.4.post1",
|
||||
],
|
||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.1",
|
||||
"deepspeed==0.16.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
@@ -141,15 +133,15 @@ setup(
|
||||
"mlflow": [
|
||||
"mlflow",
|
||||
],
|
||||
"lion-pytorch": [
|
||||
"lion-pytorch==0.1.2",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
"apollo": [
|
||||
"apollo-torch",
|
||||
],
|
||||
"optimizers": [
|
||||
"galore_torch",
|
||||
"lion-pytorch==0.1.2",
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
],
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
launch axolotl in supported cloud platforms
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
base class for cloud platforms from cli
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Modal Cloud support from CLI
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -56,7 +56,7 @@ def do_inference(
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Inference-specific CLI arguments.
|
||||
"""
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
prompter = cli_args.prompter
|
||||
|
||||
prompter_module = None
|
||||
@@ -151,7 +151,7 @@ def do_inference_gradio(
|
||||
"""
|
||||
import gradio as gr
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||
prompter = cli_args.prompter
|
||||
|
||||
prompter_module = None
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Click CLI definitions for various axolotl commands."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
@@ -24,7 +25,7 @@ from axolotl.cli.utils import (
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
|
||||
@click.group()
|
||||
|
||||
@@ -27,7 +27,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
LOG.info("Running merge of LoRA with base model...")
|
||||
@@ -44,6 +44,9 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
if processor:
|
||||
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
|
||||
@@ -17,13 +17,14 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
"""
|
||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||
@@ -33,6 +34,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
@@ -44,16 +48,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
del model, tokenizer, trainer
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
del trainer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
@@ -14,17 +13,22 @@ from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
import click
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def strip_optional_type(field_type: type | typing._SpecialForm | None):
|
||||
def strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
@@ -296,9 +300,13 @@ def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
@@ -306,7 +314,7 @@ def load_model_and_tokenizer(
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
`transformers` model and tokenizer.
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
@@ -314,4 +322,9 @@ def load_model_and_tokenizer(
|
||||
LOG.info("loading model...")
|
||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||
|
||||
return model, tokenizer
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
||||
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
ChatML transformation functions for MessageContents
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..messages import MessageContents, Messages
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Llama 3.x chat formatting functions for MessageContents
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..messages import MessageContents, Messages
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
shared functions for format transforms
|
||||
"""
|
||||
|
||||
from axolotl.core.chat.messages import MessageContents, Messages
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
internal message representations of chat messages
|
||||
"""
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
chat dataset module
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
||||
"""
|
||||
|
||||
from typing import Any, Mapping, Union
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
"""
|
||||
Builder for the training args and trainer
|
||||
"""
|
||||
"""Builder for the training args and trainer"""
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
@@ -38,7 +36,7 @@ from transformers import (
|
||||
from transformers.training_args import OptimizerNames
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlMambaTrainer,
|
||||
@@ -62,6 +60,7 @@ from axolotl.core.training_args import (
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
@@ -85,8 +84,8 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||
from axolotl.utils.models import ensure_dtype
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
@@ -332,9 +331,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs = {}
|
||||
|
||||
if self.cfg.include_tokens_per_second is not None:
|
||||
training_arguments_kwargs[
|
||||
"include_tokens_per_second"
|
||||
] = self.cfg.include_tokens_per_second
|
||||
training_arguments_kwargs["include_tokens_per_second"] = (
|
||||
self.cfg.include_tokens_per_second
|
||||
)
|
||||
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
@@ -351,13 +350,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing"
|
||||
] = self.cfg.gradient_checkpointing
|
||||
training_arguments_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing_kwargs"
|
||||
] = self.cfg.gradient_checkpointing_kwargs
|
||||
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
@@ -373,9 +372,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
|
||||
if self.cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs[
|
||||
"lr_quadratic_warmup"
|
||||
] = self.cfg.lr_quadratic_warmup
|
||||
training_arguments_kwargs["lr_quadratic_warmup"] = (
|
||||
self.cfg.lr_quadratic_warmup
|
||||
)
|
||||
|
||||
if self.cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||
@@ -399,28 +398,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_pin_memory"
|
||||
] = self.cfg.dataloader_pin_memory
|
||||
training_arguments_kwargs["dataloader_pin_memory"] = (
|
||||
self.cfg.dataloader_pin_memory
|
||||
)
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_num_workers"
|
||||
] = self.cfg.dataloader_num_workers
|
||||
training_arguments_kwargs["dataloader_num_workers"] = (
|
||||
self.cfg.dataloader_num_workers
|
||||
)
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_prefetch_factor"
|
||||
] = self.cfg.dataloader_prefetch_factor
|
||||
training_arguments_kwargs["dataloader_prefetch_factor"] = (
|
||||
self.cfg.dataloader_prefetch_factor
|
||||
)
|
||||
if self.cfg.dataloader_drop_last is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_drop_last"
|
||||
] = self.cfg.dataloader_drop_last
|
||||
training_arguments_kwargs["dataloader_drop_last"] = (
|
||||
self.cfg.dataloader_drop_last
|
||||
)
|
||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_arguments_kwargs[
|
||||
"remove_unused_columns"
|
||||
] = self.cfg.remove_unused_columns
|
||||
training_arguments_kwargs["remove_unused_columns"] = (
|
||||
self.cfg.remove_unused_columns
|
||||
)
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
@@ -452,9 +451,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.do_causal_lm_eval:
|
||||
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
||||
if self.cfg.metric_for_best_model:
|
||||
training_arguments_kwargs[
|
||||
"metric_for_best_model"
|
||||
] = self.cfg.metric_for_best_model
|
||||
training_arguments_kwargs["metric_for_best_model"] = (
|
||||
self.cfg.metric_for_best_model
|
||||
)
|
||||
if self.cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||
|
||||
@@ -467,13 +466,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_backend"
|
||||
] = self.cfg.torch_compile_backend
|
||||
training_arguments_kwargs["torch_compile_backend"] = (
|
||||
self.cfg.torch_compile_backend
|
||||
)
|
||||
if self.cfg.torch_compile_mode:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_mode"
|
||||
] = self.cfg.torch_compile_mode
|
||||
training_arguments_kwargs["torch_compile_mode"] = (
|
||||
self.cfg.torch_compile_mode
|
||||
)
|
||||
|
||||
# DDP Config
|
||||
if self.cfg.ddp_timeout:
|
||||
@@ -482,32 +481,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||
if self.cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs[
|
||||
"ddp_broadcast_buffers"
|
||||
] = self.cfg.ddp_broadcast_buffers
|
||||
training_arguments_kwargs["ddp_broadcast_buffers"] = (
|
||||
self.cfg.ddp_broadcast_buffers
|
||||
)
|
||||
|
||||
# these are all the "standard" kwargs that are def used
|
||||
training_arguments_kwargs["max_steps"] = (
|
||||
total_num_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||
training_arguments_kwargs[
|
||||
"per_device_train_batch_size"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["per_device_train_batch_size"] = (
|
||||
self.cfg.micro_batch_size
|
||||
)
|
||||
if self.cfg.eval_batch_size:
|
||||
training_arguments_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
training_arguments_kwargs["per_device_eval_batch_size"] = (
|
||||
self.cfg.eval_batch_size
|
||||
)
|
||||
if self.cfg.auto_find_batch_size is not None:
|
||||
training_arguments_kwargs[
|
||||
"auto_find_batch_size"
|
||||
] = self.cfg.auto_find_batch_size
|
||||
training_arguments_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs[
|
||||
"eval_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs["auto_find_batch_size"] = (
|
||||
self.cfg.auto_find_batch_size
|
||||
)
|
||||
training_arguments_kwargs["gradient_accumulation_steps"] = (
|
||||
self.cfg.gradient_accumulation_steps
|
||||
)
|
||||
training_arguments_kwargs["eval_accumulation_steps"] = (
|
||||
self.cfg.gradient_accumulation_steps
|
||||
)
|
||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||
@@ -554,9 +553,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_arguments_kwargs[
|
||||
"alternate_lr_scheduler_type"
|
||||
] = self.cfg.lr_scheduler
|
||||
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
)
|
||||
else:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||
@@ -565,9 +564,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"cosine_constant_lr_ratio"
|
||||
] = self.cfg.cosine_constant_lr_ratio
|
||||
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
|
||||
self.cfg.cosine_constant_lr_ratio
|
||||
)
|
||||
training_arguments_kwargs["weight_decay"] = (
|
||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||
)
|
||||
@@ -580,40 +579,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.eval_sample_packing
|
||||
)
|
||||
if self.cfg.sample_packing_bin_size is not None:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_bin_size"
|
||||
] = self.cfg.sample_packing_bin_size
|
||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||
self.cfg.sample_packing_bin_size
|
||||
)
|
||||
if self.cfg.sample_packing_group_size is not None:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_group_size"
|
||||
] = self.cfg.sample_packing_group_size
|
||||
training_arguments_kwargs["sample_packing_group_size"] = (
|
||||
self.cfg.sample_packing_group_size
|
||||
)
|
||||
if self.cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = self.cfg.sample_packing_eff_est
|
||||
training_arguments_kwargs["sample_packing_efficiency"] = (
|
||||
self.cfg.sample_packing_eff_est
|
||||
)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs[
|
||||
"relora_warmup_steps"
|
||||
] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = (
|
||||
self.cfg.relora_warmup_steps
|
||||
)
|
||||
if self.cfg.relora_anneal_steps:
|
||||
training_arguments_kwargs[
|
||||
"relora_anneal_steps"
|
||||
] = self.cfg.relora_anneal_steps
|
||||
training_arguments_kwargs["relora_anneal_steps"] = (
|
||||
self.cfg.relora_anneal_steps
|
||||
)
|
||||
if self.cfg.relora_prune_ratio:
|
||||
training_arguments_kwargs[
|
||||
"relora_prune_ratio"
|
||||
] = self.cfg.relora_prune_ratio
|
||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||
self.cfg.relora_prune_ratio
|
||||
)
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||
training_arguments_kwargs[
|
||||
"lisa_step_interval"
|
||||
] = self.cfg.lisa_step_interval
|
||||
training_arguments_kwargs[
|
||||
"lisa_layers_attribute"
|
||||
] = self.cfg.lisa_layers_attribute
|
||||
training_arguments_kwargs["lisa_step_interval"] = (
|
||||
self.cfg.lisa_step_interval
|
||||
)
|
||||
training_arguments_kwargs["lisa_layers_attribute"] = (
|
||||
self.cfg.lisa_layers_attribute
|
||||
)
|
||||
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
@@ -627,9 +626,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
"neftune_noise_alpha"
|
||||
] = self.cfg.neftune_noise_alpha
|
||||
training_arguments_kwargs["neftune_noise_alpha"] = (
|
||||
self.cfg.neftune_noise_alpha
|
||||
)
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
@@ -731,24 +730,30 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
if self.cfg.optim_target_modules:
|
||||
training_arguments_kwargs[
|
||||
"optim_target_modules"
|
||||
] = self.cfg.optim_target_modules
|
||||
training_arguments_kwargs["optim_target_modules"] = (
|
||||
self.cfg.optim_target_modules
|
||||
)
|
||||
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["loraplus_lr_embedding"] = (
|
||||
self.cfg.loraplus_lr_embedding
|
||||
)
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs[
|
||||
"accelerator_config"
|
||||
] = self.cfg.accelerator_config
|
||||
training_arguments_kwargs["accelerator_config"] = (
|
||||
self.cfg.accelerator_config
|
||||
)
|
||||
|
||||
if self.cfg.image_size:
|
||||
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||
if self.cfg.image_resize_algorithm:
|
||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||
self.cfg.image_resize_algorithm
|
||||
)
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
if self.cfg.kd_alpha is not None:
|
||||
@@ -756,13 +761,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.kd_temperature is not None:
|
||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||
if self.cfg.kd_zscore_base_temp is not None:
|
||||
training_arguments_kwargs[
|
||||
"kd_zscore_base_temp"
|
||||
] = self.cfg.kd_zscore_base_temp
|
||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
||||
self.cfg.kd_zscore_base_temp
|
||||
)
|
||||
if self.cfg.kd_top_k_before_softmax is not None:
|
||||
training_arguments_kwargs[
|
||||
"kd_top_k_before_softmax"
|
||||
] = self.cfg.kd_top_k_before_softmax
|
||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||
self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
@@ -847,9 +856,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if self.cfg.pretraining_sample_concatenation is False:
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
if self.cfg.micro_batch_size > 1:
|
||||
if (
|
||||
self.cfg.pretraining_sample_concatenation is False
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
return None
|
||||
|
||||
@@ -877,9 +887,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if "max_length" in kwargs:
|
||||
kwargs.pop("max_length")
|
||||
elif use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
):
|
||||
@@ -889,8 +897,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
if self.cfg.processor_type and self.processor:
|
||||
collator = MultiModalChatDataCollator
|
||||
kwargs["processor"] = self.processor
|
||||
kwargs["chat_template"] = training_args.chat_template
|
||||
kwargs["processing_strategy"] = get_processing_strategy(
|
||||
self.processor,
|
||||
training_args.chat_template,
|
||||
self.cfg.chat_template,
|
||||
image_size=training_args.image_size,
|
||||
image_resize_algorithm=training_args.image_resize_algorithm,
|
||||
)
|
||||
elif self.cfg.batch_flattening:
|
||||
collator = DataCollatorWithFlattening
|
||||
collator_args.pop(0)
|
||||
@@ -910,6 +923,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
kwargs["return_tensors"] = "pt"
|
||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||
|
||||
return collator(
|
||||
*collator_args,
|
||||
@@ -972,32 +987,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_args_kwargs[
|
||||
"remove_unused_columns"
|
||||
] = self.cfg.remove_unused_columns
|
||||
training_args_kwargs["remove_unused_columns"] = (
|
||||
self.cfg.remove_unused_columns
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["remove_unused_columns"] = False
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_args_kwargs[
|
||||
"dataloader_pin_memory"
|
||||
] = self.cfg.dataloader_pin_memory
|
||||
training_args_kwargs["dataloader_pin_memory"] = (
|
||||
self.cfg.dataloader_pin_memory
|
||||
)
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_args_kwargs[
|
||||
"dataloader_num_workers"
|
||||
] = self.cfg.dataloader_num_workers
|
||||
training_args_kwargs["dataloader_num_workers"] = (
|
||||
self.cfg.dataloader_num_workers
|
||||
)
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_args_kwargs[
|
||||
"dataloader_prefetch_factor"
|
||||
] = self.cfg.dataloader_prefetch_factor
|
||||
training_args_kwargs["dataloader_prefetch_factor"] = (
|
||||
self.cfg.dataloader_prefetch_factor
|
||||
)
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs[
|
||||
"gradient_checkpointing"
|
||||
] = self.cfg.gradient_checkpointing
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||
training_args_kwargs[
|
||||
"gradient_checkpointing_kwargs"
|
||||
] = self.cfg.gradient_checkpointing_kwargs
|
||||
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||
"use_reentrant": False
|
||||
@@ -1071,9 +1086,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs[
|
||||
"use_logits_to_keep"
|
||||
] = self.cfg.dpo_use_logits_to_keep
|
||||
training_args_kwargs["use_logits_to_keep"] = (
|
||||
self.cfg.dpo_use_logits_to_keep
|
||||
)
|
||||
|
||||
for blocklist_key in blocklist_args_kwargs:
|
||||
if blocklist_key in training_args_kwargs:
|
||||
@@ -1108,9 +1123,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
if self.cfg.rl == "grpo":
|
||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Init for axolotl.core.trainers"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -1,365 +1,47 @@
|
||||
"""
|
||||
module for customized trainers
|
||||
"""
|
||||
"""Module for customized trainers"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Dict, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
RandomSampler,
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
from axolotl.core.trainers.mixins import (
|
||||
OptimizerMixin,
|
||||
SchedulerMixin,
|
||||
SequenceParallelMixin,
|
||||
)
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
extra_lr_kwargs = {}
|
||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["pct_start"] = pct_start
|
||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||
if use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class OptimizerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for shared handling of building custom optimizers
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_optimizer_grouped_parameters(
|
||||
self, opt_model, optimizer_kwargs
|
||||
) -> list[dict]:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params: dict = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
lr_groups_lookup = {}
|
||||
lr_groups_learning_rates = {}
|
||||
if self.args.lr_groups:
|
||||
for lr_group in self.args.lr_groups:
|
||||
group_name = lr_group["name"]
|
||||
group_modules = lr_group["modules"]
|
||||
for module in group_modules:
|
||||
lr_groups_lookup[module] = group_name
|
||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||
params[f"to_weight_decay_{group_name}"] = {}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
lr_group_modules = [
|
||||
group_modules
|
||||
for group_modules in lr_groups_lookup
|
||||
if group_modules in name
|
||||
]
|
||||
if lr_groups_lookup and any(lr_group_modules):
|
||||
lr_group_module = lr_group_modules[0]
|
||||
group_name = lr_groups_lookup[lr_group_module]
|
||||
params[f"to_weight_decay_{group_name}"][name] = param
|
||||
else:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||
if params[f"to_weight_decay_{group_name}"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(
|
||||
params[f"to_weight_decay_{group_name}"].values()
|
||||
),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
and self.optimizer_cls_and_kwargs is not None
|
||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||
):
|
||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
self.optimizer = optimizer_factory_cls()(
|
||||
opt_model, self.args, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if not self.optimizer:
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||
self.args, opt_model
|
||||
)
|
||||
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", 1e-6
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||
"optimizer_dict"
|
||||
)
|
||||
|
||||
self.optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
tag_names = ["axolotl"]
|
||||
@@ -376,12 +58,18 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
self._signature_columns = None # workaround for pylint
|
||||
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Initialize sequence parallelism if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._setup_sequence_parallel()
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
@@ -394,142 +82,247 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
def _create_multipack_sampler(
|
||||
self, base_sampler: Sampler, dataset: Dataset
|
||||
) -> MultipackBatchSampler:
|
||||
"""
|
||||
Helper method to create a `MultipackBatchSampler` for multipacking sequences
|
||||
for training.
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
sampler = SequentialSampler(self.train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(self.train_dataset)
|
||||
Args:
|
||||
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
|
||||
dataset: Dataset to sample from.
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
Returns:
|
||||
Multipack (sample packing) batch sampler.
|
||||
"""
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
train_batch_size = (
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
if self.args.curriculum_sampling:
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return super()._get_train_sampler()
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
lengths=get_dataset_lengths(self.eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
group_size=self.args.sample_packing_group_size,
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
drop_last=True,
|
||||
return MultipackBatchSampler(
|
||||
base_sampler,
|
||||
lengths=get_dataset_lengths(dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
def _get_train_sampler(self) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for training. Handles cases for sequence
|
||||
parallelism, sample packing, and curriculum sampling (sequential).
|
||||
|
||||
Returns:
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||
|
||||
# Determine the base sampler first
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||
elif self.args.curriculum_sampling:
|
||||
base_sampler = SequentialSampler(self.train_dataset)
|
||||
elif use_sample_packing:
|
||||
base_sampler = RandomSampler(self.train_dataset)
|
||||
else:
|
||||
# Default to parent class implementation for standard random sampling
|
||||
return super()._get_train_sampler()
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if use_sample_packing:
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=self.train_dataset,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
if "length" in train_dataset.features.keys():
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
return base_sampler
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||
and sample packing cases.
|
||||
|
||||
Returns:
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||
use_multipack = (
|
||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||
)
|
||||
|
||||
# Determine the base sampler
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||
elif use_multipack:
|
||||
base_sampler = SequentialSampler(eval_dataset)
|
||||
else:
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if use_multipack:
|
||||
return self._create_multipack_sampler(
|
||||
base_sampler=base_sampler,
|
||||
dataset=eval_dataset,
|
||||
)
|
||||
|
||||
return base_sampler
|
||||
|
||||
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||
"""Create common dataloader parameters for train or eval."""
|
||||
batch_size = custom_batch_size or (
|
||||
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||
)
|
||||
|
||||
params = {
|
||||
"batch_size": batch_size,
|
||||
"collate_fn": self.data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
# Add persistent workers only for training
|
||||
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||
|
||||
# Add prefetch factor if specified
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return params
|
||||
|
||||
def _prepare_dataloader(
|
||||
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||
):
|
||||
"""Prepare a dataloader with the given dataset and sampler."""
|
||||
# Get base parameters
|
||||
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||
|
||||
# Add sampler configuration
|
||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
if isinstance(sampler, BatchSampler):
|
||||
# batch_size and batch_sampler are mutually exclusive
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
if not is_eval:
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
if self.args.sample_packing and (
|
||||
(not is_eval and not self.args.pretraining)
|
||||
or (is_eval and self.args.eval_sample_packing is not False)
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Get dataloader for training"""
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator # type: ignore
|
||||
|
||||
# Handle dataset preprocessing
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
if not self.args.sample_packing or self.args.pretraining:
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
data_collator,
|
||||
description="training",
|
||||
)
|
||||
|
||||
# Get sampler and create dataloader
|
||||
sampler = self._get_train_sampler()
|
||||
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||
"""Get dataloader for evaluation"""
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
)
|
||||
if eval_dataset:
|
||||
if "length" in eval_dataset.column_names:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.train_data_collator
|
||||
)
|
||||
|
||||
return dataloader
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
# Handle sample packing or sequence parallelism
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and self.args.eval_sample_packing is not False
|
||||
or self.args.sequence_parallel_degree > 1
|
||||
):
|
||||
# Get appropriate data collator
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
||||
else self.data_collator
|
||||
)
|
||||
if "length" in eval_dataset.column_names:
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
|
||||
# Handle dataset preprocessing for SP
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if isinstance(eval_dataset, datasets.Dataset):
|
||||
eval_dataset = self._remove_unused_columns(
|
||||
eval_dataset, description="evaluation"
|
||||
)
|
||||
else:
|
||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||
self.data_collator, description="evaluation"
|
||||
)
|
||||
|
||||
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||
batch_size = (
|
||||
self.args.eval_batch_size
|
||||
if self.args.sample_packing
|
||||
else self.args.per_device_eval_batch_size
|
||||
)
|
||||
sampler = self._get_eval_sampler(eval_dataset)
|
||||
dataloader = self._prepare_dataloader(
|
||||
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
return dataloader
|
||||
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
) -> torch.utils.data.Sampler | None:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
@@ -554,6 +347,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
@@ -570,6 +364,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
@@ -744,10 +539,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@@ -764,15 +559,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
return res
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
start_time (`Optional[float]`):
|
||||
The start of training.
|
||||
logs: The values to log.
|
||||
start_time: The start of training.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
@@ -784,7 +577,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
) -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
@@ -797,110 +590,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Mamba specific trainer to handle loss calculation
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "mamba"]
|
||||
|
||||
def compute_loss(
|
||||
def training_step(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
model: nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
num_items_in_batch: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs. Overrides the
|
||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||
enabled.
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
Args:
|
||||
model: Model to perform training step for.
|
||||
inputs: Dictionary mapping.
|
||||
"""
|
||||
# Set up sequence parallelism for this step if enabled
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
self._update_ring_flash_attn_params(inputs)
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
# Proceed with normal training step
|
||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
return loss
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
DPO trainer for axolotl
|
||||
"""
|
||||
|
||||
import gc
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Union
|
||||
@@ -12,10 +13,10 @@ from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
SchedulerMixin,
|
||||
_sanitize_kwargs_for_ds_tagging,
|
||||
_sanitize_kwargs_for_tagging,
|
||||
from axolotl.core.trainers.mixins import SchedulerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
@@ -73,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -45,9 +45,9 @@ class GRPOStrategy:
|
||||
)
|
||||
|
||||
if trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = trl.vllm_gpu_memory_utilization
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
trl.vllm_gpu_memory_utilization
|
||||
)
|
||||
|
||||
if trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||
@@ -86,9 +86,9 @@ class GRPOStrategy:
|
||||
def set_trainer_kwargs(cls, cfg):
|
||||
trainer_kwargs = {}
|
||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||
trainer_kwargs[
|
||||
"reward_processing_classes"
|
||||
] = cfg.trl.reward_processing_classes
|
||||
trainer_kwargs["reward_processing_classes"] = (
|
||||
cfg.trl.reward_processing_classes
|
||||
)
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Axolotl Specific Training Args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import GRPOConfig
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
32
src/axolotl/core/trainers/mamba.py
Normal file
32
src/axolotl/core/trainers/mamba.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Module for mamba trainer"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""Mamba specific trainer to handle loss calculation"""
|
||||
|
||||
tag_names = ["axolotl", "mamba"]
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
|
||||
labels = input_ids.to(lm_logits.device)
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
return lm_loss
|
||||
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Init for axolotl.core.trainers.mixins"""
|
||||
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .optimizer import OptimizerMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
from .sequence_parallel import SequenceParallelMixin
|
||||
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Module for Axolotl trainer optimizer mixin"""
|
||||
|
||||
import logging
|
||||
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OptimizerMixin(Trainer):
|
||||
"""Mixin class for shared handling of building custom optimizers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_optimizer_grouped_parameters(
|
||||
self, opt_model, optimizer_kwargs
|
||||
) -> list[dict]:
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
params: dict = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
lr_groups_lookup = {}
|
||||
lr_groups_learning_rates = {}
|
||||
if self.args.lr_groups:
|
||||
for lr_group in self.args.lr_groups:
|
||||
group_name = lr_group["name"]
|
||||
group_modules = lr_group["modules"]
|
||||
for module in group_modules:
|
||||
lr_groups_lookup[module] = group_name
|
||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||
params[f"to_weight_decay_{group_name}"] = {}
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
lr_group_modules = [
|
||||
group_modules
|
||||
for group_modules in lr_groups_lookup
|
||||
if group_modules in name
|
||||
]
|
||||
if lr_groups_lookup and any(lr_group_modules):
|
||||
lr_group_module = lr_group_modules[0]
|
||||
group_name = lr_groups_lookup[lr_group_module]
|
||||
params[f"to_weight_decay_{group_name}"][name] = param
|
||||
else:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||
if params[f"to_weight_decay_{group_name}"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(
|
||||
params[f"to_weight_decay_{group_name}"].values()
|
||||
),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.lr_groups is None
|
||||
and self.optimizer_cls_and_kwargs is None
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
|
||||
if (
|
||||
not self.optimizer
|
||||
and self.optimizer_cls_and_kwargs is not None
|
||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||
):
|
||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
self.optimizer = optimizer_factory_cls()(
|
||||
opt_model, self.args, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if not self.optimizer:
|
||||
if self.optimizer_cls_and_kwargs is not None:
|
||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||
else:
|
||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||
self.args, opt_model
|
||||
)
|
||||
|
||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||
opt_model, optimizer_kwargs
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", 1e-6
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
else:
|
||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for GaLore optimizer.
|
||||
if "params" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||
|
||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||
# e.g. for LOMO optimizer.
|
||||
if "model" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||
|
||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||
# to avoid arguments conflicts.
|
||||
if "optimizer_dict" in optimizer_kwargs:
|
||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||
"optimizer_dict"
|
||||
)
|
||||
|
||||
self.optimizer = optimizer_cls(
|
||||
optimizer_grouped_parameters, **optimizer_kwargs
|
||||
)
|
||||
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Module for Axolotl trainer scheduler mixin"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
use_cosine_quadratic = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
)
|
||||
|
||||
use_cosine_min_lr = (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.cosine_min_lr_ratio is not None
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
extra_lr_kwargs = {}
|
||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["pct_start"] = pct_start
|
||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||
if use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||
)
|
||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler
|
||||
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from ring_flash_attn import update_ring_flash_attn_params
|
||||
except ImportError:
|
||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||
pass
|
||||
|
||||
|
||||
class SequenceParallelMixin:
|
||||
"""
|
||||
Mixin class for sequence parallelism support in trainers.
|
||||
|
||||
This mixin provides functionality for handling sequence parallelism,
|
||||
including creating appropriate samplers, managing data partitioning,
|
||||
and updating ring flash attention parameters during training.
|
||||
"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
|
||||
def _setup_sequence_parallel(self):
|
||||
"""Set up sequence parallelism environment."""
|
||||
self.ring_attn_group = get_ring_attn_group()
|
||||
|
||||
def _create_sequence_parallel_sampler(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
shuffle: bool = True,
|
||||
is_eval: bool = False,
|
||||
) -> DistributedSampler:
|
||||
"""
|
||||
Helper method to create sampler for sequence parallelism (SP).
|
||||
|
||||
We create a distributed sampler with rank equal to the SP group ID, which
|
||||
means that all ranks in the SP group receive the same sample / set of samples
|
||||
per training step. We also set the number of replicas equal to the number of
|
||||
SP groups, which is a bit of a hack / unintended use, but works!
|
||||
|
||||
Args:
|
||||
dataset: Dataset to sample from.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||
|
||||
Returns:
|
||||
Distributed sampler.
|
||||
"""
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
return DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=num_sp_groups,
|
||||
rank=sp_group_id,
|
||||
seed=self.args.seed if shuffle else None,
|
||||
shuffle=shuffle,
|
||||
drop_last=not is_eval,
|
||||
)
|
||||
|
||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||
"""
|
||||
Get a training sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
dataset: The training dataset
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self._create_sequence_parallel_sampler(
|
||||
dataset,
|
||||
shuffle=not self.args.curriculum_sampling,
|
||||
)
|
||||
|
||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||
"""
|
||||
Get an evaluation sampler configured for sequence parallelism.
|
||||
|
||||
Args:
|
||||
eval_dataset: The evaluation dataset.
|
||||
|
||||
Returns:
|
||||
Configured sequence parallel sampler.
|
||||
"""
|
||||
return self._create_sequence_parallel_sampler(
|
||||
eval_dataset, shuffle=False, is_eval=True
|
||||
)
|
||||
|
||||
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
||||
"""
|
||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||
the substituted ring_flash_attn. This is accomplished by using the passed
|
||||
`input_ids`.
|
||||
|
||||
Args:
|
||||
inputs: Current batch of inputs.
|
||||
"""
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
packed_seq_lens = [seq_len] * batch_size
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||
|
||||
cu_seqlens = torch.cumsum(
|
||||
torch.tensor(
|
||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||
),
|
||||
dim=-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = F.pad(
|
||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||
)
|
||||
|
||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||
43
src/axolotl/core/trainers/relora.py
Normal file
43
src/axolotl/core/trainers/relora.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Module for ReLoRA trainer"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
@@ -1,15 +1,25 @@
|
||||
"""
|
||||
module for TRL PPO training
|
||||
"""
|
||||
"""Module for TRL PPO trainer"""
|
||||
|
||||
from typing import Literal, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from trl import PPOTrainer
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
"""
|
||||
wrapper for ppo trainer to handle customizations
|
||||
"""
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -30,9 +40,7 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
"batch_size": 16,
|
||||
}
|
||||
|
||||
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
||||
enumerate(self.dataloader)
|
||||
):
|
||||
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# generate model response
|
||||
@@ -64,3 +72,189 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
rewards,
|
||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
|
||||
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
|
||||
|
||||
metrics = {}
|
||||
|
||||
forward_output = self.concatenated_forward(model, batch)
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_nll_loss,
|
||||
) = forward_output[:5]
|
||||
if self.aux_loss_enabled:
|
||||
aux_loss = forward_output[5]
|
||||
|
||||
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
|
||||
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
||||
)
|
||||
# full ORPO loss
|
||||
loss = policy_nll_loss - losses.mean()
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
|
||||
chosen_rewards
|
||||
).mean()
|
||||
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
|
||||
rejected_rewards
|
||||
).mean()
|
||||
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
|
||||
reward_accuracies
|
||||
).mean()
|
||||
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
||||
chosen_rewards - rejected_rewards
|
||||
).mean()
|
||||
metrics[f"{prefix}logps/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
||||
)
|
||||
metrics[f"{prefix}logps/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
||||
)
|
||||
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
|
||||
policy_rejected_logits.detach().mean()
|
||||
).mean()
|
||||
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
|
||||
policy_chosen_logits.detach().mean()
|
||||
).mean()
|
||||
metrics[f"{prefix}nll_loss"] = (
|
||||
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
||||
)
|
||||
metrics[f"{prefix}log_odds_ratio"] = (
|
||||
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
|
||||
)
|
||||
metrics[f"{prefix}log_odds_chosen"] = (
|
||||
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
|
||||
)
|
||||
for k, v in metrics.items():
|
||||
metrics[k] = v.item()
|
||||
if self.aux_loss_enabled:
|
||||
loss += self.aux_loss_coef * aux_loss
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
|
||||
forward_output = self.concatenated_forward(model, batch)
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_nll_loss,
|
||||
) = forward_output[:5]
|
||||
if self.aux_loss_enabled:
|
||||
aux_loss = forward_output[5]
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
)
|
||||
|
||||
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
||||
)
|
||||
metrics[f"{prefix}rewards/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
||||
)
|
||||
metrics[f"{prefix}rewards/accuracies"] = (
|
||||
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
||||
)
|
||||
metrics[f"{prefix}rewards/margins"] = (
|
||||
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}logps/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(policy_rejected_logps)
|
||||
.detach()
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}logps/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(policy_chosen_logps)
|
||||
.detach()
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}logits/rejected"] = (
|
||||
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}logits/chosen"] = (
|
||||
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
metrics[f"{prefix}nll_loss"] = (
|
||||
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
||||
)
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
loss += self.aux_loss_coef * aux_loss
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||
"""
|
||||
Extend the base trl.PRMTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "prm"]
|
||||
|
||||
33
src/axolotl/core/trainers/utils.py
Normal file
33
src/axolotl/core/trainers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Utils for Axolotl trainers"""
|
||||
|
||||
|
||||
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
if isinstance(tag_names, str):
|
||||
tag_names = [tag_names]
|
||||
|
||||
if kwargs is not None:
|
||||
if "tags" not in kwargs:
|
||||
kwargs["tags"] = tag_names
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||
kwargs["tags"].extend(tag_names)
|
||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||
tag_names.append(kwargs["tags"])
|
||||
kwargs["tags"] = tag_names
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
@@ -1,9 +1,11 @@
|
||||
"""
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Resampling
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
@@ -206,14 +208,33 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
sequence_parallel_degree: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The size of the image to resize to"},
|
||||
)
|
||||
|
||||
image_resize_algorithm: Resampling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The algorithm to use for image resizing"},
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
Training arguments for Causal trainer
|
||||
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||
so it can't be used as a mixin.
|
||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
||||
default value so it can't be used as a mixin.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
@@ -25,18 +27,18 @@ LOG = get_logger("axolotl.evaluate")
|
||||
|
||||
|
||||
def evaluate_dataset(
|
||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
|
||||
) -> Optional[Dict[str, float]]:
|
||||
"""Helper function to evaluate a single dataset safely.
|
||||
"""Helper function to evaluate a single dataset.
|
||||
|
||||
Args:
|
||||
trainer: The trainer instance
|
||||
dataset: Dataset to evaluate
|
||||
dataset_type: Type of dataset ('train' or 'eval')
|
||||
flash_optimum: Whether to use flash optimum
|
||||
trainer: The trainer instance.
|
||||
dataset: Dataset to evaluate.
|
||||
dataset_type: Type of dataset ('train' or 'eval').
|
||||
flash_optimum: Whether to use flash optimum.
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics or None if dataset is None
|
||||
Dictionary of metrics or None if dataset is None.
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
@@ -63,17 +65,14 @@ def evaluate_dataset(
|
||||
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets
|
||||
Evaluate a model on training and validation datasets.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- The model (either PeftModel or PreTrainedModel)
|
||||
- The tokenizer
|
||||
- Dictionary of evaluation metrics
|
||||
Dictionary mapping metric names to their values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
|
||||
@@ -11,19 +11,17 @@
|
||||
# the License.
|
||||
|
||||
"""
|
||||
module to handle merging the plugins' input arguments with the base configurations.
|
||||
Module to handle merging the plugins' input arguments with the base configurations.
|
||||
|
||||
this was moved here to prevent circular imports
|
||||
This was moved here to prevent circular imports.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||
|
||||
|
||||
def merge_input_args():
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Cut Cross Entropy
|
||||
|
||||
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
||||
Cut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
||||
|
||||
See https://github.com/apple/ml-cross-entropy
|
||||
|
||||
@@ -29,6 +29,20 @@ plugins:
|
||||
cut_cross_entropy: true
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
- llama
|
||||
- phi3
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3
|
||||
- gemma3_text
|
||||
- mistral
|
||||
- mistral3
|
||||
- qwen2
|
||||
- cohere
|
||||
- cohere2
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
|
||||
@@ -25,8 +25,8 @@ import torch
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.distributed import zero_only
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||
@@ -72,7 +72,9 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
if cfg.cut_cross_entropy:
|
||||
self._check_requirements()
|
||||
|
||||
from cut_cross_entropy.transformers import cce_patch
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
||||
cce_patch,
|
||||
)
|
||||
|
||||
with zero_only():
|
||||
LOG.info(
|
||||
|
||||
201
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
201
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Cohere and Cohere2 CCE patch."""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
||||
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
||||
# operation is done internally and should be mathematically equivalent.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
_CONFIG_FOR_DOC,
|
||||
COHERE_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||
|
||||
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
|
||||
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>> # Generate
|
||||
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
# scale weight by logit_scale in-place of logits
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight * self.logit_scale,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
logits = logits * self.logit_scale # main diff from Llama
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_cohere(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere import modeling_cohere
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere.CohereForCausalLM
|
||||
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_cohere2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere2 import modeling_cohere2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
||||
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
175
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
175
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Gemma CCE patch"""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma.GemmaForCausalLM
|
||||
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
459
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
459
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
||||
|
||||
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
||||
# and updated for transformers 4.50.0.
|
||||
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
||||
# with both gemma3 (text and multimodal) models.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache, HybridCache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
_CONFIG_FOR_DOC,
|
||||
GEMMA3_INPUTS_DOCSTRING,
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||
|
||||
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logits = logits / self.config.final_logit_softcapping
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.config.final_logit_softcapping
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"answer en Where is the cow standing?\nbeach"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids # type: ignore
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
||||
)
|
||||
cache_position = torch.arange( # type: ignore
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(
|
||||
self.config.image_token_index,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
# mask out pad-token-ids in labels for BC
|
||||
if labels is not None and self.pad_token_id in labels:
|
||||
logger.warning_once(
|
||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||
)
|
||||
labels = torch.where( # type: ignore
|
||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||
)
|
||||
|
||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
inputs_embeds,
|
||||
is_training,
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:]
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = shift_logits[
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = shift_labels[
|
||||
shift_attention_mask.to(shift_labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = shift_logits.contiguous()
|
||||
shift_labels = shift_labels.contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||
loss = loss_fct(flat_logits, flat_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Gemma3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
||||
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3_text(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
||||
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
||||
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -0,0 +1,392 @@
|
||||
"""Mistral and Mistral3 CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mistral3.modeling_mistral3 import (
|
||||
Mistral3CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
_CONFIG_FOR_DOC,
|
||||
MISTRAL_INPUTS_DOCSTRING,
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_torchdynamo_compiling,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] | None = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||
|
||||
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor | None = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
|
||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = logits[..., :-1, :][
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = labels[..., 1:][
|
||||
shift_attention_mask.to(labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1).to(shift_logits.device),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Mistral3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_mistral(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral.MistralForCausalLM
|
||||
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_mistral3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
from transformers.models.mistral3 import modeling_mistral3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
||||
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
379
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
379
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Mllama CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mllama.modeling_mllama import (
|
||||
MLLAMA_INPUTS_DOCSTRING,
|
||||
_prepare_cross_attention_mask,
|
||||
)
|
||||
from transformers.utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
|
||||
)
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.LongTensor] = None,
|
||||
cross_attention_mask: Optional[torch.LongTensor] = None,
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
||||
|
||||
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
||||
|
||||
>>> prompt = "If I had to write a haiku, it would be:"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
||||
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
>>> print(result)
|
||||
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
||||
I love the idea of snowflakes gently falling, each one
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
|
||||
)
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
||||
|
||||
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
||||
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
||||
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
||||
|
||||
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> output = model.generate(**inputs, max_new_tokens=15)
|
||||
|
||||
>>> prompt_len = inputs.input_ids.shape[-1]
|
||||
>>> generated_ids = output[:, prompt_len:]
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
>>> print(generated_text)
|
||||
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and cross_attention_states is not None:
|
||||
raise ValueError(
|
||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if aspect_ratio_ids is None:
|
||||
raise ValueError(
|
||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||
)
|
||||
# get vision tokens from vision model
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
cross_attention_states = vision_outputs[0]
|
||||
cross_attention_states = self.multi_modal_projector(
|
||||
cross_attention_states
|
||||
).reshape(
|
||||
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
||||
)
|
||||
|
||||
if cross_attention_mask is not None:
|
||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
||||
_prepare_cross_attention_mask(
|
||||
cross_attention_mask,
|
||||
num_vision_tokens=self.vision_model.num_patches,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_text_row_masked_out_mask = None
|
||||
|
||||
if cross_attention_mask is not None and cache_position is not None:
|
||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
||||
:, :, cache_position
|
||||
]
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
else:
|
||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
||||
logits = hidden_states
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (loss,) + outputs if loss is not None else outputs
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=outputs.logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_mllama(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mllama import modeling_mllama
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
||||
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the language model
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
||||
|
||||
# patch the causal language model
|
||||
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Cut Cross Entropy patcher"""
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
||||
from cut_cross_entropy.transformers.llama import patch_llama
|
||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
||||
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
||||
patch_cohere,
|
||||
patch_cohere2,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
||||
patch_gemma2,
|
||||
patch_gemma3,
|
||||
patch_gemma3_text,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
||||
patch_mistral,
|
||||
patch_mistral3,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
||||
|
||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
||||
"llama": patch_llama,
|
||||
"mllama": patch_mllama,
|
||||
"phi3": patch_phi3,
|
||||
"gemma": patch_gemma,
|
||||
"gemma2": patch_gemma2,
|
||||
"gemma3": patch_gemma3,
|
||||
"gemma3_text": patch_gemma3_text,
|
||||
"mistral": patch_mistral,
|
||||
"mistral3": patch_mistral3,
|
||||
"qwen2": patch_qwen2,
|
||||
"cohere": patch_cohere,
|
||||
"cohere2": patch_cohere2,
|
||||
}
|
||||
|
||||
|
||||
def cce_patch(
|
||||
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
||||
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
||||
reduction: str = "mean",
|
||||
filter_eps: float | str | None = "auto",
|
||||
accum_e_fp32: bool = False,
|
||||
accum_c_fp32: bool = False,
|
||||
filter_e_grad: bool = True,
|
||||
filter_c_grad: bool = True,
|
||||
train_only: bool = False,
|
||||
) -> TransformersModelT | None:
|
||||
if isinstance(impl, LinearCrossEntropyImpl):
|
||||
impl = impl.name.lower()
|
||||
|
||||
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
||||
raise ValueError(f"Unknown {impl=}")
|
||||
|
||||
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
||||
model_type = model_type_or_model.config.model_type
|
||||
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
||||
model_type = model_type_or_model.model_type
|
||||
else:
|
||||
model_type = model_type_or_model
|
||||
|
||||
patch_options = PatchOptions(
|
||||
impl=impl,
|
||||
reduction=reduction,
|
||||
filter_eps=filter_eps,
|
||||
accum_e_fp32=accum_e_fp32,
|
||||
accum_c_fp32=accum_c_fp32,
|
||||
filter_e_grad=filter_e_grad,
|
||||
filter_c_grad=filter_c_grad,
|
||||
train_only=train_only,
|
||||
)
|
||||
|
||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
||||
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
||||
model_type_or_model, patch_options
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unknown model type {model_type}")
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Monkeypatch for apply_lce to add softcap."""
|
||||
|
||||
import torch
|
||||
from cut_cross_entropy import linear_cross_entropy
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions
|
||||
|
||||
|
||||
def apply_lce(
|
||||
e: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
opts: PatchOptions,
|
||||
bias: torch.Tensor | None = None,
|
||||
softcap: float | None = None,
|
||||
**loss_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Monkey patch for apply_lce to support softcap kwarg."""
|
||||
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
||||
cce_kwargs = opts.to_kwargs()
|
||||
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
||||
cce_kwargs["reduction"] = "sum"
|
||||
else:
|
||||
num_items_in_batch = None
|
||||
|
||||
loss = linear_cross_entropy(
|
||||
e,
|
||||
c,
|
||||
labels.to(e.device),
|
||||
bias=bias,
|
||||
shift=True,
|
||||
softcap=softcap,
|
||||
**cce_kwargs,
|
||||
)
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
loss = loss / num_items_in_batch
|
||||
|
||||
return loss
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Grokfast plugin for Axolotl
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
config args for grokfast plugin
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
|
||||
"""
|
||||
|
||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
||||
kd_ce_alpha: Optional[
|
||||
float
|
||||
] = None # loss coefficient for cross-entropy loss during KD
|
||||
kd_ce_alpha: Optional[float] = (
|
||||
None # loss coefficient for cross-entropy loss during KD
|
||||
)
|
||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||
kd_top_k_before_softmax: Optional[
|
||||
bool
|
||||
] = None # whether to sample top k before softmax during KD
|
||||
kd_top_k_before_softmax: Optional[bool] = (
|
||||
None # whether to sample top k before softmax during KD
|
||||
)
|
||||
|
||||
@@ -20,6 +20,26 @@ liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
- deepseek_v2
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3 (partial support, no support for FLCE yet)
|
||||
- granite
|
||||
- jamba
|
||||
- llama
|
||||
- mistral
|
||||
- mixtral
|
||||
- mllama
|
||||
- mllama_text_model
|
||||
- olmo2
|
||||
- paligemma
|
||||
- phi3
|
||||
- qwen2
|
||||
- qwen2_5_vl
|
||||
- qwen2_vl
|
||||
|
||||
## Citation
|
||||
|
||||
```bib
|
||||
|
||||
@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
@@ -41,11 +42,18 @@ class LigerPlugin(BasePlugin):
|
||||
def pre_model_load(self, cfg):
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
||||
raise ValueError(
|
||||
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
||||
)
|
||||
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
@@ -55,9 +63,9 @@ class LigerPlugin(BasePlugin):
|
||||
if "cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs[
|
||||
"fused_linear_cross_entropy"
|
||||
] = cfg.liger_fused_linear_cross_entropy
|
||||
kwargs["fused_linear_cross_entropy"] = (
|
||||
cfg.liger_fused_linear_cross_entropy
|
||||
)
|
||||
if "rms_norm" in liger_fn_sig.parameters:
|
||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||
if "layer_norm" in liger_fn_sig.parameters:
|
||||
@@ -82,6 +90,8 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
@@ -104,13 +114,51 @@ class LigerPlugin(BasePlugin):
|
||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
if cfg.liger_glu_activation:
|
||||
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
||||
if cfg.liger_cross_entropy:
|
||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||
# nn.CrossEntropyLoss in the forward method.
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
|
||||
def _liger_rms_norm_wrapper(dim, **kwargs):
|
||||
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
|
||||
return LigerRMSNorm(hidden_size=dim, **kwargs)
|
||||
|
||||
modeling_gemma3.Gemma3RMSNorm = partial(
|
||||
_liger_rms_norm_wrapper,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
||||
if cfg.liger_layer_norm:
|
||||
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
raise NotImplementedError(
|
||||
"Fused linear cross entropy is not yet supported for Gemma3."
|
||||
)
|
||||
elif cfg.model_config_type in ["deepseek_v3"]:
|
||||
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Jamba model with LigerFusedLinearCrossEntropyLoss
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Module for the Plugin for LM Eval Harness
|
||||
"""
|
||||
|
||||
import subprocess # nosec
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user