Compare commits
47 Commits
pre-commit
...
sequence-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ac65462f0 | ||
|
|
ce35b2a95f | ||
|
|
ab3b36339a | ||
|
|
22cfa42961 | ||
|
|
0b2c2ed68c | ||
|
|
2f0b4626b9 | ||
|
|
a26985c53c | ||
|
|
c1a58339e8 | ||
|
|
411df76a97 | ||
|
|
a09d1ccbf2 | ||
|
|
2727d86544 | ||
|
|
64c203cdef | ||
|
|
7d7042f602 | ||
|
|
d187f1f8e2 | ||
|
|
1cced52719 | ||
|
|
11321b17e7 | ||
|
|
7a1a211c99 | ||
|
|
e1a02a32b5 | ||
|
|
a6ef6c7764 | ||
|
|
cb3a9e99a3 | ||
|
|
3ae47ec7de | ||
|
|
e36dc763ab | ||
|
|
03027cf6bf | ||
|
|
0ade60d455 | ||
|
|
02e1a42f04 | ||
|
|
919b88f11b | ||
|
|
345a9dd831 | ||
|
|
4ff97bc9d4 | ||
|
|
d0e178d52f | ||
|
|
5731cdc0cf | ||
|
|
b7738d57c4 | ||
|
|
698e599bf7 | ||
|
|
1d339e4007 | ||
|
|
4190ad0647 | ||
|
|
b44a207248 | ||
|
|
51c326150b | ||
|
|
14baaf6e0a | ||
|
|
f487910444 | ||
|
|
c5071dfd8a | ||
|
|
e323145ba9 | ||
|
|
7efc787ac8 | ||
|
|
dce61cdab1 | ||
|
|
bd952de9d2 | ||
|
|
3f8a43cab6 | ||
|
|
113e9cd193 | ||
|
|
61825a464a | ||
|
|
c907ac173e |
6
.github/workflows/docs.yml
vendored
6
.github/workflows/docs.yml
vendored
@@ -20,9 +20,11 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
- name: install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter
|
python3 -m pip install jupyter quartodoc
|
||||||
|
- name: Build autodoc
|
||||||
|
run: quartodoc build
|
||||||
- name: Publish to GitHub Pages (and render)
|
- name: Publish to GitHub Pages (and render)
|
||||||
uses: quarto-dev/quarto-actions/publish@v2
|
uses: quarto-dev/quarto-actions/publish@v2
|
||||||
with:
|
with:
|
||||||
|
|||||||
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -98,8 +98,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -172,8 +173,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -181,6 +181,10 @@ prepared-datasets/
|
|||||||
submit.sh
|
submit.sh
|
||||||
*.out*
|
*.out*
|
||||||
|
|
||||||
|
# Quartodoc generated files
|
||||||
|
objects.json
|
||||||
|
site_libs/
|
||||||
|
|
||||||
typings/
|
typings/
|
||||||
out/
|
out/
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
|||||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||||
|
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
||||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||||
|
|
||||||
## 🤝 Getting Help
|
## 🤝 Getting Help
|
||||||
|
|||||||
193
_quarto.yml
193
_quarto.yml
@@ -1,6 +1,178 @@
|
|||||||
project:
|
project:
|
||||||
type: website
|
type: website
|
||||||
|
|
||||||
|
quartodoc:
|
||||||
|
dir: docs/api
|
||||||
|
package: axolotl
|
||||||
|
title: API Reference
|
||||||
|
parser: google
|
||||||
|
|
||||||
|
sections:
|
||||||
|
- title: Core
|
||||||
|
desc: Core functionality for training
|
||||||
|
contents:
|
||||||
|
- train
|
||||||
|
- evaluate
|
||||||
|
- datasets
|
||||||
|
- convert
|
||||||
|
- prompt_tokenizers
|
||||||
|
- logging_config
|
||||||
|
- core.trainer_builder
|
||||||
|
- core.training_args
|
||||||
|
- core.chat.messages
|
||||||
|
- core.chat.format.chatml
|
||||||
|
- core.chat.format.llama3x
|
||||||
|
- core.chat.format.shared
|
||||||
|
- core.datasets.chat
|
||||||
|
- core.datasets.transforms.chat_builder
|
||||||
|
- title: CLI
|
||||||
|
desc: Command-line interface
|
||||||
|
contents:
|
||||||
|
- cli.main
|
||||||
|
- cli.train
|
||||||
|
- cli.evaluate
|
||||||
|
- cli.args
|
||||||
|
- cli.checks
|
||||||
|
- cli.config
|
||||||
|
- cli.inference
|
||||||
|
- cli.merge_lora
|
||||||
|
- cli.merge_sharded_fsdp_weights
|
||||||
|
- cli.preprocess
|
||||||
|
- cli.sweeps
|
||||||
|
- cli.utils
|
||||||
|
- cli.cloud.base
|
||||||
|
- cli.cloud.modal_
|
||||||
|
- title: Trainers
|
||||||
|
desc: Training implementations
|
||||||
|
contents:
|
||||||
|
- core.trainers.base
|
||||||
|
- core.trainers.trl
|
||||||
|
- core.trainers.dpo.trainer
|
||||||
|
- core.trainers.grpo.trainer
|
||||||
|
- title: Prompt Strategies
|
||||||
|
desc: Prompt formatting strategies
|
||||||
|
contents:
|
||||||
|
- prompt_strategies.base
|
||||||
|
- prompt_strategies.chat_template
|
||||||
|
- prompt_strategies.alpaca_chat
|
||||||
|
- prompt_strategies.alpaca_instruct
|
||||||
|
- prompt_strategies.alpaca_w_system
|
||||||
|
- prompt_strategies.user_defined
|
||||||
|
- prompt_strategies.llama2_chat
|
||||||
|
- prompt_strategies.completion
|
||||||
|
- prompt_strategies.input_output
|
||||||
|
- prompt_strategies.stepwise_supervised
|
||||||
|
- prompt_strategies.metharme
|
||||||
|
- prompt_strategies.orcamini
|
||||||
|
- prompt_strategies.pygmalion
|
||||||
|
- prompt_strategies.messages.chat
|
||||||
|
- prompt_strategies.dpo.chat_template
|
||||||
|
- prompt_strategies.dpo.llama3
|
||||||
|
- prompt_strategies.dpo.chatml
|
||||||
|
- prompt_strategies.dpo.zephyr
|
||||||
|
- prompt_strategies.dpo.user_defined
|
||||||
|
- prompt_strategies.dpo.passthrough
|
||||||
|
- prompt_strategies.kto.llama3
|
||||||
|
- prompt_strategies.kto.chatml
|
||||||
|
- prompt_strategies.kto.user_defined
|
||||||
|
- prompt_strategies.orpo.chat_template
|
||||||
|
- prompt_strategies.bradley_terry.llama3
|
||||||
|
- title: Kernels
|
||||||
|
desc: Low-level performance optimizations
|
||||||
|
contents:
|
||||||
|
- kernels.lora
|
||||||
|
- kernels.geglu
|
||||||
|
- kernels.swiglu
|
||||||
|
- kernels.quantize
|
||||||
|
- kernels.utils
|
||||||
|
- title: MonkeyPatches
|
||||||
|
desc: Runtime patches for model optimizations
|
||||||
|
contents:
|
||||||
|
- monkeypatch.llama_attn_hijack_flash
|
||||||
|
- monkeypatch.llama_attn_hijack_xformers
|
||||||
|
- monkeypatch.mistral_attn_hijack_flash
|
||||||
|
- monkeypatch.multipack
|
||||||
|
- monkeypatch.relora
|
||||||
|
- monkeypatch.llama_expand_mask
|
||||||
|
- monkeypatch.lora_kernels
|
||||||
|
- monkeypatch.utils
|
||||||
|
- monkeypatch.btlm_attn_hijack_flash
|
||||||
|
- monkeypatch.llama_patch_multipack
|
||||||
|
- monkeypatch.stablelm_attn_hijack_flash
|
||||||
|
- monkeypatch.trainer_fsdp_optim
|
||||||
|
- monkeypatch.transformers_fa_utils
|
||||||
|
- monkeypatch.unsloth_
|
||||||
|
- monkeypatch.attention.mllama
|
||||||
|
- monkeypatch.data.batch_dataset_fetcher
|
||||||
|
- monkeypatch.mixtral
|
||||||
|
- title: Utils
|
||||||
|
desc: Utility functions
|
||||||
|
contents:
|
||||||
|
- utils.models
|
||||||
|
- utils.tokenization
|
||||||
|
- utils.chat_templates
|
||||||
|
- utils.lora
|
||||||
|
- utils.lora_embeddings
|
||||||
|
- utils.model_shard_quant
|
||||||
|
- utils.bench
|
||||||
|
- utils.freeze
|
||||||
|
- utils.trainer
|
||||||
|
- utils.schedulers
|
||||||
|
- utils.distributed
|
||||||
|
- utils.dict
|
||||||
|
- utils.optimizers.adopt
|
||||||
|
- utils.data.pretraining
|
||||||
|
- utils.data.sft
|
||||||
|
- utils.gradient_checkpointing.unsloth
|
||||||
|
- title: Schemas
|
||||||
|
desc: Pydantic data models for Axolotl config
|
||||||
|
contents:
|
||||||
|
- utils.schemas.config
|
||||||
|
- utils.schemas.model
|
||||||
|
- utils.schemas.training
|
||||||
|
- utils.schemas.datasets
|
||||||
|
- utils.schemas.peft
|
||||||
|
- utils.schemas.trl
|
||||||
|
- utils.schemas.integrations
|
||||||
|
- utils.schemas.enums
|
||||||
|
- utils.schemas.utils
|
||||||
|
- title: Integrations
|
||||||
|
desc: Third-party integrations and extensions
|
||||||
|
contents:
|
||||||
|
- integrations.base
|
||||||
|
- integrations.cut_cross_entropy.args
|
||||||
|
- integrations.grokfast.optimizer
|
||||||
|
- integrations.kd.trainer
|
||||||
|
- integrations.liger.args
|
||||||
|
- integrations.lm_eval.args
|
||||||
|
- integrations.spectrum.args
|
||||||
|
- title: Common
|
||||||
|
desc: Common utilities and shared functionality
|
||||||
|
contents:
|
||||||
|
- common.architectures
|
||||||
|
- common.const
|
||||||
|
- common.datasets
|
||||||
|
- title: Models
|
||||||
|
desc: Custom model implementations
|
||||||
|
contents:
|
||||||
|
- models.mamba.modeling_mamba
|
||||||
|
- title: Data Processing
|
||||||
|
desc: Data processing utilities
|
||||||
|
contents:
|
||||||
|
- utils.collators.core
|
||||||
|
- utils.collators.batching
|
||||||
|
- utils.collators.mamba
|
||||||
|
- utils.collators.mm_chat
|
||||||
|
- utils.samplers.multipack
|
||||||
|
- title: Callbacks
|
||||||
|
desc: Training callbacks
|
||||||
|
contents:
|
||||||
|
- utils.callbacks.perplexity
|
||||||
|
- utils.callbacks.profiler
|
||||||
|
- utils.callbacks.lisa
|
||||||
|
- utils.callbacks.mlflow_
|
||||||
|
- utils.callbacks.comet_
|
||||||
|
|
||||||
website:
|
website:
|
||||||
title: "Axolotl"
|
title: "Axolotl"
|
||||||
description: "We make fine-tuning accessible, scalable, and fun"
|
description: "We make fine-tuning accessible, scalable, and fun"
|
||||||
@@ -35,6 +207,8 @@ website:
|
|||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/config.qmd
|
- docs/config.qmd
|
||||||
|
- text: "API Reference"
|
||||||
|
href: docs/api
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
@@ -80,3 +254,22 @@ format:
|
|||||||
theme: darkly
|
theme: darkly
|
||||||
css: styles.css
|
css: styles.css
|
||||||
toc: true
|
toc: true
|
||||||
|
# Enable better handling of line breaks in markdown
|
||||||
|
preserve-tabs: true
|
||||||
|
html-math-method: mathjax
|
||||||
|
# Improved markdown processing options
|
||||||
|
md-extensions:
|
||||||
|
- markdown_it
|
||||||
|
- def_list
|
||||||
|
- attr_list
|
||||||
|
- fenced_divs
|
||||||
|
- tables
|
||||||
|
- html_admonition
|
||||||
|
- lineblocks
|
||||||
|
- fancy_lists
|
||||||
|
# Control whitespace handling
|
||||||
|
whitespace: preserve
|
||||||
|
# Process newlines in paragraphs
|
||||||
|
wrap: preserve
|
||||||
|
# Better line break handling
|
||||||
|
preserve-linebreaks: true
|
||||||
|
|||||||
@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
||||||
|
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1,2 +1,4 @@
|
|||||||
/.quarto/
|
/.quarto/
|
||||||
_site/
|
_site/
|
||||||
|
/api/*.qmd
|
||||||
|
/api/*.html
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "CLI Reference"
|
title: "Command Line Interface (CLI)"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ tokenizer_legacy:
|
|||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
shrink_embeddings:
|
shrink_embeddings:
|
||||||
|
# Whether to load the model with randomly initialized weights. Useful for
|
||||||
|
# pre-training a model from scratch or debugging purposes.
|
||||||
|
random_init_weights:
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -617,6 +620,14 @@ ddp_timeout:
|
|||||||
ddp_bucket_cap_mb:
|
ddp_bucket_cap_mb:
|
||||||
ddp_broadcast_buffers:
|
ddp_broadcast_buffers:
|
||||||
|
|
||||||
|
# Sequence parallelism
|
||||||
|
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||||
|
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||||
|
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||||
|
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||||
|
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||||
|
sequence_parallel_degree:
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ description: How datasets are processed
|
|||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||||
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
the [dataset format](dataset-formats) and prompt strategies to:
|
||||||
|
|
||||||
- parse the dataset based on the *dataset format*
|
- parse the dataset based on the *dataset format*
|
||||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||||
|
|
||||||
|
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
|
||||||
|
|
||||||
|
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
90
docs/sequence_parallelism.qmd
Normal file
90
docs/sequence_parallelism.qmd
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
---
|
||||||
|
title: Sequence Parallelism
|
||||||
|
description: Train with long sequences split across multiple GPUs.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Sequence Parallelism
|
||||||
|
|
||||||
|
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||||
|
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||||
|
GPU processes a different portion of the sequence, and the results are aggregated
|
||||||
|
through a ring communication pattern.
|
||||||
|
|
||||||
|
## When to Use Sequence Parallelism
|
||||||
|
|
||||||
|
Use sequence parallelism when:
|
||||||
|
|
||||||
|
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||||
|
- You have multiple GPUs available
|
||||||
|
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To enable sequence parallelism, add the following to your configuration file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Set to a divisor (> 1) of the number of GPUs available
|
||||||
|
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||||
|
```
|
||||||
|
|
||||||
|
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||||
|
|
||||||
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||||
|
- With 4 GPUs, valid values would be 2 or 4
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
When sequence parallelism is enabled:
|
||||||
|
|
||||||
|
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||||
|
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||||
|
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
||||||
|
4. The trainer uses special ring communication patterns for attention operations
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
To use sequence parallelism, you need:
|
||||||
|
|
||||||
|
- Multiple GPUs (at least 2)
|
||||||
|
- The `ring-flash-attn` package. Install with:
|
||||||
|
- `pip install axolotl[ring-flash-attn]` (preferred)
|
||||||
|
- `pip install ring-flash-attn>=0.1.4`
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
||||||
|
- May have a small performance overhead due to communication between GPUs
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Example config with sequence parallelism
|
||||||
|
base_model: meta-llama/Llama-3-8B-Instruct
|
||||||
|
sequence_len: 8192
|
||||||
|
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
||||||
|
flash_attention: true # Required with sequence parallelism
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||||
|
into 2 subsequences of length 4096 across 2 GPUs.
|
||||||
|
|
||||||
|
## Sample Packing with Sequence Parallelism
|
||||||
|
|
||||||
|
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||||
|
|
||||||
|
1. Samples are first packed together
|
||||||
|
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||||
|
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||||
|
|
||||||
|
## Effect on Batch Size
|
||||||
|
|
||||||
|
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||||
|
|
||||||
|
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||||
|
- The number of batches processed per step decreases
|
||||||
|
|
||||||
|
For example:
|
||||||
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||||
|
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||||
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||||
@@ -2,3 +2,5 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
|
quartodoc
|
||||||
|
jupyter
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
bitsandbytes==0.45.3
|
bitsandbytes==0.45.3
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
flash-attn==2.7.4.post1
|
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
liger-kernel==0.5.3
|
liger-kernel==0.5.3
|
||||||
@@ -36,6 +35,7 @@ einops
|
|||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4,<=2.0.1
|
numpy>=1.24.4,<=2.0.1
|
||||||
|
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
|
|||||||
12
setup.py
12
setup.py
@@ -17,11 +17,7 @@ def parse_requirements():
|
|||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
is_extras = (
|
is_extras = (
|
||||||
"flash-attn" in line
|
"deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line
|
||||||
or "flash-attention" in line
|
|
||||||
or "deepspeed" in line
|
|
||||||
or "mamba-ssm" in line
|
|
||||||
or "lion-pytorch" in line
|
|
||||||
)
|
)
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
@@ -39,7 +35,6 @@ def parse_requirements():
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"triton",
|
"triton",
|
||||||
"mamba-ssm",
|
"mamba-ssm",
|
||||||
"flash-attn",
|
|
||||||
"xformers",
|
"xformers",
|
||||||
"autoawq",
|
"autoawq",
|
||||||
"liger-kernel",
|
"liger-kernel",
|
||||||
@@ -124,9 +119,8 @@ setup(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||||
"flash-attn==2.7.4.post1",
|
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
||||||
],
|
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.4",
|
"deepspeed==0.16.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from axolotl.cli.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||||
"""
|
"""
|
||||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||||
@@ -44,16 +44,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
del model, tokenizer, trainer
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
del model
|
|
||||||
del tokenizer
|
|
||||||
del trainer
|
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
"""
|
"""Builder for the training args and trainer"""
|
||||||
Builder for the training args and trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
import importlib
|
||||||
@@ -38,7 +36,7 @@ from transformers import (
|
|||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
@@ -85,8 +83,8 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
@@ -764,6 +762,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kd_top_k_before_softmax
|
self.cfg.kd_top_k_before_softmax
|
||||||
)
|
)
|
||||||
|
|
||||||
|
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||||
|
self.cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
elif self.cfg.process_reward_model:
|
elif self.cfg.process_reward_model:
|
||||||
@@ -847,9 +849,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if self.cfg.pretraining_sample_concatenation is False:
|
if (
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
self.cfg.pretraining_sample_concatenation is False
|
||||||
if self.cfg.micro_batch_size > 1:
|
or self.cfg.micro_batch_size > 1
|
||||||
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -877,9 +880,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
elif (
|
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
and self.cfg.flash_attention is not True
|
and self.cfg.flash_attention is not True
|
||||||
):
|
):
|
||||||
@@ -910,6 +911,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
|
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||||
|
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""Init for axolotl.core.trainers"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .base import AxolotlTrainer
|
||||||
|
from .dpo.trainer import AxolotlDPOTrainer
|
||||||
|
from .grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
from .mamba import AxolotlMambaTrainer
|
||||||
|
from .relora import ReLoRATrainer
|
||||||
|
from .trl import (
|
||||||
|
AxolotlCPOTrainer,
|
||||||
|
AxolotlKTOTrainer,
|
||||||
|
AxolotlORPOTrainer,
|
||||||
|
AxolotlPRMTrainer,
|
||||||
|
AxolotlRewardTrainer,
|
||||||
|
TRLPPOTrainer,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,365 +1,47 @@
|
|||||||
"""
|
"""Module for customized trainers"""
|
||||||
module for customized trainers
|
|
||||||
"""
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Dict, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.utils.data import (
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
BatchSampler,
|
||||||
|
DataLoader,
|
||||||
|
RandomSampler,
|
||||||
|
Sampler,
|
||||||
|
SequentialSampler,
|
||||||
|
)
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
|
||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.core.trainers.mixins import (
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
OptimizerMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
SequenceParallelMixin,
|
||||||
|
)
|
||||||
|
from axolotl.core.trainers.utils import (
|
||||||
|
sanitize_kwargs_for_ds_tagging,
|
||||||
|
sanitize_kwargs_for_tagging,
|
||||||
|
)
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
|
||||||
RexLR,
|
|
||||||
get_cosine_schedule_with_min_lr,
|
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
|
||||||
get_cosine_schedule_with_warmup_decay_constant,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
LOG = logging.getLogger(__name__)
|
||||||
import smdistributed.modelparallel.torch as smp
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
||||||
if isinstance(tag_names, str):
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
tag_names = [tag_names]
|
|
||||||
|
|
||||||
if kwargs is not None:
|
|
||||||
if "tags" not in kwargs:
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
|
||||||
kwargs["tags"].extend(tag_names)
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
|
||||||
tag_names.append(kwargs["tags"])
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
|
||||||
if isinstance(dataset_tags, str):
|
|
||||||
dataset_tags = [dataset_tags]
|
|
||||||
|
|
||||||
if (dataset_tags is not None) and (kwargs is not None):
|
|
||||||
if "dataset_tags" not in kwargs:
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
|
||||||
kwargs["dataset_tags"].extend(dataset_tags)
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
|
||||||
dataset_tags.append(kwargs["dataset_tags"])
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Mixin class for scheduler setup in CausalTrainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
|
||||||
passed as an argument.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_training_steps (int): The number of training steps to do.
|
|
||||||
optimizer (torch.optim.Optimizer): The training optimizer
|
|
||||||
"""
|
|
||||||
use_cosine_quadratic = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.lr_quadratic_warmup is True
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cosine_min_lr = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.cosine_min_lr_ratio is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
|
||||||
# fmt: on
|
|
||||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
|
||||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
|
||||||
extra_lr_kwargs = {}
|
|
||||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["pct_start"] = pct_start
|
|
||||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
|
||||||
|
|
||||||
self.lr_scheduler = OneCycleLR(
|
|
||||||
optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
**extra_lr_kwargs,
|
|
||||||
**self.args.lr_scheduler_kwargs,
|
|
||||||
)
|
|
||||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
|
|
||||||
self.lr_scheduler = RexLR(
|
|
||||||
optimizer=optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
)
|
|
||||||
elif use_cosine_quadratic:
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
|
||||||
else:
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Mixin class for shared handling of building custom optimizers
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(
|
|
||||||
self, opt_model, optimizer_kwargs
|
|
||||||
) -> list[dict]:
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
|
||||||
params: dict = {
|
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
|
||||||
"no_weight_decay": {},
|
|
||||||
}
|
|
||||||
lr_groups_lookup = {}
|
|
||||||
lr_groups_learning_rates = {}
|
|
||||||
if self.args.lr_groups:
|
|
||||||
for lr_group in self.args.lr_groups:
|
|
||||||
group_name = lr_group["name"]
|
|
||||||
group_modules = lr_group["modules"]
|
|
||||||
for module in group_modules:
|
|
||||||
lr_groups_lookup[module] = group_name
|
|
||||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
|
||||||
params[f"to_weight_decay_{group_name}"] = {}
|
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
if name.endswith("modules_to_save.default.weight") or any(
|
|
||||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
|
||||||
):
|
|
||||||
params["embeddings"][name] = param
|
|
||||||
elif name in decay_parameters:
|
|
||||||
lr_group_modules = [
|
|
||||||
group_modules
|
|
||||||
for group_modules in lr_groups_lookup
|
|
||||||
if group_modules in name
|
|
||||||
]
|
|
||||||
if lr_groups_lookup and any(lr_group_modules):
|
|
||||||
lr_group_module = lr_group_modules[0]
|
|
||||||
group_name = lr_groups_lookup[lr_group_module]
|
|
||||||
params[f"to_weight_decay_{group_name}"][name] = param
|
|
||||||
else:
|
|
||||||
params["to_weight_decay"][name] = param
|
|
||||||
else:
|
|
||||||
params["no_weight_decay"][name] = param
|
|
||||||
optimizer_grouped_parameters = []
|
|
||||||
if params["to_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["to_weight_decay"].values()),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["embeddings"]:
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
if self.args.embedding_lr_scale:
|
|
||||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
|
||||||
elif self.args.embedding_lr:
|
|
||||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["embeddings"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["no_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["no_weight_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
|
||||||
if params[f"to_weight_decay_{group_name}"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(
|
|
||||||
params[f"to_weight_decay_{group_name}"].values()
|
|
||||||
),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": group_lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if (
|
|
||||||
self.args.loraplus_lr_ratio is None
|
|
||||||
and self.args.embedding_lr_scale is None
|
|
||||||
and self.args.embedding_lr is None
|
|
||||||
and self.args.lr_groups is None
|
|
||||||
and self.optimizer_cls_and_kwargs is None
|
|
||||||
):
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
|
|
||||||
if (
|
|
||||||
not self.optimizer
|
|
||||||
and self.optimizer_cls_and_kwargs is not None
|
|
||||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
|
||||||
):
|
|
||||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
self.optimizer = optimizer_factory_cls()(
|
|
||||||
opt_model, self.args, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.optimizer:
|
|
||||||
if self.optimizer_cls_and_kwargs is not None:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
else:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args, opt_model
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
|
||||||
opt_model, optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
loraplus_lr_embedding = getattr(
|
|
||||||
self.args, "loraplus_lr_embedding", 1e-6
|
|
||||||
)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for GaLore optimizer.
|
|
||||||
if "params" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
|
||||||
|
|
||||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for LOMO optimizer.
|
|
||||||
if "model" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
|
||||||
|
|
||||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
|
||||||
# to avoid arguments conflicts.
|
|
||||||
if "optimizer_dict" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
|
||||||
"optimizer_dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer = optimizer_cls(
|
|
||||||
optimizer_grouped_parameters, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
for module in opt_model.modules():
|
|
||||||
if isinstance(module, nn.Embedding):
|
|
||||||
skipped += sum(
|
|
||||||
{
|
|
||||||
p.data_ptr(): p.numel() for p in module.parameters()
|
|
||||||
}.values()
|
|
||||||
)
|
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
|
||||||
manager.register_module_override(
|
|
||||||
module, "weight", {"optim_bits": 32}
|
|
||||||
)
|
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|
||||||
"""
|
|
||||||
Extend the base Trainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
@@ -376,12 +58,18 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
# Initialize sequence parallelism if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._setup_sequence_parallel()
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
@@ -394,8 +82,20 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _create_multipack_sampler(
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
self, base_sampler: Sampler, dataset: Dataset
|
||||||
|
) -> MultipackBatchSampler:
|
||||||
|
"""
|
||||||
|
Helper method to create a `MultipackBatchSampler` for multipacking sequences
|
||||||
|
for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
|
||||||
|
dataset: Dataset to sample from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Multipack (sample packing) batch sampler.
|
||||||
|
"""
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
batch_size = self.args.per_device_train_batch_size
|
batch_size = self.args.per_device_train_batch_size
|
||||||
batch_max_len = self.args.max_seq_length
|
batch_max_len = self.args.max_seq_length
|
||||||
@@ -406,130 +106,223 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
|
|
||||||
if self.args.curriculum_sampling:
|
|
||||||
sampler = SequentialSampler(self.train_dataset)
|
|
||||||
else:
|
|
||||||
sampler = RandomSampler(self.train_dataset)
|
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
sampler,
|
base_sampler,
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
lengths=get_dataset_lengths(dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
batch_max_len=batch_max_len,
|
batch_max_len=batch_max_len,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
group_size=self.args.sample_packing_group_size,
|
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
if self.args.curriculum_sampling:
|
|
||||||
return SequentialSampler(self.train_dataset)
|
def _get_train_sampler(self) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Helper method to get the sampler for training. Handles cases for sequence
|
||||||
|
parallelism, sample packing, and curriculum sampling (sequential).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
|
depends on the passed training args.
|
||||||
|
"""
|
||||||
|
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||||
|
|
||||||
|
# Determine the base sampler first
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||||
|
elif self.args.curriculum_sampling:
|
||||||
|
base_sampler = SequentialSampler(self.train_dataset)
|
||||||
|
elif use_sample_packing:
|
||||||
|
base_sampler = RandomSampler(self.train_dataset)
|
||||||
|
else:
|
||||||
|
# Default to parent class implementation for standard random sampling
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
def _get_eval_sampler(
|
# Apply multipack wrapper if needed
|
||||||
self, eval_dataset: Dataset
|
if use_sample_packing:
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
return self._create_multipack_sampler(
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
base_sampler=base_sampler,
|
||||||
if self.args.multipack_real_batches:
|
dataset=self.train_dataset,
|
||||||
batch_size = self.args.per_device_eval_batch_size
|
)
|
||||||
batch_max_len = self.args.max_seq_length
|
|
||||||
|
return base_sampler
|
||||||
|
|
||||||
|
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||||
|
and sample packing cases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
|
depends on the passed training args.
|
||||||
|
"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||||
|
use_multipack = (
|
||||||
|
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine the base sampler
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||||
|
elif use_multipack:
|
||||||
|
base_sampler = SequentialSampler(eval_dataset)
|
||||||
else:
|
else:
|
||||||
batch_size = 1
|
|
||||||
batch_max_len = (
|
|
||||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
|
||||||
)
|
|
||||||
return MultipackBatchSampler(
|
|
||||||
SequentialSampler(eval_dataset),
|
|
||||||
lengths=get_dataset_lengths(self.eval_dataset),
|
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
|
||||||
batch_max_len=batch_max_len,
|
|
||||||
batch_size=batch_size,
|
|
||||||
group_size=self.args.sample_packing_group_size,
|
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
# Apply multipack wrapper if needed
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if use_multipack:
|
||||||
train_dataset = self.train_dataset
|
return self._create_multipack_sampler(
|
||||||
if "length" in train_dataset.features.keys():
|
base_sampler=base_sampler,
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
dataset=eval_dataset,
|
||||||
data_collator = self.data_collator
|
)
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self._train_batch_size,
|
return base_sampler
|
||||||
"collate_fn": data_collator,
|
|
||||||
|
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||||
|
"""Create common dataloader parameters for train or eval."""
|
||||||
|
batch_size = custom_batch_size or (
|
||||||
|
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"collate_fn": self.data_collator,
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params["prefetch_factor"] = (
|
|
||||||
self.args.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
sampler = self._get_train_sampler()
|
# Add persistent workers only for training
|
||||||
|
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||||
|
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||||
|
|
||||||
|
# Add prefetch factor if specified
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _prepare_dataloader(
|
||||||
|
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||||
|
):
|
||||||
|
"""Prepare a dataloader with the given dataset and sampler."""
|
||||||
|
# Get base parameters
|
||||||
|
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||||
|
|
||||||
|
# Add sampler configuration
|
||||||
|
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
|
# batch_size and batch_sampler are mutually exclusive
|
||||||
dataloader_params["batch_sampler"] = sampler
|
dataloader_params["batch_sampler"] = sampler
|
||||||
del dataloader_params["batch_size"]
|
del dataloader_params["batch_size"]
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
|
if not is_eval:
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
# Create the dataloader
|
||||||
return self.accelerator.prepare_data_loader(
|
dataloader = DataLoader(dataset, **dataloader_params)
|
||||||
DataLoader(train_dataset, **dataloader_params)
|
|
||||||
)
|
|
||||||
return super().get_train_dataloader()
|
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
if self.args.sample_packing and (
|
||||||
|
(not is_eval and not self.args.pretraining)
|
||||||
|
or (is_eval and self.args.eval_sample_packing is not False)
|
||||||
|
):
|
||||||
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
# Otherwise prepare with accelerator
|
||||||
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
"""Get dataloader for training"""
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
data_collator = self.data_collator # type: ignore
|
||||||
|
|
||||||
|
# Handle dataset preprocessing
|
||||||
|
if isinstance(train_dataset, datasets.Dataset):
|
||||||
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
|
if not self.args.sample_packing or self.args.pretraining:
|
||||||
|
train_dataset = self._remove_unused_columns(
|
||||||
|
train_dataset, description="training"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
|
data_collator,
|
||||||
|
description="training",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sampler and create dataloader
|
||||||
|
sampler = self._get_train_sampler()
|
||||||
|
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||||
|
|
||||||
|
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||||
|
"""Get dataloader for evaluation"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
)
|
)
|
||||||
if eval_dataset:
|
if "length" in eval_dataset.column_names:
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.train_data_collator
|
self.train_data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
# Handle sample packing or sequence parallelism
|
||||||
eval_dataset = (
|
if (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
self.args.sample_packing
|
||||||
|
and self.args.eval_sample_packing is not False
|
||||||
|
or self.args.sequence_parallel_degree > 1
|
||||||
|
):
|
||||||
|
# Get appropriate data collator
|
||||||
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.eval_data_collator
|
||||||
|
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
||||||
|
else self.data_collator
|
||||||
)
|
)
|
||||||
|
if "length" in eval_dataset.column_names:
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
data_collator = self.data_collator
|
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self.args.eval_batch_size,
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
}
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params["prefetch_factor"] = (
|
|
||||||
self.args.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(eval_sampler, BatchSampler):
|
# Handle dataset preprocessing for SP
|
||||||
dataloader_params["batch_sampler"] = eval_sampler
|
if self.args.sequence_parallel_degree > 1:
|
||||||
del dataloader_params["batch_size"]
|
if isinstance(eval_dataset, datasets.Dataset):
|
||||||
|
eval_dataset = self._remove_unused_columns(
|
||||||
|
eval_dataset, description="evaluation"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = eval_sampler
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
self.data_collator, description="evaluation"
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
DataLoader(eval_dataset, **dataloader_params)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||||
|
batch_size = (
|
||||||
|
self.args.eval_batch_size
|
||||||
|
if self.args.sample_packing
|
||||||
|
else self.args.per_device_eval_batch_size
|
||||||
|
)
|
||||||
|
sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
dataloader = self._prepare_dataloader(
|
||||||
|
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def _get_bench_sampler(
|
def _get_bench_sampler(
|
||||||
self, bench_dataset: Dataset
|
self, bench_dataset: Dataset
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> torch.utils.data.Sampler | None:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return SequentialSampler(bench_dataset)
|
return SequentialSampler(bench_dataset)
|
||||||
return None
|
return None
|
||||||
@@ -554,6 +347,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
):
|
):
|
||||||
@@ -570,6 +364,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return_outputs=return_outputs,
|
return_outputs=return_outputs,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().compute_loss(
|
return super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -744,10 +539,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@@ -764,15 +559,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs (`Dict[str, float]`):
|
logs: The values to log.
|
||||||
The values to log.
|
start_time: The start of training.
|
||||||
start_time (`Optional[float]`):
|
|
||||||
The start of training.
|
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -784,7 +577,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
) -> None:
|
) -> None:
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
@@ -797,110 +590,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
def training_step(
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Mamba specific trainer to handle loss calculation
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "mamba"]
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
self,
|
||||||
model,
|
model: nn.Module,
|
||||||
inputs,
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
num_items_in_batch: int | None = None,
|
||||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
) -> torch.Tensor:
|
||||||
):
|
|
||||||
input_ids = inputs.pop("input_ids")
|
|
||||||
lm_logits = model(input_ids).logits
|
|
||||||
|
|
||||||
labels = input_ids.to(lm_logits.device)
|
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
lm_loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
return lm_loss
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRATrainer(AxolotlTrainer):
|
|
||||||
"""
|
"""
|
||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Perform a training step on a batch of inputs. Overrides the
|
||||||
|
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||||
|
enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to perform training step for.
|
||||||
|
inputs: Dictionary mapping.
|
||||||
"""
|
"""
|
||||||
|
# Set up sequence parallelism for this step if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
tag_names = ["axolotl", "relora"]
|
# Proceed with normal training step
|
||||||
|
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
return loss
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.lr_scheduler = None
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self,
|
|
||||||
num_training_steps: int,
|
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
||||||
):
|
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
|
||||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
|
||||||
|
|
||||||
if self.args.relora_steps:
|
|
||||||
warmup_steps = (
|
|
||||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
|
||||||
)
|
|
||||||
anneal_steps = (
|
|
||||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
|
||||||
)
|
|
||||||
self.lr_scheduler = ReLoRAScheduler(
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler,
|
|
||||||
self.args.relora_steps,
|
|
||||||
anneal_steps,
|
|
||||||
warmup_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "prm"]
|
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.mixins import SchedulerMixin
|
||||||
SchedulerMixin,
|
from axolotl.core.trainers.utils import (
|
||||||
_sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
_sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
@@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
32
src/axolotl/core/trainers/mamba.py
Normal file
32
src/axolotl/core/trainers/mamba.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Module for mamba trainer"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
"""Mamba specific trainer to handle loss calculation"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "mamba"]
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
input_ids = inputs.pop("input_ids")
|
||||||
|
lm_logits = model(input_ids).logits
|
||||||
|
|
||||||
|
labels = input_ids.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lm_loss
|
||||||
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Init for axolotl.core.trainers.mixins"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .optimizer import OptimizerMixin
|
||||||
|
from .scheduler import SchedulerMixin
|
||||||
|
from .sequence_parallel import SequenceParallelMixin
|
||||||
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""Module for Axolotl trainer optimizer mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerMixin(Trainer):
|
||||||
|
"""Mixin class for shared handling of building custom optimizers"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_optimizer_grouped_parameters(
|
||||||
|
self, opt_model, optimizer_kwargs
|
||||||
|
) -> list[dict]:
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
params: dict = {
|
||||||
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
|
"no_weight_decay": {},
|
||||||
|
}
|
||||||
|
lr_groups_lookup = {}
|
||||||
|
lr_groups_learning_rates = {}
|
||||||
|
if self.args.lr_groups:
|
||||||
|
for lr_group in self.args.lr_groups:
|
||||||
|
group_name = lr_group["name"]
|
||||||
|
group_modules = lr_group["modules"]
|
||||||
|
for module in group_modules:
|
||||||
|
lr_groups_lookup[module] = group_name
|
||||||
|
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||||
|
params[f"to_weight_decay_{group_name}"] = {}
|
||||||
|
|
||||||
|
for name, param in opt_model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
lr_group_modules = [
|
||||||
|
group_modules
|
||||||
|
for group_modules in lr_groups_lookup
|
||||||
|
if group_modules in name
|
||||||
|
]
|
||||||
|
if lr_groups_lookup and any(lr_group_modules):
|
||||||
|
lr_group_module = lr_group_modules[0]
|
||||||
|
group_name = lr_groups_lookup[lr_group_module]
|
||||||
|
params[f"to_weight_decay_{group_name}"][name] = param
|
||||||
|
else:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||||
|
if params[f"to_weight_decay_{group_name}"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(
|
||||||
|
params[f"to_weight_decay_{group_name}"].values()
|
||||||
|
),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": group_lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
|
and self.args.lr_groups is None
|
||||||
|
and self.optimizer_cls_and_kwargs is None
|
||||||
|
):
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self.optimizer
|
||||||
|
and self.optimizer_cls_and_kwargs is not None
|
||||||
|
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||||
|
):
|
||||||
|
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
self.optimizer = optimizer_factory_cls()(
|
||||||
|
opt_model, self.args, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.optimizer:
|
||||||
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
else:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args, opt_model
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
|
opt_model, optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(
|
||||||
|
self.args, "loraplus_lr_embedding", 1e-6
|
||||||
|
)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
|
**optimizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for GaLore optimizer.
|
||||||
|
if "params" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
|
|
||||||
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for LOMO optimizer.
|
||||||
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||||
|
"optimizer_dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer = optimizer_cls(
|
||||||
|
optimizer_grouped_parameters, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
skipped = 0
|
||||||
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum(
|
||||||
|
{
|
||||||
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Module for Axolotl trainer scheduler mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.utils.schedulers import (
|
||||||
|
RexLR,
|
||||||
|
get_cosine_schedule_with_min_lr,
|
||||||
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
passed as an argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_training_steps (int): The number of training steps to do.
|
||||||
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
|
"""
|
||||||
|
use_cosine_quadratic = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cosine_min_lr = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.cosine_min_lr_ratio is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||||
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
extra_lr_kwargs = {}
|
||||||
|
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["pct_start"] = pct_start
|
||||||
|
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||||
|
|
||||||
|
self.lr_scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
**extra_lr_kwargs,
|
||||||
|
**self.args.lr_scheduler_kwargs,
|
||||||
|
)
|
||||||
|
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
|
||||||
|
self.lr_scheduler = RexLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
)
|
||||||
|
elif use_cosine_quadratic:
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
|
else:
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from datasets import Dataset
|
||||||
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
except ImportError:
|
||||||
|
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||||
|
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for sequence parallelism support in trainers.
|
||||||
|
|
||||||
|
This mixin provides functionality for handling sequence parallelism,
|
||||||
|
including creating appropriate samplers, managing data partitioning,
|
||||||
|
and updating ring flash attention parameters during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def _setup_sequence_parallel(self):
|
||||||
|
"""Set up sequence parallelism environment."""
|
||||||
|
self.ring_attn_group = get_ring_attn_group()
|
||||||
|
|
||||||
|
def _create_sequence_parallel_sampler(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
shuffle: bool = True,
|
||||||
|
is_eval: bool = False,
|
||||||
|
) -> DistributedSampler:
|
||||||
|
"""
|
||||||
|
Helper method to create sampler for sequence parallelism (SP).
|
||||||
|
|
||||||
|
We create a distributed sampler with rank equal to the SP group ID, which
|
||||||
|
means that all ranks in the SP group receive the same sample / set of samples
|
||||||
|
per training step. We also set the number of replicas equal to the number of
|
||||||
|
SP groups, which is a bit of a hack / unintended use, but works!
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Dataset to sample from.
|
||||||
|
shuffle: Whether to shuffle the dataset.
|
||||||
|
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Distributed sampler.
|
||||||
|
"""
|
||||||
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||||
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
return DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=num_sp_groups,
|
||||||
|
rank=sp_group_id,
|
||||||
|
seed=self.args.seed if shuffle else None,
|
||||||
|
shuffle=shuffle,
|
||||||
|
drop_last=not is_eval,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Get a training sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The training dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self._create_sequence_parallel_sampler(
|
||||||
|
dataset,
|
||||||
|
shuffle=not self.args.curriculum_sampling,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Get an evaluation sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_dataset: The evaluation dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self._create_sequence_parallel_sampler(
|
||||||
|
eval_dataset, shuffle=False, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
||||||
|
"""
|
||||||
|
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||||
|
the substituted ring_flash_attn. This is accomplished by using the passed
|
||||||
|
`input_ids`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Current batch of inputs.
|
||||||
|
"""
|
||||||
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
|
# parallel data collator
|
||||||
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
packed_seq_lens = [seq_len] * batch_size
|
||||||
|
|
||||||
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
|
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
cu_seqlens = torch.cumsum(
|
||||||
|
torch.tensor(
|
||||||
|
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(
|
||||||
|
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||||
43
src/axolotl/core/trainers/relora.py
Normal file
43
src/axolotl/core/trainers/relora.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Module for ReLoRA trainer"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRATrainer(AxolotlTrainer):
|
||||||
|
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "relora"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self,
|
||||||
|
num_training_steps: int,
|
||||||
|
optimizer: torch.optim.Optimizer | None = None,
|
||||||
|
):
|
||||||
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
|
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
if self.args.relora_steps:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||||
|
)
|
||||||
|
anneal_steps = (
|
||||||
|
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||||
|
)
|
||||||
|
self.lr_scheduler = ReLoRAScheduler(
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
self.args.relora_steps,
|
||||||
|
anneal_steps,
|
||||||
|
warmup_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
@@ -1,16 +1,23 @@
|
|||||||
"""
|
"""Module for TRL PPO trainer"""
|
||||||
module for TRL PPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trl import PPOTrainer
|
from trl import (
|
||||||
|
CPOTrainer,
|
||||||
|
KTOTrainer,
|
||||||
|
ORPOTrainer,
|
||||||
|
PPOTrainer,
|
||||||
|
PRMTrainer,
|
||||||
|
RewardTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
class TRLPPOTrainer(PPOTrainer):
|
||||||
"""
|
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||||
wrapper for ppo trainer to handle customizations
|
|
||||||
"""
|
tag_names = ["axolotl", "ppo"]
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
@@ -31,9 +38,7 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||||
enumerate(self.dataloader)
|
|
||||||
):
|
|
||||||
query_tensors = batch["input_ids"]
|
query_tensors = batch["input_ids"]
|
||||||
|
|
||||||
# generate model response
|
# generate model response
|
||||||
@@ -65,3 +70,43 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
rewards,
|
rewards,
|
||||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base RewardTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base trl.PRMTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "prm"]
|
||||||
|
|||||||
33
src/axolotl/core/trainers/utils.py
Normal file
33
src/axolotl/core/trainers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Utils for Axolotl trainers"""
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||||
|
if isinstance(tag_names, str):
|
||||||
|
tag_names = [tag_names]
|
||||||
|
|
||||||
|
if kwargs is not None:
|
||||||
|
if "tags" not in kwargs:
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||||
|
kwargs["tags"].extend(tag_names)
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||||
|
tag_names.append(kwargs["tags"])
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||||
|
if isinstance(dataset_tags, str):
|
||||||
|
dataset_tags = [dataset_tags]
|
||||||
|
|
||||||
|
if (dataset_tags is not None) and (kwargs is not None):
|
||||||
|
if "dataset_tags" not in kwargs:
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||||
|
kwargs["dataset_tags"].extend(dataset_tags)
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||||
|
dataset_tags.append(kwargs["dataset_tags"])
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
|
return kwargs
|
||||||
@@ -207,14 +207,19 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sequence_parallel_degree: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
||||||
so it can't be used as a mixin.
|
default value so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
@@ -25,18 +27,18 @@ LOG = get_logger("axolotl.evaluate")
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_dataset(
|
def evaluate_dataset(
|
||||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
) -> Optional[Dict[str, float]]:
|
) -> Optional[Dict[str, float]]:
|
||||||
"""Helper function to evaluate a single dataset safely.
|
"""Helper function to evaluate a single dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainer: The trainer instance
|
trainer: The trainer instance.
|
||||||
dataset: Dataset to evaluate
|
dataset: Dataset to evaluate.
|
||||||
dataset_type: Type of dataset ('train' or 'eval')
|
dataset_type: Type of dataset ('train' or 'eval').
|
||||||
flash_optimum: Whether to use flash optimum
|
flash_optimum: Whether to use flash optimum.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of metrics or None if dataset is None
|
Dictionary of metrics or None if dataset is None.
|
||||||
"""
|
"""
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
@@ -63,17 +65,14 @@ def evaluate_dataset(
|
|||||||
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets
|
Evaluate a model on training and validation datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Dictionary mapping metric names to their values.
|
||||||
- The model (either PeftModel or PreTrainedModel)
|
|
||||||
- The tokenizer
|
|
||||||
- Dictionary of evaluation metrics
|
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
|||||||
@@ -11,19 +11,17 @@
|
|||||||
# the License.
|
# the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
module to handle merging the plugins' input arguments with the base configurations.
|
Module to handle merging the plugins' input arguments with the base configurations.
|
||||||
|
|
||||||
this was moved here to prevent circular imports
|
This was moved here to prevent circular imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_input_args():
|
def merge_input_args():
|
||||||
|
|||||||
89
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
89
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""
|
||||||
|
Ring attention group registration and flash attention patching.
|
||||||
|
|
||||||
|
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||||
|
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||||
|
their sequence parallel version of Flash Attention 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
RING_ATTN_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||||
|
"""
|
||||||
|
Getter for ring attention group on this rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The process group for ring attention for this rank.
|
||||||
|
"""
|
||||||
|
return RING_ATTN_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
|
||||||
|
"""
|
||||||
|
Setter for ring attention group on this rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Process group for ring attention.
|
||||||
|
"""
|
||||||
|
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||||
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
|
def register_ring_attn(sequence_parallel_degree: int):
|
||||||
|
"""
|
||||||
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_parallel_degree: Sequence parallelism factor.
|
||||||
|
"""
|
||||||
|
LOG.info(
|
||||||
|
"Enabling ring attention sequence parallelism: "
|
||||||
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
|
)
|
||||||
|
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
assert sequence_parallel_degree <= world_size, (
|
||||||
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
|
f"must be less than or equal to world_size ({world_size})"
|
||||||
|
)
|
||||||
|
assert world_size % sequence_parallel_degree == 0, (
|
||||||
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
|
f"must evenly divide world_size ({world_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detailed logging of group formation
|
||||||
|
rank = dist.get_rank()
|
||||||
|
group_assignments = {}
|
||||||
|
|
||||||
|
for i in range(world_size // sequence_parallel_degree):
|
||||||
|
ring_attn_ranks = list(
|
||||||
|
range(
|
||||||
|
i * sequence_parallel_degree,
|
||||||
|
(i + 1) * sequence_parallel_degree,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||||
|
|
||||||
|
# Track which GPUs are in which groups
|
||||||
|
for r in ring_attn_ranks:
|
||||||
|
group_assignments[r] = i
|
||||||
|
|
||||||
|
if rank in ring_attn_ranks:
|
||||||
|
set_ring_attn_group(group)
|
||||||
|
|
||||||
|
# Log the GPU group assignments
|
||||||
|
if rank == 0:
|
||||||
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
|
|
||||||
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|
||||||
|
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||||
@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
|||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
def default(
|
def default(
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ def execute_training(
|
|||||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute the training process with appropriate backend configurations.
|
Execute the training process with appropriate SDP kernel configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
@@ -177,9 +177,6 @@ def execute_training(
|
|||||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||||
"""
|
"""
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from trl.models import unwrap_model_for_generation
|
|||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
broadcast_dict,
|
broadcast_dict,
|
||||||
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
|
|||||||
is_main_process,
|
is_main_process,
|
||||||
zero_first,
|
zero_first,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|||||||
@@ -1,14 +1,59 @@
|
|||||||
"""
|
"""
|
||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
||||||
|
includes logic for handling sequence parallelism collation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_position_ids_for_slice(
|
||||||
|
position_ids: torch.Tensor, start_idx: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||||
|
This handles the case where position IDs might not be contiguous due to sample
|
||||||
|
packing.
|
||||||
|
"""
|
||||||
|
# Convert to tensor if not already
|
||||||
|
# Find the boundaries between samples (where position_ids reset)
|
||||||
|
adjusted_pos_ids = position_ids.clone()
|
||||||
|
|
||||||
|
# Process each sequence in the batch
|
||||||
|
for i in range(position_ids.shape[0]):
|
||||||
|
seq = position_ids[i]
|
||||||
|
|
||||||
|
# Find sample boundaries
|
||||||
|
boundaries = []
|
||||||
|
for j in range(1, len(seq)):
|
||||||
|
if seq[j] < seq[j - 1]:
|
||||||
|
boundaries.append(j)
|
||||||
|
|
||||||
|
# No need to adjust if there are no boundaries or this is a single sample
|
||||||
|
if not boundaries:
|
||||||
|
adjusted_pos_ids[i] = seq - start_idx
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Adjust each segment separately
|
||||||
|
prev_boundary = 0
|
||||||
|
for boundary in boundaries:
|
||||||
|
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||||
|
prev_boundary = boundary
|
||||||
|
|
||||||
|
# Last segment
|
||||||
|
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||||
|
|
||||||
|
return adjusted_pos_ids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -43,6 +88,8 @@ class DataCollatorForSeq2Seq:
|
|||||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
return_tensors (`str`):
|
return_tensors (`str`):
|
||||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
|
sequence_parallel_degree (`int`):
|
||||||
|
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
@@ -53,6 +100,16 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
sequence_parallel_degree: int = 1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
# Get information about our position in the SP group
|
||||||
|
sp_group = get_ring_attn_group()
|
||||||
|
self.local_rank = dist.get_rank(group=sp_group)
|
||||||
|
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -119,8 +176,43 @@ class DataCollatorForSeq2Seq:
|
|||||||
)
|
)
|
||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
|
if self.sequence_parallel_degree > 1:
|
||||||
|
features = self.apply_sequence_parallelism(features)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
def apply_sequence_parallelism(
|
||||||
|
self, batch: dict[str, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply sequence parallelism slicing to a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Batch dictionary from parent collator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sliced batch dictionary.
|
||||||
|
"""
|
||||||
|
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||||
|
|
||||||
|
for key in keys_to_slice:
|
||||||
|
if key in batch:
|
||||||
|
seq_len = batch[key].shape[1]
|
||||||
|
slice_size = seq_len // self.local_world_size
|
||||||
|
start_idx = self.local_rank * slice_size
|
||||||
|
end_idx = (
|
||||||
|
start_idx + slice_size
|
||||||
|
if self.local_rank < self.local_world_size - 1
|
||||||
|
else seq_len
|
||||||
|
)
|
||||||
|
batch[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
|
||||||
|
# Special handling for position_ids
|
||||||
|
if key == "position_ids" and self.local_rank > 0:
|
||||||
|
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
@@ -148,6 +240,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@@ -177,6 +270,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,19 +12,13 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
|
||||||
)
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
|
||||||
)
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
DPODataset,
|
|
||||||
KTODataset,
|
|
||||||
SFTDataset,
|
|
||||||
)
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model_config
|
from axolotl.utils.models import load_model_config
|
||||||
|
from axolotl.utils.schemas.config import (
|
||||||
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
|
)
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
|
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -131,6 +125,9 @@ def normalize_config(cfg):
|
|||||||
with open(ds_config_path, encoding="utf-8") as f:
|
with open(ds_config_path, encoding="utf-8") as f:
|
||||||
cfg.deepspeed = json.load(f)
|
cfg.deepspeed = json.load(f)
|
||||||
|
|
||||||
|
if cfg.sequence_parallel_degree is None:
|
||||||
|
cfg.sequence_parallel_degree = 1
|
||||||
|
|
||||||
if cfg.saves_per_epoch:
|
if cfg.saves_per_epoch:
|
||||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
if save_steps < 1.0: # prevent saves on every step
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
|
|||||||
@@ -67,7 +67,12 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
|
|||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MULTIMODEL_AUTO_MODEL_MAPPING = {
|
||||||
|
"llava": LlavaForConditionalGeneration,
|
||||||
|
"mllama": MllamaForConditionalGeneration,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||||
@@ -476,7 +481,7 @@ class ModelLoader:
|
|||||||
else:
|
else:
|
||||||
self.text_model_config = self.model_config
|
self.text_model_config = self.model_config
|
||||||
|
|
||||||
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
# load any patches from plugins
|
# load any patches from plugins
|
||||||
@@ -547,6 +552,14 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||||
|
|
||||||
|
# Initialize ring attn for sequence parallelism. This must be done after
|
||||||
|
# model init but before the first forward pass, since it modifies flash
|
||||||
|
# attn to use ring comm for SP training across multiple GPUs.
|
||||||
|
register_ring_attn(self.cfg.sequence_parallel_degree)
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||||
@@ -603,7 +616,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
def patch_llama_derived_model(self) -> None:
|
def patch_llama_derived_model(self):
|
||||||
"""Modify all llama derived models in one block"""
|
"""Modify all llama derived models in one block"""
|
||||||
self.patch_loss_llama()
|
self.patch_loss_llama()
|
||||||
|
|
||||||
@@ -653,24 +666,15 @@ class ModelLoader:
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_auto_model_loader(self) -> None:
|
def set_auto_model_loader(self):
|
||||||
"""set self.AutoModelLoader
|
"""
|
||||||
- default value: AutoModelForCausalLM (set at __init__)
|
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
- when using a multi modality model, self.AutoModelLoader should
|
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||||
be set according to model type of the model
|
should be set according to the type of the model.
|
||||||
"""
|
"""
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
if self.model_config.model_type == "llava":
|
self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get(
|
||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
self.model_config.model_type, AutoModelForVision2Seq
|
||||||
LlavaForConditionalGeneration
|
|
||||||
)
|
|
||||||
elif self.model_config.model_type == "mllama":
|
|
||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
|
||||||
MllamaForConditionalGeneration
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.AutoModelLoader = (
|
|
||||||
AutoModelForVision2Seq # pylint: disable=invalid-name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_device_map_config(self) -> None:
|
def set_device_map_config(self) -> None:
|
||||||
@@ -695,7 +699,7 @@ class ModelLoader:
|
|||||||
from accelerate import infer_auto_device_map
|
from accelerate import infer_auto_device_map
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model_canvas = self.AutoModelLoader.from_config(
|
model_canvas = self.auto_model_loader.from_config(
|
||||||
self.model_config,
|
self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
)
|
)
|
||||||
@@ -916,7 +920,23 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
|
||||||
|
# Load model with random initialization if specified
|
||||||
|
if self.cfg.random_init_weights:
|
||||||
|
# AutoModel classes support the from_config method
|
||||||
|
if self.auto_model_loader in [
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForVision2Seq,
|
||||||
|
]:
|
||||||
|
self.model = self.auto_model_loader.from_config(
|
||||||
|
config=self.model_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = self.auto_model_loader(
|
||||||
|
config=self.model_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
@@ -958,7 +978,7 @@ class ModelLoader:
|
|||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -991,7 +1011,7 @@ class ModelLoader:
|
|||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1011,7 +1031,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1307,7 +1327,7 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
loader = ModelLoader(
|
model_loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@@ -1315,7 +1335,7 @@ def load_model(
|
|||||||
reference_model=reference_model,
|
reference_model=reference_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return loader.load_model()
|
return model_loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
|
|||||||
@@ -104,9 +104,7 @@ def allocate(
|
|||||||
|
|
||||||
|
|
||||||
class MultipackBatchSampler(BatchSampler):
|
class MultipackBatchSampler(BatchSampler):
|
||||||
"""
|
"""Batch sampler class for multipack"""
|
||||||
Batch Sampler class for multipack
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
165
src/axolotl/utils/schemas/datasets.py
Normal file
165
src/axolotl/utils/schemas/datasets.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""Pydantic models for datasets-related configuration"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import ChatTemplate
|
||||||
|
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedPrompterType(BaseModel):
|
||||||
|
"""Structure for user defined prompt types"""
|
||||||
|
|
||||||
|
system_prompt: str | None = None
|
||||||
|
system_format: str | None = None
|
||||||
|
field_system: str | None = None
|
||||||
|
field_instruction: str | None = None
|
||||||
|
field_input: str | None = None
|
||||||
|
field_output: str | None = None
|
||||||
|
|
||||||
|
format: str | None = None
|
||||||
|
no_input_format: str | None = None
|
||||||
|
field: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SFTDataset(BaseModel):
|
||||||
|
"""SFT configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: str | UserDefinedPrompterType | None = None
|
||||||
|
input_transform: str | None = None
|
||||||
|
shards: int | None = None
|
||||||
|
shards_idx: int | None = None
|
||||||
|
preprocess_shards: int | None = None
|
||||||
|
conversation: str | None = None
|
||||||
|
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||||
|
chat_template: ChatTemplate | str | None = None
|
||||||
|
chat_template_jinja: str | None = None
|
||||||
|
data_files: str | list[str] | None = None
|
||||||
|
input_format: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
ds_type: str | None = None
|
||||||
|
train_on_split: str | None = None
|
||||||
|
field: str | None = None
|
||||||
|
field_human: str | None = None
|
||||||
|
field_model: str | None = None
|
||||||
|
field_messages: str | None = None
|
||||||
|
# deprecated, use message_property_mappings
|
||||||
|
message_field_role: str | None = None
|
||||||
|
# deprecated, use message_property_mappings
|
||||||
|
message_field_content: str | None = None
|
||||||
|
message_property_mappings: dict[str, str] | None = None
|
||||||
|
message_field_training: str | None = None
|
||||||
|
message_field_training_detail: str | None = None
|
||||||
|
logprobs_field: str | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
roles_to_train: list[str] | None = None
|
||||||
|
train_on_eos: str | None = None
|
||||||
|
roles: dict[str, list[str]] | None = None
|
||||||
|
drop_system_message: bool | None = None
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
revision: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def handle_legacy_message_fields(cls, data):
|
||||||
|
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
||||||
|
return handle_legacy_message_fields_logic(data)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def check_chat_template_config(cls, data):
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
data = data.model_dump()
|
||||||
|
|
||||||
|
# Set chat_template to tokenizer_default if not set
|
||||||
|
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||||
|
|
||||||
|
# if chat_template is set to jinja, chat_template_jinja is required
|
||||||
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
|
"chat_template_jinja"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If chat_template_jinja is set, set chat_template to jinja
|
||||||
|
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.jinja
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainingDataset(BaseModel):
|
||||||
|
"""Pretraining dataset configuration subset"""
|
||||||
|
|
||||||
|
name: str | None = None
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = "train"
|
||||||
|
text_column: str | None = "text"
|
||||||
|
type: str | None = "pretrain"
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
data_files: str | None = None
|
||||||
|
skip: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedDPOType(BaseModel):
|
||||||
|
"""User defined typing for DPO"""
|
||||||
|
|
||||||
|
field_system: str | None = None
|
||||||
|
field_prompt: str | None = None
|
||||||
|
field_chosen: str | None = None
|
||||||
|
field_rejected: str | None = None
|
||||||
|
prompt_format: str | None = None
|
||||||
|
chosen_format: str | None = None
|
||||||
|
rejected_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DPODataset(BaseModel):
|
||||||
|
"""DPO configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: UserDefinedDPOType | str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
revision: str | None = None
|
||||||
|
field_messages: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StepwiseSupervisedDataset(BaseModel):
|
||||||
|
"""Stepwise supervised dataset configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
revision: str | None = None
|
||||||
|
step_separator: str | None = None
|
||||||
|
max_completion_length: int | None = None
|
||||||
|
train_on_last_step_only: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedKTOType(BaseModel):
|
||||||
|
"""User defined typing for KTO"""
|
||||||
|
|
||||||
|
field_system: str | None = None
|
||||||
|
field_prompt: str | None = None
|
||||||
|
field_completion: str | None = None
|
||||||
|
field_label: bool | None = None
|
||||||
|
prompt_format: str | None = None
|
||||||
|
completion_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class KTODataset(BaseModel):
|
||||||
|
"""KTO configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: UserDefinedKTOType | str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
revision: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
||||||
68
src/axolotl/utils/schemas/deprecated.py
Normal file
68
src/axolotl/utils/schemas/deprecated.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""Pydantic models for deprecated and remapped configuration parameters"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedParameters(BaseModel):
|
||||||
|
"""configurations that are deprecated"""
|
||||||
|
|
||||||
|
max_packed_sequence_len: int | None = None
|
||||||
|
rope_scaling: Any | None = None
|
||||||
|
noisy_embedding_alpha: float | None = None
|
||||||
|
dpo_beta: float | None = None
|
||||||
|
evaluation_strategy: str | None = None
|
||||||
|
|
||||||
|
@field_validator("max_packed_sequence_len")
|
||||||
|
@classmethod
|
||||||
|
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
|
||||||
|
if max_packed_sequence_len:
|
||||||
|
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||||
|
return max_packed_sequence_len
|
||||||
|
|
||||||
|
@field_validator("rope_scaling")
|
||||||
|
@classmethod
|
||||||
|
def validate_rope_scaling(cls, rope_scaling):
|
||||||
|
if rope_scaling:
|
||||||
|
raise DeprecationWarning(
|
||||||
|
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
|
||||||
|
)
|
||||||
|
return rope_scaling
|
||||||
|
|
||||||
|
@field_validator("noisy_embedding_alpha")
|
||||||
|
@classmethod
|
||||||
|
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
|
||||||
|
if noisy_embedding_alpha:
|
||||||
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
|
return noisy_embedding_alpha
|
||||||
|
|
||||||
|
@field_validator("dpo_beta")
|
||||||
|
@classmethod
|
||||||
|
def validate_dpo_beta(cls, dpo_beta):
|
||||||
|
if dpo_beta is not None:
|
||||||
|
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||||
|
return dpo_beta
|
||||||
|
|
||||||
|
@field_validator("evaluation_strategy")
|
||||||
|
@classmethod
|
||||||
|
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||||
|
if evaluation_strategy is not None:
|
||||||
|
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||||
|
return evaluation_strategy
|
||||||
|
|
||||||
|
|
||||||
|
class RemappedParameters(BaseModel):
|
||||||
|
"""Parameters that have been remapped to other names"""
|
||||||
|
|
||||||
|
overrides_of_model_config: dict[str, Any] | None = Field(
|
||||||
|
default=None, alias="model_config"
|
||||||
|
)
|
||||||
|
overrides_of_model_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None, alias="model_kwargs"
|
||||||
|
)
|
||||||
|
type_of_model: str | None = Field(default=None, alias="model_type")
|
||||||
|
revision_of_model: str | None = Field(default=None, alias="model_revision")
|
||||||
49
src/axolotl/utils/schemas/enums.py
Normal file
49
src/axolotl/utils/schemas/enums.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Enums for Axolotl input config"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class RLType(str, Enum):
|
||||||
|
"""RL trainer type configuration subset"""
|
||||||
|
|
||||||
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
|
grpo = "grpo" # pylint: disable=invalid-name
|
||||||
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
simpo = "simpo" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTemplate(str, Enum):
|
||||||
|
"""Chat templates configuration subset"""
|
||||||
|
|
||||||
|
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||||
|
chatml = "chatml" # pylint: disable=invalid-name
|
||||||
|
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||||
|
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||||
|
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||||
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
|
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||||
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
|
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||||
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
|
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
||||||
|
jamba = "jamba" # pylint: disable=invalid-name
|
||||||
|
jinja = "jinja" # pylint: disable=invalid-name
|
||||||
|
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||||
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
|
"""Custom supported optimizers"""
|
||||||
|
|
||||||
|
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
108
src/axolotl/utils/schemas/integrations.py
Normal file
108
src/axolotl/utils/schemas/integrations.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Pydantic models for Axolotl integrations"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MLFlowConfig(BaseModel):
|
||||||
|
"""MLFlow configuration subset"""
|
||||||
|
|
||||||
|
use_mlflow: bool | None = None
|
||||||
|
mlflow_tracking_uri: str | None = None
|
||||||
|
mlflow_experiment_name: str | None = None
|
||||||
|
mlflow_run_name: str | None = None
|
||||||
|
hf_mlflow_log_artifacts: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LISAConfig(BaseModel):
|
||||||
|
"""LISA configuration subset"""
|
||||||
|
|
||||||
|
lisa_n_layers: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "the number of activate layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_step_interval: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "how often to switch layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_layers_attribute: str | None = Field(
|
||||||
|
default="model.layers",
|
||||||
|
json_schema_extra={"description": "path under the model to access the layers"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WandbConfig(BaseModel):
|
||||||
|
"""Wandb configuration subset"""
|
||||||
|
|
||||||
|
use_wandb: bool | None = None
|
||||||
|
wandb_name: str | None = None
|
||||||
|
wandb_run_id: str | None = None
|
||||||
|
wandb_mode: str | None = None
|
||||||
|
wandb_project: str | None = None
|
||||||
|
wandb_entity: str | None = None
|
||||||
|
wandb_watch: str | None = None
|
||||||
|
wandb_log_model: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_wandb_run(cls, data):
|
||||||
|
if data.get("wandb_run_id") and not data.get("wandb_name"):
|
||||||
|
data["wandb_name"] = data.get("wandb_run_id")
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CometConfig(BaseModel):
|
||||||
|
"""Comet configuration subset"""
|
||||||
|
|
||||||
|
use_comet: bool | None = None
|
||||||
|
comet_api_key: str | None = None
|
||||||
|
comet_workspace: str | None = None
|
||||||
|
comet_project_name: str | None = None
|
||||||
|
comet_experiment_key: str | None = None
|
||||||
|
comet_mode: str | None = None
|
||||||
|
comet_online: bool | None = None
|
||||||
|
comet_experiment_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GradioConfig(BaseModel):
|
||||||
|
"""Gradio configuration subset"""
|
||||||
|
|
||||||
|
gradio_title: str | None = None
|
||||||
|
gradio_share: bool | None = None
|
||||||
|
gradio_server_name: str | None = None
|
||||||
|
gradio_server_port: int | None = None
|
||||||
|
gradio_max_new_tokens: int | None = None
|
||||||
|
gradio_temperature: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RayConfig(BaseModel):
|
||||||
|
"""Ray launcher configuration subset"""
|
||||||
|
|
||||||
|
use_ray: bool = Field(default=False)
|
||||||
|
ray_run_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ray_num_workers: int = Field(
|
||||||
|
default=1,
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resources_per_worker: dict = Field(
|
||||||
|
default_factory=lambda: {"GPU": 1},
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||||
|
},
|
||||||
|
)
|
||||||
55
src/axolotl/utils/schemas/model.py
Normal file
55
src/axolotl/utils/schemas/model.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Pydantic models for model input / output, etc. configuration"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInputConfig(BaseModel):
|
||||||
|
"""Model configuration subset"""
|
||||||
|
|
||||||
|
model_config = {"protected_namespaces": ()}
|
||||||
|
|
||||||
|
base_model: str
|
||||||
|
base_model_config: str | None = None
|
||||||
|
cls_model_config: str | None = None
|
||||||
|
tokenizer_config: str | None = None
|
||||||
|
tokenizer_use_fast: bool | None = None
|
||||||
|
tokenizer_legacy: bool | None = None
|
||||||
|
tokenizer_type: str | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||||
|
)
|
||||||
|
processor_type: str | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
|
)
|
||||||
|
trust_remote_code: bool | None = None
|
||||||
|
|
||||||
|
@field_validator("trust_remote_code")
|
||||||
|
@classmethod
|
||||||
|
def hint_trust_remote_code(cls, trust_remote_code):
|
||||||
|
if trust_remote_code:
|
||||||
|
LOG.warning(
|
||||||
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
|
)
|
||||||
|
return trust_remote_code
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutputConfig(BaseModel):
|
||||||
|
"""model save configuration subset"""
|
||||||
|
|
||||||
|
output_dir: str = Field(default="./model-out")
|
||||||
|
hub_model_id: str | None = None
|
||||||
|
hub_strategy: str | None = None
|
||||||
|
save_safetensors: bool | None = True
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialTokensConfig(BaseModel):
|
||||||
|
"""Special tokens configuration subset"""
|
||||||
|
|
||||||
|
bos_token: str | None = None
|
||||||
|
eos_token: str | None = None
|
||||||
|
pad_token: str | None = None
|
||||||
|
unk_token: str | None = None
|
||||||
|
additional_special_tokens: list[str] | None = None
|
||||||
132
src/axolotl/utils/schemas/peft.py
Normal file
132
src/axolotl/utils/schemas/peft.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Pydantic models for PEFT-related configuration"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class LoftQConfig(BaseModel):
|
||||||
|
"""LoftQ configuration subset"""
|
||||||
|
|
||||||
|
loftq_bits: int = Field(
|
||||||
|
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||||
|
)
|
||||||
|
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||||
|
|
||||||
|
|
||||||
|
class PeftConfig(BaseModel):
|
||||||
|
"""peftq configuration subset"""
|
||||||
|
|
||||||
|
loftq_config: LoftQConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LoraConfig(BaseModel):
|
||||||
|
"""Peft / LoRA configuration subset"""
|
||||||
|
|
||||||
|
load_in_8bit: bool | None = Field(default=False)
|
||||||
|
load_in_4bit: bool | None = Field(default=False)
|
||||||
|
|
||||||
|
adapter: str | None = None
|
||||||
|
lora_model_dir: str | None = None
|
||||||
|
lora_r: int | None = None
|
||||||
|
lora_alpha: int | None = None
|
||||||
|
lora_fan_in_fan_out: bool | None = None
|
||||||
|
lora_target_modules: str | list[str] | None = None
|
||||||
|
lora_target_linear: bool | None = None
|
||||||
|
lora_modules_to_save: list[str] | None = None
|
||||||
|
lora_dropout: float | None = 0.0
|
||||||
|
peft_layers_to_transform: list[int] | None = None
|
||||||
|
peft_layers_pattern: list[str] | None = None
|
||||||
|
peft: PeftConfig | None = None
|
||||||
|
peft_use_dora: bool | None = None
|
||||||
|
peft_use_rslora: bool | None = None
|
||||||
|
peft_layer_replication: list[tuple[int, int]] | None = None
|
||||||
|
peft_init_lora_weights: bool | str | None = None
|
||||||
|
|
||||||
|
qlora_sharded_model_loading: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
lora_on_cpu: bool | None = None
|
||||||
|
gptq: bool | None = None
|
||||||
|
bnb_config_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
loraplus_lr_ratio: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
loraplus_lr_embedding: float | None = Field(
|
||||||
|
default=1e-6,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate for lora embedding layers."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_lora: bool | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_adapter(cls, data):
|
||||||
|
if (
|
||||||
|
not data.get("adapter")
|
||||||
|
and not data.get("inference")
|
||||||
|
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||||
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_qlora(self):
|
||||||
|
if self.adapter == "qlora":
|
||||||
|
if self.merge_lora:
|
||||||
|
# can't merge qlora if loaded in 8bit or 4bit
|
||||||
|
if self.load_in_8bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
|
|
||||||
|
if self.gptq:
|
||||||
|
raise ValueError("Can't merge qlora if gptq")
|
||||||
|
|
||||||
|
if self.load_in_4bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.load_in_8bit:
|
||||||
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
|
|
||||||
|
if self.gptq:
|
||||||
|
raise ValueError("Can't load qlora if gptq")
|
||||||
|
|
||||||
|
if not self.load_in_4bit:
|
||||||
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("loraplus_lr_embedding")
|
||||||
|
@classmethod
|
||||||
|
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
|
||||||
|
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
|
||||||
|
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
||||||
|
return loraplus_lr_embedding
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_lora_dropout(cls, data):
|
||||||
|
if data.get("adapter") is not None and data.get("lora_dropout") is None:
|
||||||
|
data["lora_dropout"] = 0.0
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRAConfig(BaseModel):
|
||||||
|
"""ReLoRA configuration subset"""
|
||||||
|
|
||||||
|
relora_steps: int | None = None
|
||||||
|
relora_warmup_steps: int | None = None
|
||||||
|
relora_anneal_steps: int | None = None
|
||||||
|
relora_prune_ratio: float | None = None
|
||||||
|
relora_cpu_offload: bool | None = None
|
||||||
99
src/axolotl/utils/schemas/training.py
Normal file
99
src/axolotl/utils/schemas/training.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Pydantic models for training hyperparameters"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from transformers import SchedulerType
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LrGroup(BaseModel):
|
||||||
|
"""Custom learning rate group configuration"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
modules: list[str]
|
||||||
|
lr: float
|
||||||
|
|
||||||
|
|
||||||
|
class HyperparametersConfig(BaseModel):
|
||||||
|
"""Training hyperparams configuration subset"""
|
||||||
|
|
||||||
|
gradient_accumulation_steps: int | None = Field(default=1)
|
||||||
|
micro_batch_size: int | None = Field(
|
||||||
|
default=1,
|
||||||
|
json_schema_extra={"description": "per gpu micro batch size for training"},
|
||||||
|
)
|
||||||
|
batch_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Total batch size, we do not recommended setting this manually"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eval_batch_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_find_batch_size: bool | None = None
|
||||||
|
|
||||||
|
train_on_inputs: bool | None = False
|
||||||
|
group_by_length: bool | None = None
|
||||||
|
|
||||||
|
learning_rate: str | float
|
||||||
|
embedding_lr: float | None = None
|
||||||
|
embedding_lr_scale: float | None = None
|
||||||
|
weight_decay: float | None = 0.0
|
||||||
|
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
|
||||||
|
OptimizerNames.ADAMW_TORCH_FUSED
|
||||||
|
)
|
||||||
|
optim_args: (str | dict[str, Any]) | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||||
|
)
|
||||||
|
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
torchdistx_path: str | None = None
|
||||||
|
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
|
||||||
|
SchedulerType.COSINE
|
||||||
|
)
|
||||||
|
lr_scheduler_kwargs: dict[str, Any] | None = None
|
||||||
|
lr_quadratic_warmup: bool | None = None
|
||||||
|
cosine_min_lr_ratio: float | None = None
|
||||||
|
cosine_constant_lr_ratio: float | None = None
|
||||||
|
lr_div_factor: float | None = None
|
||||||
|
lr_groups: list[LrGroup] | None = None
|
||||||
|
|
||||||
|
adam_epsilon: float | None = None
|
||||||
|
adam_beta1: float | None = None
|
||||||
|
adam_beta2: float | None = None
|
||||||
|
max_grad_norm: float | None = None
|
||||||
|
num_epochs: float = Field(default=1.0)
|
||||||
|
|
||||||
|
@field_validator("batch_size")
|
||||||
|
@classmethod
|
||||||
|
def hint_batch_size_set(cls, batch_size):
|
||||||
|
if batch_size:
|
||||||
|
LOG.warning(
|
||||||
|
"%s\n%s",
|
||||||
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
|
)
|
||||||
|
return batch_size
|
||||||
|
|
||||||
|
@field_validator("learning_rate")
|
||||||
|
@classmethod
|
||||||
|
def convert_learning_rate(cls, learning_rate):
|
||||||
|
if learning_rate and isinstance(learning_rate, str):
|
||||||
|
learning_rate = float(learning_rate)
|
||||||
|
return learning_rate
|
||||||
@@ -1,8 +1,4 @@
|
|||||||
"""
|
"""Pydantic models for TRL trainer configuration"""
|
||||||
GRPO specific configuration args
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -12,11 +8,11 @@ class TRLConfig(BaseModel):
|
|||||||
Input args for TRL.
|
Input args for TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beta: Optional[float] = Field(
|
beta: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Beta for RL training"},
|
json_schema_extra={"description": "Beta for RL training"},
|
||||||
)
|
)
|
||||||
max_completion_length: Optional[int] = Field(
|
max_completion_length: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the completion for RL training"
|
"description": "Maximum length of the completion for RL training"
|
||||||
@@ -25,50 +21,50 @@ class TRLConfig(BaseModel):
|
|||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||||
use_vllm: Optional[bool] = Field(
|
use_vllm: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||||
)
|
)
|
||||||
vllm_device: Optional[str] = Field(
|
vllm_device: str | None = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Device to use for VLLM"},
|
json_schema_extra={"description": "Device to use for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_gpu_memory_utilization: Optional[float] = Field(
|
vllm_gpu_memory_utilization: float | None = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_dtype: Optional[str] = Field(
|
vllm_dtype: str | None = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Data type for VLLM"},
|
json_schema_extra={"description": "Data type for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_max_model_len: Optional[int] = Field(
|
vllm_max_model_len: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the model context for VLLM"
|
"description": "Maximum length of the model context for VLLM"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
reward_funcs: Optional[list[str]] = Field(
|
reward_funcs: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "List of reward functions to load"},
|
json_schema_extra={"description": "List of reward functions to load"},
|
||||||
)
|
)
|
||||||
reward_weights: Optional[list[float]] = Field(
|
reward_weights: list[float] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_generations: Optional[int] = Field(
|
num_generations: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log_completions: Optional[bool] = Field(
|
log_completions: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to log completions"},
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
)
|
)
|
||||||
sync_ref_model: Optional[bool] = Field(
|
sync_ref_model: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": (
|
"description": (
|
||||||
@@ -77,13 +73,13 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_mixup_alpha: Optional[float] = Field(
|
ref_model_mixup_alpha: float | None = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_sync_steps: Optional[int] = Field(
|
ref_model_sync_steps: int | None = Field(
|
||||||
default=64,
|
default=64,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||||
79
src/axolotl/utils/schemas/utils.py
Normal file
79
src/axolotl/utils/schemas/utils.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Utilities for Axolotl Pydantic models"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
||||||
|
|
||||||
|
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
||||||
|
- message_field_role: Mapped to the role field
|
||||||
|
- message_field_content: Mapped to the content field
|
||||||
|
|
||||||
|
The new system uses message_property_mappings to support arbitrary field mappings:
|
||||||
|
message_property_mappings:
|
||||||
|
role: source_role_field
|
||||||
|
content: source_content_field
|
||||||
|
additional_field: source_field
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing configuration data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated dictionary with message field mappings consolidated
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there are conflicts between legacy and new mappings
|
||||||
|
"""
|
||||||
|
data = data.copy() # Create a copy to avoid modifying the original
|
||||||
|
|
||||||
|
if data.get("message_property_mappings") is None:
|
||||||
|
data["message_property_mappings"] = {}
|
||||||
|
|
||||||
|
# Check for conflicts and handle role
|
||||||
|
if "message_field_role" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_role is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"role" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
||||||
|
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
||||||
|
|
||||||
|
del data["message_field_role"]
|
||||||
|
elif "role" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["role"] = "role"
|
||||||
|
|
||||||
|
# Check for conflicts and handle content
|
||||||
|
if "message_field_content" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_content is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"content" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["content"]
|
||||||
|
!= data["message_field_content"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||||
|
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["content"] = (
|
||||||
|
data["message_field_content"] or "content"
|
||||||
|
)
|
||||||
|
|
||||||
|
del data["message_field_content"]
|
||||||
|
elif "content" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["content"] = "content"
|
||||||
|
|
||||||
|
return data
|
||||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing:
|
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
|
* cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||||
@@ -473,7 +474,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
total_num_steps = int(
|
||||||
|
math.floor(
|
||||||
|
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
@@ -494,7 +499,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(
|
||||||
|
len(train_dataset)
|
||||||
|
* cfg.num_epochs
|
||||||
|
* cfg.sequence_parallel_degree
|
||||||
|
/ cfg.batch_size
|
||||||
|
)
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|||||||
90
styles.css
90
styles.css
@@ -14,7 +14,7 @@
|
|||||||
h1 {
|
h1 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 5rem;
|
font-size: 3rem;
|
||||||
line-height: 1.1;
|
line-height: 1.1;
|
||||||
letter-spacing: -0.05em;
|
letter-spacing: -0.05em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -24,7 +24,7 @@ h1 {
|
|||||||
h2 {
|
h2 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
font-size: 2rem;
|
font-size: 1.5rem;
|
||||||
line-height: 1.2;
|
line-height: 1.2;
|
||||||
letter-spacing: -0.03em;
|
letter-spacing: -0.03em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -35,7 +35,7 @@ h3,
|
|||||||
h4 {
|
h4 {
|
||||||
font-family: var(--font-body);
|
font-family: var(--font-body);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 1.5rem;
|
font-size: 1.25rem;
|
||||||
line-height: 1.5;
|
line-height: 1.5;
|
||||||
letter-spacing: -0.02em;
|
letter-spacing: -0.02em;
|
||||||
}
|
}
|
||||||
@@ -191,3 +191,87 @@ code span.er {
|
|||||||
color: #5cb85c !important;
|
color: #5cb85c !important;
|
||||||
text-decoration: none !important;
|
text-decoration: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* API Documentation Styling */
|
||||||
|
|
||||||
|
/* Improve docstring section rendering */
|
||||||
|
.level3 p {
|
||||||
|
white-space: pre-line !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Format docstring sections */
|
||||||
|
.level3 p strong {
|
||||||
|
display: block;
|
||||||
|
margin-top: 1em;
|
||||||
|
font-weight: bold;
|
||||||
|
color: var(--cyan);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Add spacing after sections */
|
||||||
|
.level3 p:has(strong) {
|
||||||
|
margin-bottom: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Format Args and Returns sections */
|
||||||
|
p:has(code) {
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Function signatures */
|
||||||
|
.sourceCode {
|
||||||
|
margin-bottom: 1.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Parameter tables */
|
||||||
|
.doc-section-parameters table,
|
||||||
|
.doc-section-returns table {
|
||||||
|
margin-top: 1em;
|
||||||
|
margin-bottom: 1.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Make parameter and returns headers smaller */
|
||||||
|
h2.anchored[data-anchor-id="parameters"],
|
||||||
|
h2.anchored[data-anchor-id="returns"],
|
||||||
|
.doc-section-parameters h4,
|
||||||
|
.doc-section-returns h4 {
|
||||||
|
font-size: 1.25rem;
|
||||||
|
margin-top: 2rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
color: var(--lime);
|
||||||
|
border-bottom: 1px solid var(--lime);
|
||||||
|
padding-bottom: 0.3rem;
|
||||||
|
font-family: var(--font-body);
|
||||||
|
font-weight: 500;
|
||||||
|
letter-spacing: normal;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Style documentation tables */
|
||||||
|
table {
|
||||||
|
width: 100%;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
border-collapse: collapse;
|
||||||
|
}
|
||||||
|
|
||||||
|
table th {
|
||||||
|
background-color: #1a1a1a;
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
border-bottom: 2px solid var(--greige-600);
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
|
||||||
|
table td {
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
border-bottom: 1px solid var(--greige-600);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Code in table cells */
|
||||||
|
table td code {
|
||||||
|
background-color: transparent !important;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Improve spacing in parameter and return tables */
|
||||||
|
.doc-section-parameters,
|
||||||
|
.doc-section-returns {
|
||||||
|
margin-top: 1rem;
|
||||||
|
}
|
||||||
|
|||||||
207
tests/e2e/patched/test_sp.py
Normal file
207
tests/e2e/patched/test_sp.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""Tests for sequence parallelism functionality."""
|
||||||
|
|
||||||
|
# pylint: disable=redefined-outer-name,unused-argument
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from accelerate.state import PartialState
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
# Use a single patch for ring_flash_attn if it's not available
|
||||||
|
ring_flash_attn_mock = MagicMock()
|
||||||
|
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def partial_state():
|
||||||
|
"""Create a real PartialState instance for testing."""
|
||||||
|
state = PartialState()
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="cfg")
|
||||||
|
def fixture_cfg():
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 1e-3,
|
||||||
|
"output_dir": "./model-out",
|
||||||
|
"sequence_len": 512,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceParallelHelpers:
|
||||||
|
"""Test helper functions used in sequence parallelism."""
|
||||||
|
|
||||||
|
def test_adjust_position_ids_for_slice(self, partial_state):
|
||||||
|
"""Test position_ids adjustment for sequence slices."""
|
||||||
|
# Create sample position_ids with multiple sequences
|
||||||
|
position_ids = torch.tensor(
|
||||||
|
[
|
||||||
|
# First sequence with 2 samples
|
||||||
|
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
||||||
|
# Second sequence with 3 samples
|
||||||
|
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adjust as if this was the second slice (start_idx = 4)
|
||||||
|
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
||||||
|
|
||||||
|
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
||||||
|
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
||||||
|
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
||||||
|
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
||||||
|
|
||||||
|
assert torch.all(adjusted[0] == expected_first_seq)
|
||||||
|
assert torch.all(adjusted[1] == expected_second_seq)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRingAttention:
|
||||||
|
"""Tests for the ring attention functionality."""
|
||||||
|
|
||||||
|
@patch("torch.distributed.new_group")
|
||||||
|
@patch("torch.distributed.get_rank")
|
||||||
|
@patch("torch.distributed.get_world_size")
|
||||||
|
def test_register_ring_attn(
|
||||||
|
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||||
|
):
|
||||||
|
"""Test that ring attention groups are created correctly."""
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||||
|
|
||||||
|
# Setup mocks
|
||||||
|
mock_world_size.return_value = 8 # 8 GPUs total
|
||||||
|
mock_rank.return_value = 3 # GPU #3
|
||||||
|
mock_group = MagicMock()
|
||||||
|
mock_new_group.return_value = mock_group
|
||||||
|
|
||||||
|
# Call register_ring_attn with size 4
|
||||||
|
register_ring_attn(sequence_parallel_degree=4)
|
||||||
|
|
||||||
|
# Verify the number of calls without examining the arguments
|
||||||
|
assert mock_new_group.call_count == 2
|
||||||
|
|
||||||
|
# Just verify that new_group was called
|
||||||
|
mock_new_group.assert_called()
|
||||||
|
|
||||||
|
@patch("torch.distributed.get_rank")
|
||||||
|
@patch("torch.distributed.get_world_size")
|
||||||
|
def test_get_ring_attn_group_no_registration(
|
||||||
|
self, mock_world_size, mock_rank, partial_state
|
||||||
|
):
|
||||||
|
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_world_size.return_value = 4
|
||||||
|
mock_rank.return_value = 0
|
||||||
|
|
||||||
|
# Get the group without registration
|
||||||
|
group = get_ring_attn_group()
|
||||||
|
|
||||||
|
# Verify that None was returned
|
||||||
|
assert group is None
|
||||||
|
|
||||||
|
|
||||||
|
# Mock a simplified DataCollator test
|
||||||
|
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
|
||||||
|
@patch("torch.distributed.get_rank")
|
||||||
|
@patch("torch.distributed.get_world_size")
|
||||||
|
def test_sequence_parallel_slicing(
|
||||||
|
mock_world_size, mock_rank, mock_get_group, partial_state
|
||||||
|
):
|
||||||
|
"""Test the basic sequence slicing logic without full collator instantiation."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_get_group.return_value = MagicMock()
|
||||||
|
mock_rank.return_value = 1 # Second GPU
|
||||||
|
mock_world_size.return_value = 4 # 4 GPUs total
|
||||||
|
|
||||||
|
# Create a sample batch
|
||||||
|
batch = {
|
||||||
|
"input_ids": torch.tensor(
|
||||||
|
[
|
||||||
|
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
|
||||||
|
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"attention_mask": torch.ones(2, 12),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Simplified slicing logic from SequenceParallelDataCollator
|
||||||
|
def slice_batch(batch, rank, world_size):
|
||||||
|
result = {}
|
||||||
|
for key in batch:
|
||||||
|
seq_len = batch[key].shape[1]
|
||||||
|
slice_size = seq_len // world_size
|
||||||
|
start_idx = rank * slice_size
|
||||||
|
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
||||||
|
result[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Slice the batch
|
||||||
|
result = slice_batch(
|
||||||
|
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check slicing
|
||||||
|
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
|
||||||
|
expected_input_ids = torch.tensor(
|
||||||
|
[
|
||||||
|
[104, 105, 106], # Second slice of first sequence
|
||||||
|
[204, 205, 206], # Second slice of second sequence
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert torch.all(result["input_ids"] == expected_input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
|
||||||
|
def test_config_validation_with_valid_inputs(cfg):
|
||||||
|
"""Test that valid sequence parallelism configurations pass validation."""
|
||||||
|
# Import the actual model class with appropriate mocks
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
||||||
|
cfg = cfg | {
|
||||||
|
"sequence_parallel_degree": 2,
|
||||||
|
"flash_attention": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should validate without errors
|
||||||
|
config = AxolotlInputConfig(**cfg)
|
||||||
|
assert config.sequence_parallel_degree == 2
|
||||||
|
assert config.flash_attention is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_validation_with_invalid_inputs(cfg):
|
||||||
|
"""Test that invalid sequence parallelism configurations fail validation."""
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
||||||
|
cfg = cfg | {
|
||||||
|
"sequence_parallel_degree": 2,
|
||||||
|
"flash_attention": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should raise ValidationError
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
AxolotlInputConfig(**cfg)
|
||||||
|
|
||||||
|
# Verify error message
|
||||||
|
assert "flash_attention: true must be set" in str(excinfo.value)
|
||||||
@@ -11,10 +11,10 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from axolotl.utils import is_comet_available
|
from axolotl.utils import is_comet_available
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import check_model_config
|
from axolotl.utils.models import check_model_config
|
||||||
|
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||||
@@ -262,6 +263,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||||
self.cfg_1 = DictDefault(
|
self.cfg_1 = DictDefault(
|
||||||
{
|
{
|
||||||
|
"base_model": "huggyllama/llama-7b",
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"dataset_exact_deduplication": True,
|
"dataset_exact_deduplication": True,
|
||||||
@@ -282,6 +284,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
normalize_config(self.cfg_1)
|
||||||
|
|
||||||
def test_prepare_dataset_with_deduplication_train(self):
|
def test_prepare_dataset_with_deduplication_train(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.schemas.datasets import ChatTemplate
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class TestModelsUtils:
|
|||||||
|
|
||||||
def test_message_property_mapping(self):
|
def test_message_property_mapping(self):
|
||||||
"""Test message property mapping configuration validation"""
|
"""Test message property mapping configuration validation"""
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
from axolotl.utils.schemas.datasets import SFTDataset
|
||||||
|
|
||||||
# Test legacy fields are mapped orrectly
|
# Test legacy fields are mapped orrectly
|
||||||
dataset = SFTDataset(
|
dataset = SFTDataset(
|
||||||
|
|||||||
Reference in New Issue
Block a user