Compare commits

..

4 Commits

Author SHA1 Message Date
Dan Saunders
156fede4f7 Update .pre-commit-config.yaml
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-03-21 10:36:18 -04:00
Dan Saunders
dcbbd7af79 sorry to revert, but pylint complained 2025-03-21 10:36:18 -04:00
Dan Saunders
21bac7ce1a running updated pre-commit plugins 2025-03-21 10:36:18 -04:00
Dan Saunders
aaa4571826 adding pre-commit auto-update GH action and bumping plugin versions 2025-03-21 10:36:17 -04:00
95 changed files with 2125 additions and 6191 deletions

View File

@@ -20,12 +20,9 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
- name: install dependencies
run: |
python3 -m pip install jupyter quartodoc
python3 -m pip install -e . --no-deps
- name: Build autodoc
run: quartodoc build
python3 -m pip install jupyter
- name: Publish to GitHub Pages (and render)
uses: quarto-dev/quarto-actions/publish@v2
with:

View File

@@ -98,9 +98,8 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |
@@ -173,9 +172,8 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |

4
.gitignore vendored
View File

@@ -181,10 +181,6 @@ prepared-datasets/
submit.sh
*.out*
# Quartodoc generated files
objects.json
site_libs/
typings/
out/

View File

@@ -97,7 +97,6 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
## 🤝 Getting Help

View File

@@ -1,179 +1,6 @@
project:
type: website
quartodoc:
dir: docs/api
package: axolotl
title: API Reference
parser: google
sections:
- title: Core
desc: Core functionality for training
contents:
- train
- evaluate
- datasets
- convert
- prompt_tokenizers
- logging_config
- core.trainer_builder
- core.training_args
- core.chat.messages
- core.chat.format.chatml
- core.chat.format.llama3x
- core.chat.format.shared
- core.datasets.chat
- core.datasets.transforms.chat_builder
- title: CLI
desc: Command-line interface
contents:
- cli.main
- cli.train
- cli.evaluate
- cli.args
- cli.checks
- cli.config
- cli.inference
- cli.merge_lora
- cli.merge_sharded_fsdp_weights
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.cloud.base
- cli.cloud.modal_
- title: Trainers
desc: Training implementations
contents:
- core.trainers.base
- core.trainers.trl
- core.trainers.dpo.trainer
- core.trainers.grpo.trainer
- title: Prompt Strategies
desc: Prompt formatting strategies
contents:
- prompt_strategies.base
- prompt_strategies.chat_template
- prompt_strategies.alpaca_chat
- prompt_strategies.alpaca_instruct
- prompt_strategies.alpaca_w_system
- prompt_strategies.user_defined
- prompt_strategies.llama2_chat
- prompt_strategies.completion
- prompt_strategies.input_output
- prompt_strategies.stepwise_supervised
- prompt_strategies.metharme
- prompt_strategies.orcamini
- prompt_strategies.pygmalion
- prompt_strategies.messages.chat
- prompt_strategies.dpo.chat_template
- prompt_strategies.dpo.llama3
- prompt_strategies.dpo.chatml
- prompt_strategies.dpo.zephyr
- prompt_strategies.dpo.user_defined
- prompt_strategies.dpo.passthrough
- prompt_strategies.kto.llama3
- prompt_strategies.kto.chatml
- prompt_strategies.kto.user_defined
- prompt_strategies.orpo.chat_template
- prompt_strategies.bradley_terry.llama3
- title: Kernels
desc: Low-level performance optimizations
contents:
- kernels.lora
- kernels.geglu
- kernels.swiglu
- kernels.quantize
- kernels.utils
- title: MonkeyPatches
desc: Runtime patches for model optimizations
contents:
- monkeypatch.llama_attn_hijack_flash
- monkeypatch.llama_attn_hijack_xformers
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
- monkeypatch.unsloth_
- monkeypatch.attention.mllama
- monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral
- title: Utils
desc: Utility functions
contents:
- utils.models
- utils.tokenization
- utils.chat_templates
- utils.lora
- utils.lora_embeddings
- utils.model_shard_quant
- utils.bench
- utils.freeze
- utils.trainer
- utils.schedulers
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.unsloth
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:
- utils.schemas.config
- utils.schemas.model
- utils.schemas.training
- utils.schemas.datasets
- utils.schemas.peft
- utils.schemas.trl
- utils.schemas.multimodal
- utils.schemas.integrations
- utils.schemas.enums
- utils.schemas.utils
- title: Integrations
desc: Third-party integrations and extensions
contents:
- integrations.base
- integrations.cut_cross_entropy.args
- integrations.grokfast.optimizer
- integrations.kd.trainer
- integrations.liger.args
- integrations.lm_eval.args
- integrations.spectrum.args
- title: Common
desc: Common utilities and shared functionality
contents:
- common.architectures
- common.const
- common.datasets
- title: Models
desc: Custom model implementations
contents:
- models.mamba.modeling_mamba
- title: Data Processing
desc: Data processing utilities
contents:
- utils.collators.core
- utils.collators.batching
- utils.collators.mamba
- utils.collators.mm_chat
- utils.samplers.multipack
- title: Callbacks
desc: Training callbacks
contents:
- utils.callbacks.perplexity
- utils.callbacks.profiler
- utils.callbacks.lisa
- utils.callbacks.mlflow_
- utils.callbacks.comet_
website:
title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun"
@@ -208,8 +35,6 @@ website:
- docs/inference.qmd
- docs/cli.qmd
- docs/config.qmd
- text: "API Reference"
href: docs/api
- section: "Dataset Formats"
contents: docs/dataset-formats/*
@@ -255,22 +80,3 @@ format:
theme: darkly
css: styles.css
toc: true
# Enable better handling of line breaks in markdown
preserve-tabs: true
html-math-method: mathjax
# Improved markdown processing options
md-extensions:
- markdown_it
- def_list
- attr_list
- fenced_divs
- tables
- html_admonition
- lineblocks
- fancy_lists
# Control whitespace handling
whitespace: preserve
# Process newlines in paragraphs
wrap: preserve
# Better line break handling
preserve-linebreaks: true

View File

@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -3,10 +3,9 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 /workspace/axolotl/tests/cli
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

2
docs/.gitignore vendored
View File

@@ -1,4 +1,2 @@
/.quarto/
_site/
/api/*.qmd
/api/*.html

View File

@@ -1,5 +1,5 @@
---
title: "Command Line Interface (CLI)"
title: "CLI Reference"
format:
html:
toc: true

View File

@@ -32,9 +32,6 @@ tokenizer_legacy:
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init_weights:
# (Internal use only)
# Used to identify which the model is based on
@@ -466,7 +463,6 @@ auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
@@ -507,58 +503,36 @@ lr_div_factor: # Learning rate div factor
# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
#
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# in the examples/ for your model and fine-tuning use case.
#
# Valid values for 'optimizer' include:
# - adamw_hf
# - adamw_torch
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_torch_npu_fused
# - adamw_apex_fused
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - adamw_torch_4bit
# - ademamix
# - sgd
# - adagrad
# - adamw_bnb_8bit
# - adamw_8bit # alias for adamw_bnb_8bit
# - ademamix_8bit
# - lion_8bit
# - lion_32bit
# - paged_adamw_32bit
# - paged_adamw_8bit
# - paged_ademamix_32bit
# - paged_ademamix_8bit
# - paged_lion_32bit
# - paged_lion_8bit
# - rmsprop
# - rmsprop_bnb
# - rmsprop_bnb_8bit
# - rmsprop_bnb_32bit
# - galore_adamw
# - galore_adamw_8bit
# - galore_adafactor
# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise
# - lomo
# - adalomo
# - grokadamw
# - schedule_free_adamw
# - schedule_free_sgd
# - apollo_adamw
# - apollo_adamw_layerwise
#
# Additional custom optimizers include:
# - optimi_adamw
# - ao_adamw_8bit
# - ao_adamw_fp8
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:
@@ -610,14 +584,6 @@ resume_from_checkpoint:
# Be careful with this being turned on between different models.
auto_resume_from_checkpoints: false
## Multimodal section
# int | tuple[int, int] | None . Size to resize images to, width x height.
# Will read from model/processor config if not set.
image_size:
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
image_resize_algorithm: 'bilinear'
## End of multimodal section
# Don't mess with this, it's here for accelerate and torchrun
local_rank:
@@ -651,14 +617,6 @@ ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Sequence parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -6,7 +6,7 @@ description: How datasets are processed
## Overview
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
the [dataset format](dataset-formats) and prompt strategies to:
the [dataset format](docs/dataset-formats) and prompt strategies to:
- parse the dataset based on the *dataset format*
- transform the dataset to how you would interact with the model based on the *prompt strategy*

View File

@@ -103,7 +103,8 @@ This uses the same tags as the [`main` image](#sec-main-tags).
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
#### Volume mounts

View File

@@ -37,10 +37,6 @@ description: Frequently asked questions
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -1,171 +1,28 @@
---
title: MultiModal / Vision Language Models (BETA)
format:
html:
toc: true
toc-depth: 3
---
# MultiModal / Vision Language Models (BETA)
## Supported Models
### Supported Models
- [Mllama](#sec-mllama)
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Gemma-3](#sec-gemma-3)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- Mllama, i.e. llama with vision models
## Usage
### Usage
Multimodal support is limited and doesn't have full feature parity.
Here are the hyperparams you'll need to use to finetune a multimodal model.
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
```yaml
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
skip_prepare_dataset: true
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
sample_packing: false # not yet supported with multimodal
chat_template: # see in next section
# example dataset
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
remove_unused_columns: false
sample_packing: false
# (optional) if doing lora, only finetune the Language model,
# leave the vision model and vision tower frozen
# load_in_8bit: true
adapter: lora
# only finetune the Language model, leave the vision model and vision tower frozen
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# (optional) if you want to resize images to a set size
image_size: 512
image_resize_algorithm: bilinear
```
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
::: {.callout-warning}
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
:::
### Mllama {#sec-mllama}
```yaml
base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
chat_template: llama3_2_vision
```
### Pixtral {#sec-pixtral}
```yaml
base_model: mistralai/Pixtral-12B-2409
chat_template: pixtral
```
### Llava-1.5 {#sec-llava-15}
```yaml
base_model: llava-hf/llava-1.5-7b-hf
chat_template: llava
```
### Mistral-Small-3.1 {#sec-mistral-small-31}
```yaml
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
chat_template: mistral_v7_tekken
```
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}
The Gemma3-1B model is a text-only model, so please train as regular text model.
:::
For multi-modal 4B/12B/27B models, use the following config:
```yaml
base_model: google/gemma-3-4b-it
chat_template: gemma3
```
### Qwen2-VL {#sec-qwen2-vl}
```yaml
base_model: Qwen/Qwen2-VL-7B-Instruct
chat_template: qwen2_vl
```
### Qwen2.5-VL {#sec-qwen25-vl}
```yaml
base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
- A message is a list of `role` and `content`.
- `role` can be `system`, `user`, `assistant`, etc.
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
::: {.callout-note}
For backwards compatibility:
- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key.
- If `content` is a string, it will be converted to a list with `type` as `text`.
:::
::: {.callout-tip}
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
- `"path": "/path/to/image.jpg"`
- `"url": "https://example.com/image.jpg"`
- `"base64": "..."`
- `"image": PIL.Image`
:::
Here is an example of a multi-modal dataset:
```json
[
{
"messages": [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The image is a bee."}
]
}
]
}
]
```

View File

@@ -1,90 +0,0 @@
---
title: Sequence Parallelism
description: Train with long sequences split across multiple GPUs.
---
# Sequence Parallelism
Sequence parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.
## When to Use Sequence Parallelism
Use sequence parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
- You're experiencing OOM (Out Of Memory) errors with long sequences
## Configuration
To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
## Implementation Details
When sequence parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
4. The trainer uses special ring communication patterns for attention operations
## Requirements
To use sequence parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
- `pip install axolotl[ring-flash-attn]` (preferred)
- `pip install ring-flash-attn>=0.1.4`
## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
sequence_parallel_degree: 2 # Split each sequence into 4 parts
flash_attention: true # Required with sequence parallelism
...
```
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -1,71 +0,0 @@
base_model: CohereForAI/c4ai-command-r7b-12-2024
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
chat_template: cohere
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,63 +0,0 @@
base_model: google/gemma-3-4b-it
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: gemma3
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,74 +0,0 @@
base_model: google/gemma-3-1b-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
chat_template: gemma3_text
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,63 +0,0 @@
base_model: llava-hf/llava-1.5-7b-hf
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llava
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,66 +0,0 @@
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
processor_type: AutoProcessor
strict: false
load_in_8bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: mistral_v7_tekken
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,65 +0,0 @@
base_model: mistral-community/pixtral-12b
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: pixtral
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <pad>

View File

@@ -1,63 +0,0 @@
base_model: Qwen/Qwen2-VL-7B-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen2_vl
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -2,5 +2,3 @@ pre-commit
black
mypy
types-requests
quartodoc
jupyter

View File

@@ -4,6 +4,7 @@
bitsandbytes==0.45.3
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.3
@@ -12,7 +13,7 @@ liger-kernel==0.5.3
packaging==23.2
peft==0.15.0
transformers==4.50.0
transformers==4.49.0
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.4.1
@@ -35,7 +36,6 @@ einops
colorama
numba
numpy>=1.24.4,<=2.0.1
# qlora things
evaluate==0.4.1
scipy

315
requirements_env.txt Normal file
View File

@@ -0,0 +1,315 @@
accelerate==0.34.1
addict==2.4.0
aiofiles==23.2.1
aiohttp==3.9.0
aiosignal==1.3.1
aiostream==0.5.2
alembic==1.13.1
annotated-types==0.6.0
annoy==1.17.3
ansible==6.7.0
ansible-core==2.13.13
ansible-vault==2.1.0
anyio==3.7.1
appdirs==1.4.4
art==6.0
asgiref==3.7.2
async-timeout==4.0.2
attrdict==2.0.1
attrs==22.2.0
awscli==1.32.75
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
backoff==2.2.1
base58==2.1.1
beartype==0.17.2
bitnet==0.2.1
bitsandbytes==0.42.0
bittensor==6.7.0
black==23.7.0
blinker==1.7.0
boto3==1.34.75
botocore==1.34.75
cachetools==5.3.3
cachy==0.1.1
certifi==2023.7.22
cffi==1.16.0
cfgv==3.3.1
chai-guanaco==1.2.4
charset-normalizer==3.2.0
cleo==0.6.8
click==8.1.7
cloudpickle==2.0.0
cohere==4.11.2
colorama==0.4.4
coloredlogs==15.0.1
CoLT5-attention==0.10.20
contextlib2==21.6.0
contourpy==1.2.0
cryptography==41.0.3
cycler==0.12.1
cytoolz==0.12.3
databricks-cli==0.18.0
dataclasses-json==0.5.7
datasets==2.11.0
ddt==1.6.0
decorator==5.1.1
deepspeed==0.15.0
# Editable Git install with no remote (dialogpt==0.1)
-e /Users/wing/Projects/ml/dialogpt/src
dill==0.3.6
distlib==0.3.6
docker==7.0.0
docker-pycreds==0.4.0
docstring-parser==0.15
docutils==0.16
ecdsa==0.18.0
einops==0.7.0
einops-exts==0.0.4
einx==0.1.3
entrypoints==0.4
eth-hash==0.6.0
eth-keys==0.5.0
eth-typing==4.0.0
eth-utils==2.3.1
evaluate==0.4.0
exceptiongroup==1.1.1
fastapi==0.109.2
fastcore==1.5.29
ffmpy==0.4.0
filelock==3.12.2
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
fire==0.5.0
first==2.0.2
flake8==7.0.0
Flask==3.0.1
fonttools==4.47.2
frozendict==2.4.1
frozenlist==1.3.3
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
fsspec==2023.6.0
fuzzywuzzy==0.18.0
gitdb==4.0.10
GitPython==3.1.31
google-pasta==0.2.0
gradio==4.42.0
gradio_client==1.3.0
greenlet==2.0.2
grpclib==0.4.7
gunicorn==21.2.0
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==0.17.3
httpx==0.24.1
huggingface-hub==0.23.4
humanfriendly==10.0
hyperframe==6.0.1
identify==2.5.24
idna==3.4
immutables==0.20
importlib-metadata==6.7.0
importlib-resources==6.1.1
inflection==0.5.1
iniconfig==2.0.0
itsdangerous==2.1.2
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
jsonlines==3.1.0
jsonschema==2.6.0
kiwisolver==1.4.5
langchain==0.0.144
Levenshtein==0.24.0
libcst==1.1.0
liger-kernel==0.0.0
lion-pytorch==0.1.2
llama-cpp-python==0.1.36
llvmlite==0.40.1
local-attention==1.9.0
loguru==0.7.0
Mako==1.3.2
Markdown==3.5.2
markdown-it-py==3.0.0
markdown2==2.4.10
MarkupSafe==2.1.2
marshmallow==3.19.0
marshmallow-enum==1.5.1
matplotlib==3.8.2
mccabe==0.7.0
mdurl==0.1.2
MEGABYTE-pytorch==0.0.7
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
mlflow==2.10.0
modal==0.62.77
more-itertools==10.2.0
mpmath==1.2.1
msgpack==1.0.7
msgpack-numpy-opentensor==0.5.0
multidict==6.0.4
multiprocess==0.70.14
munch==2.5.0
mypy==1.3.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
netaddr==0.10.1
networkx==3.0rc1
nh3==0.2.14
nodeenv==1.8.0
nomic==2.0.2
numba==0.57.1
numexpr==2.8.4
numpy==1.24.4
oauthlib==3.2.2
openai==0.27.4
openapi==1.1.0
openapi-schema-pydantic==1.2.4
optimum==1.8.6
orjson==3.10.7
packaging==23.1
pandas==2.0.0
parameterized==0.9.0
password-strength==0.0.3.post2
pastel==0.1.1
pathos==0.3.0
pathspec==0.11.1
pathtools==0.1.2
peft==0.11.1
pendulum==3.0.0
Pillow==9.5.0
pip-tools==1.11.0
platformdirs==3.2.0
pluggy==1.4.0
poetry==0.7.1
pox==0.3.2
ppft==1.7.6.6
pre-commit==3.3.2
prettytable==3.10.0
prompt-toolkit==3.0.39
protobuf==3.20.2
protobuf3-to-dict==0.1.5
psutil==5.9.5
psycopg==3.1.18
PuLP==2.8.0
py==1.11.0
py-bip39-bindings==0.1.11
py-cpuinfo==9.0.0
py-ed25519-zebra-bindings==1.0.1
py-sr25519-bindings==0.2.0
pyarrow==11.0.0
pyasn1==0.6.0
pycodestyle==2.11.1
pycparser==2.21
pycryptodome==3.20.0
pydantic==2.5.3
pydantic_core==2.14.6
pydub==0.25.1
pyfiglet==0.8.post1
pyflakes==3.2.0
Pygments==2.15.1
PyJWT==2.8.0
pylev==1.4.0
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==2.4.7
pyrsistent==0.14.11
pytest==8.0.2
pytest-asyncio==0.23.4
python-dateutil==2.8.2
python-dotenv==1.0.1
python-Levenshtein==0.24.0
python-multipart==0.0.9
pytz==2023.3
PyYAML==6.0.1
querystring-parser==1.2.4
rapidfuzz==3.6.1
regex==2023.6.3
requests==2.31.0
requests-toolbelt==0.8.0
resolvelib==0.8.1
responses==0.18.0
retry==0.9.2
rich==13.7.0
rsa==4.7.2
ruff==0.6.3
s3transfer==0.10.1
safetensors==0.4.5
sagemaker==2.148.0
scalecodec==1.2.7
schedulefree==1.2.1
schema==0.7.5
scikit-learn==1.4.0
scipy==1.9.3
seaborn==0.13.2
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==1.19.1
setproctitle==1.3.2
shellingham==1.5.4
shortuuid==1.0.11
shtab==1.6.5
sigtools==4.0.1
six==1.16.0
skypilot==0.4.1
smdebug-rulesconfig==1.0.1
smmap==5.0.0
sniffio==1.3.0
SQLAlchemy==1.4.47
sqlparse==0.4.4
starlette==0.36.3
substrate-interface==1.5.2
svgwrite==1.4.3
sympy==1.11.1
synchronicity==0.6.7
tabulate==0.9.0
tblib==1.7.0
tenacity==8.2.2
tensor-parallel==2.0.0
termcolor==2.2.0
text2art==0.2.0
threadpoolctl==3.2.0
tiktoken==0.6.0
time-machine==2.14.1
timm==0.9.16
tokenizers==0.19.1
tokenmonster==1.1.12
toml==0.9.6
tomli==2.0.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.2.0
torchdata==0.6.1
torchdiffeq==0.2.3
TorchFix==0.4.0
torchtext==0.15.2
torchvision==0.17.0
tqdm==4.66.2
transformers==4.44.2
trl==0.9.6
typer==0.12.5
types-certifi==2021.10.8.3
types-requests==2.31.0.20240125
types-setuptools==69.0.0.20240125
types-toml==0.10.8.7
typing==3.7.4.3
typing-inspect==0.8.0
typing_extensions==4.9.0
tyro==0.5.18
tzdata==2023.3
unique-names-generator==1.0.2
urllib3==2.2.2
uvicorn==0.22.0
vector_quantize_pytorch==1.14.1
virtualenv==20.23.0
voyager==2.0.2
wandb==0.16.2
watchfiles==0.21.0
wavedrom==2.0.3.post3
wcwidth==0.2.6
websocket-client==1.7.0
websockets==12.0
Werkzeug==3.0.1
wonderwords==2.2.0
xxhash==3.2.0
yarl==1.8.2
zetascale==2.2.7
zipp==3.15.0

View File

@@ -16,7 +16,13 @@ def parse_requirements():
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = "deepspeed" in line or "mamba-ssm" in line
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
@@ -33,6 +39,7 @@ def parse_requirements():
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
@@ -117,8 +124,9 @@ setup(
],
},
extras_require={
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"flash-attn": [
"flash-attn==2.7.4.post1",
],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",
@@ -133,15 +141,15 @@ setup(
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],

View File

@@ -56,7 +56,7 @@ def do_inference(
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Inference-specific CLI arguments.
"""
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
@@ -151,7 +151,7 @@ def do_inference_gradio(
"""
import gradio as gr
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None

View File

@@ -25,7 +25,7 @@ from axolotl.cli.utils import (
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.schemas.config import AxolotlInputConfig
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@click.group()

View File

@@ -27,7 +27,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
"""
print_axolotl_text_art()
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
@@ -44,9 +44,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""

View File

@@ -17,14 +17,13 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Trains a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
@@ -34,9 +33,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
print_axolotl_text_art()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
@@ -48,13 +44,16 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
del trainer
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_train`.

View File

@@ -13,16 +13,11 @@ from typing import Any, Callable, Type, Union, get_args, get_origin
import click
import requests
from pydantic import BaseModel
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)
@@ -300,13 +295,9 @@ def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[
PreTrainedModel,
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
ProcessorMixin | None,
]:
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
"""
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
Helper function for loading a model and tokenizer specified in the given `axolotl`
config.
Args:
@@ -314,7 +305,7 @@ def load_model_and_tokenizer(
inference: Boolean denoting inference mode.
Returns:
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
`transformers` model and tokenizer.
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
@@ -322,9 +313,4 @@ def load_model_and_tokenizer(
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
processor = None
if cfg.is_multimodal:
LOG.info("loading processor...")
processor = load_processor(cfg, tokenizer)
return model, tokenizer, processor
return model, tokenizer

View File

@@ -13,7 +13,9 @@
# limitations under the License.
# pylint: disable=too-many-lines
"""Builder for the training args and trainer"""
"""
Builder for the training args and trainer
"""
import abc
import importlib
@@ -36,7 +38,7 @@ from transformers import (
from transformers.training_args import OptimizerNames
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.trainers import (
from axolotl.core.trainers.base import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlMambaTrainer,
@@ -60,7 +62,6 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
@@ -84,8 +85,8 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
from axolotl.utils.models import ensure_dtype
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
try:
import torch._dynamo # pylint: disable=ungrouped-imports
@@ -748,12 +749,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.accelerator_config
)
if self.cfg.image_size:
training_arguments_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
@@ -769,10 +764,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_top_k_before_softmax
)
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
@@ -856,10 +847,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
@@ -887,7 +877,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
):
@@ -897,13 +889,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
)
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
@@ -923,8 +910,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
return collator(
*collator_args,

View File

@@ -1,18 +0,0 @@
"""Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,
AxolotlORPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -1,47 +1,365 @@
"""Module for customized trainers"""
# pylint: disable=too-many-lines
"""
module for customized trainers
"""
from __future__ import annotations
# pylint: disable=too-many-lines
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Any, Literal
from typing import Dict, Literal, Optional
import datasets
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
RandomSampler,
Sampler,
SequentialSampler,
)
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
OptimizerMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = logging.getLogger(__name__)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger("axolotl.core.trainer_builder")
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
"""Extend the base Trainer for axolotl helpers"""
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class OptimizerMixin(Trainer):
"""
Mixin class for shared handling of building custom optimizers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
@@ -58,18 +376,12 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -82,247 +394,142 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _create_multipack_sampler(
self, base_sampler: Sampler, dataset: Dataset
) -> MultipackBatchSampler:
"""
Helper method to create a `MultipackBatchSampler` for multipacking sequences
for training.
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
Args:
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
dataset: Dataset to sample from.
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
Returns:
Multipack (sample packing) batch sampler.
"""
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
return MultipackBatchSampler(
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
return MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
drop_last=True,
)
def _get_train_sampler(self) -> Sampler | None:
"""
Helper method to get the sampler for training. Handles cases for sequence
parallelism, sample packing, and curriculum sampling (sequential).
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset)
else:
# Default to parent class implementation for standard random sampling
return super()._get_train_sampler()
# Apply multipack wrapper if needed
if use_sample_packing:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=self.train_dataset,
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
lengths=get_dataset_lengths(self.eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
return super()._get_eval_sampler(eval_dataset)
return base_sampler
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
depends on the passed training args.
"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Multipacking enabled if training is enabled and eval is not explicitly disabled
use_multipack = (
self.args.sample_packing and self.args.eval_sample_packing is not False
)
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
return super()._get_eval_sampler(eval_dataset)
# Apply multipack wrapper if needed
if use_multipack:
return self._create_multipack_sampler(
base_sampler=base_sampler,
dataset=eval_dataset,
)
return base_sampler
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
"""Create common dataloader parameters for train or eval."""
batch_size = custom_batch_size or (
self.args.eval_batch_size if is_eval else self._train_batch_size
)
params = {
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
# Add persistent workers only for training
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
params["persistent_workers"] = self.args.dataloader_persistent_workers
# Add prefetch factor if specified
if self.args.dataloader_prefetch_factor:
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return params
def _prepare_dataloader(
self, dataset, sampler, is_eval=False, custom_batch_size=None
):
"""Prepare a dataloader with the given dataset and sampler."""
# Get base parameters
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
# Add sampler configuration
if not isinstance(dataset, torch.utils.data.IterableDataset):
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
# batch_size and batch_sampler are mutually exclusive
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
if self.args.sample_packing and (
(not is_eval and not self.args.pretraining)
or (is_eval and self.args.eval_sample_packing is not False)
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
if isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
data_collator,
description="training",
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
# Get sampler and create dataloader
sampler = self._get_train_sampler()
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""Get dataloader for evaluation"""
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle special case: sample packing is enabled but eval_sample_packing is False
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if "length" in eval_dataset.column_names:
if eval_dataset:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
# Handle sample packing or sequence parallelism
if (
self.args.sample_packing
and self.args.eval_sample_packing is not False
or self.args.sequence_parallel_degree > 1
):
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
if hasattr(self, "eval_data_collator") and self.eval_data_collator
else self.data_collator
)
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
# Handle dataset preprocessing for SP
if self.args.sequence_parallel_degree > 1:
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size
if self.args.sample_packing
else self.args.per_device_eval_batch_size
)
sampler = self._get_eval_sampler(eval_dataset)
dataloader = self._prepare_dataloader(
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
return dataloader
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> torch.utils.data.Sampler | None:
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
@@ -347,7 +554,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
@override
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
@@ -364,7 +570,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
@@ -539,10 +744,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = sanitize_kwargs_for_ds_tagging(
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@@ -559,13 +764,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return res
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs: The values to log.
start_time: The start of training.
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
@@ -577,7 +784,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
return super().log(logs, start_time)
def store_metrics(
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
@@ -590,26 +797,110 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
def training_step(
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
# Proceed with normal training step
loss = super().training_step(model, inputs, num_items_in_batch)
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return loss
return lm_loss
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -13,10 +13,10 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.mixins import SchedulerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
from axolotl.core.trainers.base import (
SchedulerMixin,
_sanitize_kwargs_for_ds_tagging,
_sanitize_kwargs_for_tagging,
)
if is_sagemaker_mp_enabled():
@@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = sanitize_kwargs_for_ds_tagging(
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)

View File

@@ -9,7 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl")

View File

@@ -1,32 +0,0 @@
"""Module for mamba trainer"""
import torch
from axolotl.core.trainers.base import AxolotlTrainer
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss

View File

@@ -1,8 +0,0 @@
"""Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa
from .optimizer import OptimizerMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,201 +0,0 @@
"""Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger(__name__)
class OptimizerMixin(Trainer):
"""Mixin class for shared handling of building custom optimizers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_optimizer_grouped_parameters(
self, opt_model, optimizer_kwargs
) -> list[dict]:
decay_parameters = self.get_decay_parameter_names(opt_model)
params: dict = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer
and self.optimizer_cls_and_kwargs is not None
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
):
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
self.optimizer = optimizer_factory_cls()(
opt_model, self.args, **optimizer_kwargs
)
if not self.optimizer:
if self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
self.args, opt_model
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
else:
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop(
"optimizer_dict"
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer

View File

@@ -1,113 +0,0 @@
"""Module for Axolotl trainer scheduler mixin"""
import logging
import torch
from torch.optim.lr_scheduler import OneCycleLR
from transformers.trainer import Trainer
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = logging.getLogger(__name__)
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler

View File

@@ -1,131 +0,0 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import logging
from typing import Any
import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
LOG = logging.getLogger(__name__)
try:
from ring_flash_attn import update_ring_flash_attn_params
except ImportError:
# We pass silently here, but raise an ImportError in our Axolotl config validation
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
pass
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
including creating appropriate samplers, managing data partitioning,
and updating ring flash attention parameters during training.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
self.ring_attn_group = get_ring_attn_group()
def _create_sequence_parallel_sampler(
self,
dataset: Dataset,
shuffle: bool = True,
is_eval: bool = False,
) -> DistributedSampler:
"""
Helper method to create sampler for sequence parallelism (SP).
We create a distributed sampler with rank equal to the SP group ID, which
means that all ranks in the SP group receive the same sample / set of samples
per training step. We also set the number of replicas equal to the number of
SP groups, which is a bit of a hack / unintended use, but works!
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
"""
Calculate the cu_seqlens for the current forward pass and pass the value to
the substituted ring_flash_attn. This is accomplished by using the passed
`input_ids`.
Args:
inputs: Current batch of inputs.
"""
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
packed_seq_lens = [seq_len] * batch_size
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_degree
cu_seqlens = torch.cumsum(
torch.tensor(
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
),
dim=-1,
dtype=torch.int32,
)
cu_seqlens = F.pad(
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)

View File

@@ -1,43 +0,0 @@
"""Module for ReLoRA trainer"""
import torch
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
class ReLoRATrainer(AxolotlTrainer):
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler

View File

@@ -1,25 +1,16 @@
"""Module for TRL PPO trainer"""
from typing import Literal, Union
"""
module for TRL PPO training
"""
import torch
from tqdm import tqdm
from trl import (
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PPOTrainer,
PRMTrainer,
RewardTrainer,
)
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
from trl import PPOTrainer
class TRLPPOTrainer(PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
"""
wrapper for ppo trainer to handle customizations
"""
def train(
self,
@@ -40,7 +31,9 @@ class TRLPPOTrainer(PPOTrainer):
"batch_size": 16,
}
for _, batch in tqdm(enumerate(self.dataloader)):
for epoch, batch in tqdm( # pylint: disable=unused-variable
enumerate(self.dataloader)
):
query_tensors = batch["input_ids"]
# generate model response
@@ -72,189 +65,3 @@ class TRLPPOTrainer(PPOTrainer):
rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
)
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
chosen_rewards
).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
rejected_rewards
).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
reward_accuracies
).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
)
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
policy_rejected_logits.detach().mean()
).mean()
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
policy_chosen_logits.detach().mean()
).mean()
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
)
metrics[f"{prefix}log_odds_ratio"] = (
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
)
metrics[f"{prefix}log_odds_chosen"] = (
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
)
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = (
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
)
metrics[f"{prefix}rewards/rejected"] = (
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
)
metrics[f"{prefix}rewards/accuracies"] = (
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
)
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
.mean()
.item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
)
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -1,33 +0,0 @@
"""Utils for Axolotl trainers"""
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs

View File

@@ -5,7 +5,6 @@ extra axolotl specific training args
from dataclasses import dataclass, field
from typing import Optional
from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@@ -208,33 +207,14 @@ class AxolotlTrainingMixins:
},
)
sequence_parallel_degree: Optional[int] = field(
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a
default value so it can't be used as a mixin.
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
"""

View File

@@ -8,8 +8,6 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from datasets import Dataset
from transformers.trainer import Trainer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
@@ -27,18 +25,18 @@ LOG = get_logger("axolotl.evaluate")
def evaluate_dataset(
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]:
"""Helper function to evaluate a single dataset.
"""Helper function to evaluate a single dataset safely.
Args:
trainer: The trainer instance.
dataset: Dataset to evaluate.
dataset_type: Type of dataset ('train' or 'eval').
flash_optimum: Whether to use flash optimum.
trainer: The trainer instance
dataset: Dataset to evaluate
dataset_type: Type of dataset ('train' or 'eval')
flash_optimum: Whether to use flash optimum
Returns:
Dictionary of metrics or None if dataset is None.
Dictionary of metrics or None if dataset is None
"""
if dataset is None:
return None
@@ -65,14 +63,17 @@ def evaluate_dataset(
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets.
Evaluate a model on training and validation datasets
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
dataset_meta: Dataset metadata containing training and evaluation datasets.
Returns:
Dictionary mapping metric names to their values.
Tuple containing:
- The model (either PeftModel or PreTrainedModel)
- The tokenizer
- Dictionary of evaluation metrics
"""
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage

View File

@@ -11,17 +11,19 @@
# the License.
"""
Module to handle merging the plugins' input arguments with the base configurations.
module to handle merging the plugins' input arguments with the base configurations.
This was moved here to prevent circular imports.
this was moved here to prevent circular imports
"""
from typing import Any, Dict, List
from axolotl.utils.schemas.config import (
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
def merge_input_args():

View File

@@ -1,6 +1,6 @@
# Cut Cross Entropy
Cut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
See https://github.com/apple/ml-cross-entropy
@@ -29,20 +29,6 @@ plugins:
cut_cross_entropy: true
```
## Supported Models
- llama
- phi3
- gemma
- gemma2
- gemma3
- gemma3_text
- mistral
- mistral3
- qwen2
- cohere
- cohere2
## Citation
```bib

View File

@@ -72,9 +72,7 @@ class CutCrossEntropyPlugin(BasePlugin):
if cfg.cut_cross_entropy:
self._check_requirements()
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
cce_patch,
)
from cut_cross_entropy.transformers import cce_patch
with zero_only():
LOG.info(

View File

@@ -1,201 +0,0 @@
"""Cohere and Cohere2 CCE patch."""
# This patch is based off transformers 4.50.0.
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
# It scales the hidden states by the logit scale in advance instead of the logits as the
# operation is done internally and should be mathematically equivalent.
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
_CONFIG_FOR_DOC,
COHERE_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>> from transformers import AutoTokenizer, CohereForCausalLM
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
>> prompt = "Hey, are you conscious? Can you talk to me?"
>> inputs = tokenizer(prompt, return_tensors="pt")
>> # Generate
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
# scale weight by logit_scale in-place of logits
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight * self.logit_scale,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
logits = logits * self.logit_scale # main diff from Llama
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def patch_cohere(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.cohere import modeling_cohere
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_cohere.CohereForCausalLM
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_cohere.CohereForCausalLM.forward = cce_forward
return None
def patch_cohere2(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.cohere2 import modeling_cohere2
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_cohere2.Cohere2ForCausalLM
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
return None

View File

@@ -1,175 +0,0 @@
"""Gemma CCE patch"""
# This patch is based off transformers 4.50.0.
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma.modeling_gemma import (
_CONFIG_FOR_DOC,
GEMMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def patch_gemma(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.gemma import modeling_gemma
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_gemma.GemmaForCausalLM
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_gemma.GemmaForCausalLM.forward = cce_forward
return None

View File

@@ -1,465 +0,0 @@
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
# and updated for transformers 4.50.0.
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
# with both gemma3 (text and multimodal) models.
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch import nn
from transformers.cache_utils import Cache, HybridCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
defer_logits_calculation: bool = False,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
defer_logits_calculation (`bool`, *optional*):
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**loss_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
if self.config.final_logit_softcapping is not None:
logger.warning_once(
"final_logit_softcapping is not supported for gemma3_text with CCE. Disabling."
)
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**loss_kwargs,
)
elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :]
if self.config.final_logit_softcapping is not None:
logger.warning_once(
"final_logit_softcapping is not supported for gemma3 with CCE. Disabling."
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids # type: ignore
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
)
cache_position = torch.arange( # type: ignore
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(
self.config.image_token_index,
dtype=torch.long,
device=inputs_embeds.device,
)
)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where( # type: ignore
input_ids == self.pad_token_id, self.config.ignore_index, labels
)
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
attention_mask,
token_type_ids,
past_key_values,
cache_position,
inputs_embeds,
is_training,
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
defer_logits_calculation=True, # enable deferred logits calculation
**lm_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
labels,
_PATCH_OPTS,
**lm_kwargs,
)
else:
logits = hidden_states
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
logits.device
)
shift_logits = shift_logits[
shift_attention_mask.to(logits.device) != 0
].contiguous()
shift_labels = shift_labels[
shift_attention_mask.to(shift_labels.device) != 0
].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_gemma2(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.gemma2 import modeling_gemma2
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_gemma2.Gemma2ForCausalLM
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
return None
def patch_gemma3_text(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.gemma3 import modeling_gemma3
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_gemma3.Gemma3ForCausalLM
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
return None
def patch_gemma3(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.gemma3 import modeling_gemma3
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
# patch the causal model to enable deferred logits calculation
maybe_model.language_model.forward = MethodType(
cce_forward, maybe_model.language_model
)
return maybe_model
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
# patch the causal model to enable deferred logits calculation
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
return None

View File

@@ -1,392 +0,0 @@
"""Mistral and Mistral3 CCE patch."""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.mistral3.modeling_mistral3 import (
Mistral3CausalLMOutputWithPast,
)
from transformers.models.mistral.modeling_mistral import (
_CONFIG_FOR_DOC,
MISTRAL_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] | None = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
defer_logits_calculation: bool = False,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
defer_logits_calculation (`bool`, *optional*):
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MistralForCausalLM
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :]
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor | None = None,
**lm_kwargs,
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is the image?The image depicts two cats lying on a pink blanket."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
defer_logits_calculation=True, # enable deferred logits calculation
**lm_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
labels,
_PATCH_OPTS,
**lm_kwargs,
)
else:
logits = hidden_states
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
logits.device
)
shift_logits = logits[..., :-1, :][
shift_attention_mask.to(logits.device) != 0
].contiguous()
shift_labels = labels[..., 1:][
shift_attention_mask.to(labels.device) != 0
].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1).to(shift_logits.device),
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Mistral3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_mistral(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.mistral import modeling_mistral
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_mistral.MistralForCausalLM
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_mistral.MistralForCausalLM.forward = cce_forward
return None
def patch_mistral3(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.mistral import modeling_mistral
from transformers.models.mistral3 import modeling_mistral3
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
# patch the causal model to enable deferred logits calculation
maybe_model.language_model.forward = MethodType(
cce_forward, maybe_model.language_model
)
return maybe_model
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
# patch the causal model to enable deferred logits calculation
modeling_mistral.MistralForCausalLM.forward = cce_forward
return None

View File

@@ -1,379 +0,0 @@
"""Mllama CCE patch."""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.mllama.modeling_mllama import (
MLLAMA_INPUTS_DOCSTRING,
_prepare_cross_attention_mask,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.LongTensor] = None,
cross_attention_mask: Optional[torch.LongTensor] = None,
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
defer_logits_calculation: bool = False,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
defer_logits_calculation (`bool`, *optional*):
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MllamaForCausalLM
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
>>> prompt = "If I had to write a haiku, it would be:"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
>>> print(result)
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
I love the idea of snowflakes gently falling, each one
```
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
cross_attention_states=cross_attention_states,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
loss = None
logits = None
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**loss_kwargs,
)
elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :]
else:
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
)
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
>>> processor = AutoProcessor.from_pretrained(checkpoint)
>>> prompt = "<|image|>If I had to write a haiku for this one"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate
>>> output = model.generate(**inputs, max_new_tokens=15)
>>> prompt_len = inputs.input_ids.shape[-1]
>>> generated_ids = output[:, prompt_len:]
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
>>> print(generated_text)
[', it would be:.\\nA stop sign in Chinatown.\\n']
```
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError(
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
)
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# get vision tokens from vision model
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(
cross_attention_states
).reshape(
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
)
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = (
_prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
)
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
:, :, cache_position
]
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
defer_logits_calculation=True, # enable deferred logits calculation
**loss_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
labels,
_PATCH_OPTS,
**loss_kwargs,
)
else:
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
logits = hidden_states
if labels is not None:
loss = self.loss_function(
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
)
if not return_dict:
return (loss,) + outputs if loss is not None else outputs
return CausalLMOutputWithPast(
loss=loss,
logits=outputs.logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def patch_mllama(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.mllama import modeling_mllama
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_mllama.MllamaForConditionalGeneration
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
# patch the language model
maybe_model.language_model.forward = MethodType(
cce_forward, maybe_model.language_model
)
return maybe_model
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
# patch the causal language model
modeling_mllama.MllamaForCausalLM.forward = cce_forward
return None

View File

@@ -1,85 +0,0 @@
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
"""Cut Cross Entropy patcher"""
import transformers
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
from cut_cross_entropy.transformers.llama import patch_llama
from cut_cross_entropy.transformers.phi3 import patch_phi3
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
patch_cohere,
patch_cohere2,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma2,
patch_gemma3,
patch_gemma3_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
patch_mistral,
patch_mistral3,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"llama": patch_llama,
"mllama": patch_mllama,
"phi3": patch_phi3,
"gemma": patch_gemma,
"gemma2": patch_gemma2,
"gemma3": patch_gemma3,
"gemma3_text": patch_gemma3_text,
"mistral": patch_mistral,
"mistral3": patch_mistral3,
"qwen2": patch_qwen2,
"cohere": patch_cohere,
"cohere2": patch_cohere2,
}
def cce_patch(
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
reduction: str = "mean",
filter_eps: float | str | None = "auto",
accum_e_fp32: bool = False,
accum_c_fp32: bool = False,
filter_e_grad: bool = True,
filter_c_grad: bool = True,
train_only: bool = False,
) -> TransformersModelT | None:
if isinstance(impl, LinearCrossEntropyImpl):
impl = impl.name.lower()
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
raise ValueError(f"Unknown {impl=}")
if isinstance(model_type_or_model, transformers.PreTrainedModel):
model_type = model_type_or_model.config.model_type
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
model_type = model_type_or_model.model_type
else:
model_type = model_type_or_model
patch_options = PatchOptions(
impl=impl,
reduction=reduction,
filter_eps=filter_eps,
accum_e_fp32=accum_e_fp32,
accum_c_fp32=accum_c_fp32,
filter_e_grad=filter_e_grad,
filter_c_grad=filter_c_grad,
train_only=train_only,
)
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
model_type_or_model, patch_options
)
raise RuntimeError(f"Unknown model type {model_type}")

View File

@@ -114,5 +114,3 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")

View File

@@ -1,89 +0,0 @@
"""
Ring attention group registration and flash attention patching.
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
"""
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.logging_config import configure_logging
configure_logging()
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
def get_ring_attn_group() -> dist.ProcessGroup:
"""
Getter for ring attention group on this rank.
Returns:
The process group for ring attention for this rank.
"""
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
"""
Setter for ring attention group on this rank.
Args:
Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
"""
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
world_size = dist.get_world_size()
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
# Detailed logging of group formation
rank = dist.get_rank()
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
# Track which GPUs are in which groups
for r in ring_attn_ranks:
group_assignments[r] = i
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)

View File

@@ -22,9 +22,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"phi3",
"gemma",
"gemma2",
"gemma3_text",
"cohere",
"cohere2",
"gemmoe",
"starcoder2",
"deepseek_v2",

View File

@@ -1,313 +0,0 @@
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
import ast
from copy import deepcopy
from typing import Optional
from PIL import Image, ImageOps
from PIL.Image import Resampling
from torch import Tensor
from transformers import ProcessorMixin
from transformers.image_utils import load_image
class ProcessingStrategy:
"""Base Processing Strategy class"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
self.processor = processor
self.chat_template = chat_template
self.image_token = None
self.image_token_id = None
self.image_size = image_size
self.image_resize_algorithm = (
image_resize_algorithm or Image.Resampling.BILINEAR
)
if hasattr(processor, "image_token"):
self.image_token = processor.image_token
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
def __call__(self, examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
list of dicts in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{"role": normalize_role(convo["from"]), "content": convo["value"]}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
result["messages"] = messages
return result
def convert_multiple_choice_to_multimedia_messages(
messages: dict,
) -> list[dict]:
def construct_prompt(sample):
question = sample["question"]
options = sample["options"]
if isinstance(options, str):
options = ast.literal_eval(options)
example = ""
start_chr = "A"
prediction_range = []
index2ans = {}
for option in options:
prediction_range.append(start_chr)
example += f"({start_chr}) {option}\n"
index2ans[start_chr] = option
start_chr = chr(ord(start_chr) + 1)
empty_prompt_sample_structure = "{}\n\n{}\n\nAnswer with the option's letter from the given choices directly."
empty_prompt = empty_prompt_sample_structure.format(question, example)
return empty_prompt
new_messages = []
user_content = construct_prompt(messages)
assistant_response = messages["answer"]
new_messages.append(
{"role": "user", "content": [{"type": "text", "text": user_content}]}
)
new_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": assistant_response}],
}
)
return new_messages
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
"""Convert regular messages format to Messages format with content type"""
new_messages = []
for message in messages:
if isinstance(message["content"], str):
new_messages.append(
{
"role": message["role"],
"content": [
{
"type": "text",
"text": message["content"],
}
],
}
)
elif isinstance(message["content"], list):
content = message["content"]
new_messages.append(
{
"role": message["role"],
"content": content,
}
)
return new_messages
processed_examples = []
for example in examples:
if not (
"messages" in example
or "conversations" in example
or "question" in example
):
raise ValueError(
"Only `messages`, `conversations`, and `question` message keys are currently supported."
)
processed_example = None
if "messages" in example: # OpenAI format
processed_example = example
# convert regular messages format to Messages format with content type
# for compatibility with apply_chat_template
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
elif "question" in example: # Multiple choice format
processed_example = {}
processed_example["messages"] = (
convert_multiple_choice_to_multimedia_messages(example)
)
else: # Legacy format
processed_example = convert_legacy_format(example)
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
# find the image key if it exists
image_keys = []
for key in example.keys():
if "image" in key:
image_keys.append(key)
for im_key in image_keys:
if example[im_key] is None:
continue
if isinstance(example[im_key], list):
if len(example[im_key]) == 0:
continue
image_value = example[im_key][0]
else:
image_value = example[im_key]
image_value = load_image(image_value)
if self.image_size is not None:
assert hasattr(
image_value, "resize"
), "Image does not have a resize method"
if isinstance(self.image_size, tuple):
image_value = image_value.resize(
self.image_size, self.image_resize_algorithm
)
else:
# Set the padding value; here we use black (0, 0, 0) for RGB images
padding_color = (0, 0, 0)
# When image_size is an int (square target), preserve aspect ratio then pad
# This is to prevent aspect ratio distortion when resizing to square
image_value = ImageOps.pad(
image_value,
(self.image_size, self.image_size),
method=self.image_resize_algorithm,
color=padding_color,
)
processed_example["messages"][0]["content"].append(
{
"type": "image",
"image": image_value,
}
)
processed_examples.append(processed_example)
return processed_examples
def process_labels(self, input_ids: Tensor) -> Tensor:
labels = input_ids.clone()
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels[labels == self.processor.tokenizer.pad_token_id] = -100
# Ignore the image token index in the loss computation (model specific)
labels[labels == self.image_token_id] = -100
return labels
class Qwen2VLProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Qwen2-VL"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<|image_pad|>" # nosec
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
class Gemma3ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Gemma3"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = processor.tokenizer.special_tokens_map["boi_token"]
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
def process_labels(self, input_ids):
labels = input_ids.clone()
# Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora
labels[labels == self.processor.tokenizer.pad_token_id] = -100
labels[labels == self.image_token_id] = -100
labels[labels == 262144] = -100 # corresponds to <image_soft_token>
return labels
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
chat_template_type,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type == "gemma3":
return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [
"llama3_2_vision",
"llava",
"mistral_v7_tekken",
"pixtral",
]:
return ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
raise ValueError(f"Unsupported chat template type: {chat_template_type}")

View File

@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.schemas.datasets import DatasetConfig
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
# Configure the logger
LOG = logging.getLogger("axolotl")

View File

@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
def default(

View File

@@ -169,7 +169,7 @@ def execute_training(
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
):
"""
Execute the training process with appropriate SDP kernel configurations.
Execute the training process with appropriate backend configurations.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -177,6 +177,9 @@ def execute_training(
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -314,7 +317,6 @@ def save_initial_configs(
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
peft_config: PeftConfig | None,
processor: ProcessorMixin | None,
):
"""
Save initial configurations before training.
@@ -342,10 +344,6 @@ def save_initial_configs(
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
if processor:
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
processor.save_pretrained(str(output_dir))
def setup_model_card(cfg: DictDefault):
"""
@@ -413,7 +411,6 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
ProcessorMixin | None,
]:
"""
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
@@ -429,7 +426,6 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
- Model
- Tokenizer
- PEFT config
- Processor
"""
# Load tokenizer, processor and model
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
@@ -460,7 +456,6 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
model,
tokenizer,
peft_config,
processor,
)
@@ -483,7 +478,6 @@ def train(
model,
tokenizer,
peft_config,
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
@@ -499,7 +493,7 @@ def train(
)
# Save initial configs
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
save_initial_configs(cfg, tokenizer, model, peft_config)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)

View File

@@ -33,6 +33,7 @@ from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.distributed import (
barrier,
broadcast_dict,
@@ -42,7 +43,6 @@ from axolotl.utils.distributed import (
is_main_process,
zero_first,
)
from axolotl.utils.schemas.config import AxolotlInputConfig
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments

File diff suppressed because one or more lines are too long

View File

@@ -1,59 +1,14 @@
"""
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
includes logic for handling sequence parallelism collation.
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
logger = logging.getLogger(__name__)
def adjust_position_ids_for_slice(
position_ids: torch.Tensor, start_idx: int
) -> torch.Tensor:
"""
Adjust position IDs for a sliced sequence to maintain proper relative positions.
This handles the case where position IDs might not be contiguous due to sample
packing.
"""
# Convert to tensor if not already
# Find the boundaries between samples (where position_ids reset)
adjusted_pos_ids = position_ids.clone()
# Process each sequence in the batch
for i in range(position_ids.shape[0]):
seq = position_ids[i]
# Find sample boundaries
boundaries = []
for j in range(1, len(seq)):
if seq[j] < seq[j - 1]:
boundaries.append(j)
# No need to adjust if there are no boundaries or this is a single sample
if not boundaries:
adjusted_pos_ids[i] = seq - start_idx
continue
# Adjust each segment separately
prev_boundary = 0
for boundary in boundaries:
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
prev_boundary = boundary
# Last segment
adjusted_pos_ids[i, prev_boundary:] -= start_idx
return adjusted_pos_ids
@dataclass
class DataCollatorForSeq2Seq:
@@ -88,8 +43,6 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
"""
tokenizer: PreTrainedTokenizerBase
@@ -100,16 +53,6 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sequence_parallel_degree: int = 1
def __post_init__(self):
if self.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
labels = None
@@ -176,43 +119,8 @@ class DataCollatorForSeq2Seq:
)
features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features)
return features
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
for key in keys_to_slice:
if key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // self.local_world_size
start_idx = self.local_rank * slice_size
end_idx = (
start_idx + slice_size
if self.local_rank < self.local_world_size - 1
else seq_len
)
batch[key] = batch[key][:, start_idx:end_idx]
# Special handling for position_ids
if key == "position_ids" and self.local_rank > 0:
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
return batch
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@@ -240,7 +148,6 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)
@@ -270,7 +177,6 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -2,17 +2,15 @@
Collators for multi-modal chat messages and packing
"""
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
from torch import Tensor
from transformers import PreTrainedTokenizerBase
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
from axolotl.processing_strategies import ProcessingStrategy
@dataclass
class MultiModalChatDataCollator(DataCollatorMixin):
@@ -21,9 +19,11 @@ class MultiModalChatDataCollator(DataCollatorMixin):
"""
tokenizer: PreTrainedTokenizerBase
processing_strategy: ProcessingStrategy
packing: bool = False
processor: ProcessorMixin
return_tensors: str = "pt"
chat_template: Optional[str] = None
packing: bool = False
max_images: int = -1
padding: Union[bool, str, PaddingStrategy] = True
pad_to_multiple_of: Optional[int] = None
@@ -31,62 +31,162 @@ class MultiModalChatDataCollator(DataCollatorMixin):
if self.packing:
raise ValueError("Packing is currently not supported.")
def torch_call(self, examples: list[dict]) -> dict[str, Any]:
return self.process_rows(examples)
def torch_call(
self, examples: list[Union[list[int], Any, dict[str, Any]]]
) -> dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
examples, self.processor, self.chat_template, self.max_images
)
@staticmethod
def process_rows(examples, processor, chat_template, max_images, length_only=False):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
def _preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
return {"messages": messages, **result}
processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)
# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))
else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)
return processed_examples
def _process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""
def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images
images = [get_image(example) for example in examples]
# Count None and non-None images
none_count = sum(1 for img in images if img is None)
# All images are None
if none_count == len(images):
return None
# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)
# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]
return images
def process_rows(
self,
examples: list[dict],
) -> dict[str, Tensor]:
# Preprocess the examples
examples = self.processing_strategy(examples)
examples = _preprocess(examples)
# Initialize batch
batch: dict[str, Any] = {}
# Process each example
for example in examples:
# Apply chat template to process the example
# This method requires transformers>=4.49.0
result = self.processing_strategy.processor.apply_chat_template(
example["messages"],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
# TODO: Check if need handling for len(input_ids) > sequence_len
images = _process_images(examples, max_images=max_images)
# Add the processed tensors to our batch
for key in result.keys():
if key not in batch:
batch[key] = []
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
batch[key].append(result[key].squeeze(0))
# Pad sequences to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(
batch["input_ids"],
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels
attention_mask = torch.nn.utils.rnn.pad_sequence(
batch["attention_mask"], batch_first=True, padding_value=0
)
# Create the final batch
final_batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
# Process the labels
final_batch["labels"] = self.processing_strategy.process_labels(
final_batch["input_ids"]
)
return final_batch
if length_only:
return {
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
}
return batch

View File

@@ -12,13 +12,19 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config
from axolotl.utils.schemas.config import (
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.config.models.input.v0_4_1 import (
DPODataset,
KTODataset,
SFTDataset,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
LOG = logging.getLogger("axolotl")
@@ -125,9 +131,6 @@ def normalize_config(cfg):
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.sequence_parallel_degree is None:
cfg.sequence_parallel_degree = 1
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step
@@ -158,7 +161,7 @@ def normalize_config(cfg):
cfg.is_multimodal = (
hasattr(model_config, "model_type")
and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING
and model_config.model_type in ["llava", "mllama"]
or any(
multimodal_name in cfg.base_model.lower()
for multimodal_name in [
@@ -171,6 +174,7 @@ def normalize_config(cfg):
cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model
)
model_config = model_config.text_config
cfg.model_config_type = model_config.model_type

View File

@@ -1,4 +1,8 @@
"""Pydantic models for TRL trainer configuration"""
"""
GRPO specific configuration args
"""
from typing import Optional
from pydantic import BaseModel, Field
@@ -8,11 +12,11 @@ class TRLConfig(BaseModel):
Input args for TRL.
"""
beta: float | None = Field(
beta: Optional[float] = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
)
max_completion_length: int | None = Field(
max_completion_length: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the completion for RL training"
@@ -21,50 +25,50 @@ class TRLConfig(BaseModel):
# GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: bool | None = Field(
use_vllm: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_device: str | None = Field(
vllm_device: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
vllm_gpu_memory_utilization: float | None = Field(
vllm_gpu_memory_utilization: Optional[float] = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
vllm_dtype: str | None = Field(
vllm_dtype: Optional[str] = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: int | None = Field(
vllm_max_model_len: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
reward_funcs: list[str] | None = Field(
reward_funcs: Optional[list[str]] = Field(
default=None,
json_schema_extra={"description": "List of reward functions to load"},
)
reward_weights: list[float] | None = Field(
reward_weights: Optional[list[float]] = Field(
default=None,
json_schema_extra={
"description": "Weights for each reward function. Must match the number of reward functions."
},
)
num_generations: int | None = Field(
num_generations: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
)
log_completions: bool | None = Field(
log_completions: Optional[bool] = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
)
sync_ref_model: bool | None = Field(
sync_ref_model: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": (
@@ -73,13 +77,13 @@ class TRLConfig(BaseModel):
)
},
)
ref_model_mixup_alpha: float | None = Field(
ref_model_mixup_alpha: Optional[float] = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
)
ref_model_sync_steps: int | None = Field(
ref_model_sync_steps: Optional[int] = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."

View File

@@ -34,16 +34,12 @@ from transformers import ( # noqa: F401
AutoTokenizer,
AwqConfig,
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
GPTQConfig,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
)
from transformers.integrations.deepspeed import (
HfTrainerDeepSpeedConfig,
@@ -71,16 +67,7 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger(__name__)
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
"llava": LlavaForConditionalGeneration,
"qwen2_vl": Qwen2VLForConditionalGeneration,
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
}
LOG = logging.getLogger("axolotl")
# copied from accelerator.FullyShardedDataParallelPlugin
@@ -109,21 +96,7 @@ def get_module_class_from_name(module, name):
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
if cfg.is_multimodal:
if hasattr(model_config, "text_config"):
model_config = model_config.text_config
model_config.use_cache = False
elif hasattr(model_config, "get_text_config"):
model_config = model_config.get_text_config()
model_config.use_cache = False
# check if image_size is not set and load image size from model config if available
if (
cfg.image_size is None
and hasattr(model_config, "vision_config")
and hasattr(model_config.vision_config, "image_size")
):
cfg.image_size = model_config.vision_config.image_size
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
model_config = model_config.text_config
quant_config_exists = (
hasattr(model_config, "quantization_config")
@@ -462,31 +435,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
**processor_kwargs,
)
# Attempt to load image size from processor if available
if (
cfg.image_size is None
and hasattr(processor, "size")
and any(dim in processor.size for dim in ["width", "height"])
):
im_width = None
im_height = None
if "width" in processor.size:
im_width = processor.size["width"]
if "height" in processor.size:
im_height = processor.size["height"]
# If both width and height are set, use a tuple
if im_width is not None and im_height is not None:
cfg.image_size = (im_width, im_height)
# If only width is set, use as integer
elif im_width is not None:
cfg.image_size = im_width
# If only height is set, use as integer
elif im_height is not None:
cfg.image_size = im_height
LOG.debug(f"Loaded image size: {cfg.image_size} from processor")
return processor
@@ -524,15 +472,11 @@ class ModelLoader:
# init model config
self.model_config = load_model_config(cfg)
if cfg.is_multimodal:
if hasattr(self.model_config, "text_config"):
self.text_model_config = self.model_config.text_config
else:
# for qwen2_vl
self.text_model_config = self.model_config.get_text_config()
self.text_model_config = self.model_config.text_config
else:
self.text_model_config = self.model_config
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
# load any patches from plugins
@@ -603,14 +547,6 @@ class ModelLoader:
patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(self.cfg.sequence_parallel_degree)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -667,7 +603,7 @@ class ModelLoader:
patch_self_attn_lora()
def patch_llama_derived_model(self):
def patch_llama_derived_model(self) -> None:
"""Modify all llama derived models in one block"""
self.patch_loss_llama()
@@ -717,16 +653,25 @@ class ModelLoader:
"Shifted-sparse attention not currently implemented without flash attention."
)
def set_auto_model_loader(self):
"""
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
should be set according to the type of the model.
def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__)
- when using a multi modality model, self.AutoModelLoader should
be set according to model type of the model
"""
if self.cfg.is_multimodal:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
)
if self.model_config.model_type == "llava":
self.AutoModelLoader = ( # pylint: disable=invalid-name
LlavaForConditionalGeneration
)
elif self.model_config.model_type == "mllama":
self.AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration
)
else:
self.AutoModelLoader = (
AutoModelForVision2Seq # pylint: disable=invalid-name
)
def set_device_map_config(self) -> None:
device_map = self.cfg.device_map
@@ -750,7 +695,7 @@ class ModelLoader:
from accelerate import infer_auto_device_map
with init_empty_weights():
model_canvas = self.auto_model_loader.from_config(
model_canvas = self.AutoModelLoader.from_config(
self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
)
@@ -971,27 +916,11 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
if self.auto_model_loader in [
AutoModelForCausalLM,
AutoModelForVision2Seq,
]:
self.model = self.auto_model_loader.from_config(
config=self.model_config,
)
else:
self.model = self.auto_model_loader(
config=self.model_config,
)
else:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
)
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
**self.model_kwargs,
)
# TODO (MengqingCao) split these patches seperately
if self.cfg.flash_attention and not self.inference:
@@ -1029,7 +958,7 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -1062,7 +991,7 @@ class ModelLoader:
if self.cfg.gptq:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained(
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -1082,7 +1011,7 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained(
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -1245,9 +1174,7 @@ class ModelLoader:
)
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None and not (
self.model_config.model_type == "llava"
):
if self.cfg.mean_resizing_embeddings is not None:
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
else:
@@ -1380,7 +1307,7 @@ def load_model(
"""
Load a model for a given configuration and tokenizer.
"""
model_loader = ModelLoader(
loader = ModelLoader(
cfg,
tokenizer,
processor=processor,
@@ -1388,7 +1315,7 @@ def load_model(
reference_model=reference_model,
**kwargs,
)
return model_loader.load_model()
return loader.load_model()
def load_adapter(model, cfg, adapter, inference=False):

View File

@@ -104,7 +104,9 @@ def allocate(
class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack"""
"""
Batch Sampler class for multipack
"""
def __init__(
self,

View File

@@ -1,165 +0,0 @@
"""Pydantic models for datasets-related configuration"""
from pydantic import BaseModel, model_validator
from axolotl.utils.schemas.enums import ChatTemplate
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
class UserDefinedPrompterType(BaseModel):
"""Structure for user defined prompt types"""
system_prompt: str | None = None
system_format: str | None = None
field_system: str | None = None
field_instruction: str | None = None
field_input: str | None = None
field_output: str | None = None
format: str | None = None
no_input_format: str | None = None
field: str | None = None
class SFTDataset(BaseModel):
"""SFT configuration subset"""
path: str | None = None
split: str | None = None
type: str | UserDefinedPrompterType | None = None
input_transform: str | None = None
shards: int | None = None
shards_idx: int | None = None
preprocess_shards: int | None = None
conversation: str | None = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: ChatTemplate | str | None = None
chat_template_jinja: str | None = None
data_files: str | list[str] | None = None
input_format: str | None = None
name: str | None = None
ds_type: str | None = None
train_on_split: str | None = None
field: str | None = None
field_human: str | None = None
field_model: str | None = None
field_messages: str | None = None
# deprecated, use message_property_mappings
message_field_role: str | None = None
# deprecated, use message_property_mappings
message_field_content: str | None = None
message_property_mappings: dict[str, str] | None = None
message_field_training: str | None = None
message_field_training_detail: str | None = None
logprobs_field: str | None = None
temperature: float | None = None
roles_to_train: list[str] | None = None
train_on_eos: str | None = None
roles: dict[str, list[str]] | None = None
drop_system_message: bool | None = None
trust_remote_code: bool | None = False
revision: str | None = None
@model_validator(mode="before")
@classmethod
def handle_legacy_message_fields(cls, data):
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
return handle_legacy_message_fields_logic(data)
@model_validator(mode="before")
@classmethod
# pylint: disable=duplicate-code
def check_chat_template_config(cls, data):
if isinstance(data, BaseModel):
data = data.model_dump()
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
class PretrainingDataset(BaseModel):
"""Pretraining dataset configuration subset"""
name: str | None = None
path: str | None = None
split: str | None = "train"
text_column: str | None = "text"
type: str | None = "pretrain"
trust_remote_code: bool | None = False
data_files: str | None = None
skip: int | None = None
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
field_system: str | None = None
field_prompt: str | None = None
field_chosen: str | None = None
field_rejected: str | None = None
prompt_format: str | None = None
chosen_format: str | None = None
rejected_format: str | None = None
class DPODataset(BaseModel):
"""DPO configuration subset"""
path: str | None = None
split: str | None = None
type: UserDefinedDPOType | str | None = None
data_files: list[str] | None = None
revision: str | None = None
field_messages: str | None = None
class StepwiseSupervisedDataset(BaseModel):
"""Stepwise supervised dataset configuration subset"""
path: str | None = None
split: str | None = None
data_files: list[str] | None = None
revision: str | None = None
step_separator: str | None = None
max_completion_length: int | None = None
train_on_last_step_only: bool | None = None
class UserDefinedKTOType(BaseModel):
"""User defined typing for KTO"""
field_system: str | None = None
field_prompt: str | None = None
field_completion: str | None = None
field_label: bool | None = None
prompt_format: str | None = None
completion_format: str | None = None
class KTODataset(BaseModel):
"""KTO configuration subset"""
path: str | None = None
split: str | None = None
type: UserDefinedKTOType | str | None = None
data_files: list[str] | None = None
trust_remote_code: bool | None = False
revision: str | None = None
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset

View File

@@ -1,68 +0,0 @@
"""Pydantic models for deprecated and remapped configuration parameters"""
import logging
from typing import Any
from pydantic import BaseModel, Field, field_validator
LOG = logging.getLogger(__name__)
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
max_packed_sequence_len: int | None = None
rope_scaling: Any | None = None
noisy_embedding_alpha: float | None = None
dpo_beta: float | None = None
evaluation_strategy: str | None = None
@field_validator("max_packed_sequence_len")
@classmethod
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
if max_packed_sequence_len:
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
return max_packed_sequence_len
@field_validator("rope_scaling")
@classmethod
def validate_rope_scaling(cls, rope_scaling):
if rope_scaling:
raise DeprecationWarning(
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
)
return rope_scaling
@field_validator("noisy_embedding_alpha")
@classmethod
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
if noisy_embedding_alpha:
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
return noisy_embedding_alpha
@field_validator("dpo_beta")
@classmethod
def validate_dpo_beta(cls, dpo_beta):
if dpo_beta is not None:
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
return dpo_beta
@field_validator("evaluation_strategy")
@classmethod
def validate_evaluation_strategy(cls, evaluation_strategy):
if evaluation_strategy is not None:
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
return evaluation_strategy
class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names"""
overrides_of_model_config: dict[str, Any] | None = Field(
default=None, alias="model_config"
)
overrides_of_model_kwargs: dict[str, Any] | None = Field(
default=None, alias="model_kwargs"
)
type_of_model: str | None = Field(default=None, alias="model_type")
revision_of_model: str | None = Field(default=None, alias="model_revision")

View File

@@ -1,54 +0,0 @@
"""Enums for Axolotl input config"""
from enum import Enum
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
grpo = "grpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name
pixtral = "pixtral" # pylint: disable=invalid-name
llava = "llava" # pylint: disable=invalid-name
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
gemma3 = "gemma3" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name

View File

@@ -1,108 +0,0 @@
"""Pydantic models for Axolotl integrations"""
import logging
from typing import Any
from pydantic import BaseModel, Field, model_validator
LOG = logging.getLogger(__name__)
class MLFlowConfig(BaseModel):
"""MLFlow configuration subset"""
use_mlflow: bool | None = None
mlflow_tracking_uri: str | None = None
mlflow_experiment_name: str | None = None
mlflow_run_name: str | None = None
hf_mlflow_log_artifacts: bool | None = None
class LISAConfig(BaseModel):
"""LISA configuration subset"""
lisa_n_layers: int | None = Field(
default=None,
json_schema_extra={"description": "the number of activate layers in LISA"},
)
lisa_step_interval: int | None = Field(
default=None,
json_schema_extra={"description": "how often to switch layers in LISA"},
)
lisa_layers_attribute: str | None = Field(
default="model.layers",
json_schema_extra={"description": "path under the model to access the layers"},
)
class WandbConfig(BaseModel):
"""Wandb configuration subset"""
use_wandb: bool | None = None
wandb_name: str | None = None
wandb_run_id: str | None = None
wandb_mode: str | None = None
wandb_project: str | None = None
wandb_entity: str | None = None
wandb_watch: str | None = None
wandb_log_model: str | None = None
@model_validator(mode="before")
@classmethod
def check_wandb_run(cls, data):
if data.get("wandb_run_id") and not data.get("wandb_name"):
data["wandb_name"] = data.get("wandb_run_id")
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
return data
class CometConfig(BaseModel):
"""Comet configuration subset"""
use_comet: bool | None = None
comet_api_key: str | None = None
comet_workspace: str | None = None
comet_project_name: str | None = None
comet_experiment_key: str | None = None
comet_mode: str | None = None
comet_online: bool | None = None
comet_experiment_config: dict[str, Any] | None = None
class GradioConfig(BaseModel):
"""Gradio configuration subset"""
gradio_title: str | None = None
gradio_share: bool | None = None
gradio_server_name: str | None = None
gradio_server_port: int | None = None
gradio_max_new_tokens: int | None = None
gradio_temperature: float | None = None
class RayConfig(BaseModel):
"""Ray launcher configuration subset"""
use_ray: bool = Field(default=False)
ray_run_name: str | None = Field(
default=None,
json_schema_extra={
"help": "The training results will be saved at `saves/ray_run_name`."
},
)
ray_num_workers: int = Field(
default=1,
json_schema_extra={
"help": "The number of workers for Ray training. Default is 1 worker."
},
)
resources_per_worker: dict = Field(
default_factory=lambda: {"GPU": 1},
json_schema_extra={
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
)

View File

@@ -1,55 +0,0 @@
"""Pydantic models for model input / output, etc. configuration"""
import logging
from pydantic import BaseModel, Field, field_validator
LOG = logging.getLogger(__name__)
class ModelInputConfig(BaseModel):
"""Model configuration subset"""
model_config = {"protected_namespaces": ()}
base_model: str
base_model_config: str | None = None
cls_model_config: str | None = None
tokenizer_config: str | None = None
tokenizer_use_fast: bool | None = None
tokenizer_legacy: bool | None = None
tokenizer_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers tokenizer class"}
)
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
trust_remote_code: bool | None = None
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
if trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
)
return trust_remote_code
class ModelOutputConfig(BaseModel):
"""model save configuration subset"""
output_dir: str = Field(default="./model-out")
hub_model_id: str | None = None
hub_strategy: str | None = None
save_safetensors: bool | None = True
class SpecialTokensConfig(BaseModel):
"""Special tokens configuration subset"""
bos_token: str | None = None
eos_token: str | None = None
pad_token: str | None = None
unk_token: str | None = None
additional_special_tokens: list[str] | None = None

View File

@@ -1,48 +0,0 @@
"""Pydantic models for multimodal-related configuration"""
from typing import Literal
from PIL.Image import Resampling
from pydantic import BaseModel, Field, field_validator
class MultiModalConfig(BaseModel):
"""Multi-modal configuration subset"""
image_size: int | tuple[int, int] | None = Field(
default=None,
json_schema_extra={
"description": (
"The size of the image to resize to. It can be an integer (resized into padded-square image) or a tuple (width, height)."
"If not provided, we will attempt to load from preprocessor.size, otherwise, images won't be resized."
)
},
)
image_resize_algorithm: (
Literal["bilinear", "bicubic", "lanczos"] | Resampling | None
) = Field(
default=None,
json_schema_extra={
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
},
)
@field_validator("image_resize_algorithm", mode="before")
@classmethod
def convert_image_resize_algorithm(cls, image_resize_algorithm):
"""
Convert the image resize algorithm to a PIL.Image.Resampling enum.
"""
if isinstance(image_resize_algorithm, str):
image_resize_algorithm = image_resize_algorithm.lower()
if image_resize_algorithm == "bilinear":
image_resize_algorithm = Resampling.BILINEAR
elif image_resize_algorithm == "bicubic":
image_resize_algorithm = Resampling.BICUBIC
elif image_resize_algorithm == "lanczos":
image_resize_algorithm = Resampling.LANCZOS
else:
raise ValueError(
f"Invalid image resize algorithm: {image_resize_algorithm}"
)
return image_resize_algorithm

View File

@@ -1,132 +0,0 @@
"""Pydantic models for PEFT-related configuration"""
from typing import Any
from pydantic import BaseModel, Field, field_validator, model_validator
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
)
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
class PeftConfig(BaseModel):
"""peftq configuration subset"""
loftq_config: LoftQConfig | None = None
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = Field(default=False)
adapter: str | None = None
lora_model_dir: str | None = None
lora_r: int | None = None
lora_alpha: int | None = None
lora_fan_in_fan_out: bool | None = None
lora_target_modules: str | list[str] | None = None
lora_target_linear: bool | None = None
lora_modules_to_save: list[str] | None = None
lora_dropout: float | None = 0.0
peft_layers_to_transform: list[int] | None = None
peft_layers_pattern: list[str] | None = None
peft: PeftConfig | None = None
peft_use_dora: bool | None = None
peft_use_rslora: bool | None = None
peft_layer_replication: list[tuple[int, int]] | None = None
peft_init_lora_weights: bool | str | None = None
qlora_sharded_model_loading: bool | None = Field(
default=False,
json_schema_extra={
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field(
default=None,
json_schema_extra={
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
},
)
loraplus_lr_embedding: float | None = Field(
default=1e-6,
json_schema_extra={
"description": "loraplus learning rate for lora embedding layers."
},
)
merge_lora: bool | None = None
@model_validator(mode="before")
@classmethod
def validate_adapter(cls, data):
if (
not data.get("adapter")
and not data.get("inference")
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
return data
@model_validator(mode="after")
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.gptq:
raise ValueError("Can't merge qlora if gptq")
if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.gptq:
raise ValueError("Can't load qlora if gptq")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@field_validator("loraplus_lr_embedding")
@classmethod
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding
@model_validator(mode="before")
@classmethod
def validate_lora_dropout(cls, data):
if data.get("adapter") is not None and data.get("lora_dropout") is None:
data["lora_dropout"] = 0.0
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
relora_steps: int | None = None
relora_warmup_steps: int | None = None
relora_anneal_steps: int | None = None
relora_prune_ratio: float | None = None
relora_cpu_offload: bool | None = None

View File

@@ -1,99 +0,0 @@
"""Pydantic models for training hyperparameters"""
import logging
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
class LrGroup(BaseModel):
"""Custom learning rate group configuration"""
name: str
modules: list[str]
lr: float
class HyperparametersConfig(BaseModel):
"""Training hyperparams configuration subset"""
gradient_accumulation_steps: int | None = Field(default=1)
micro_batch_size: int | None = Field(
default=1,
json_schema_extra={"description": "per gpu micro batch size for training"},
)
batch_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Total batch size, we do not recommended setting this manually"
},
)
eval_batch_size: int | None = Field(
default=None,
json_schema_extra={
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
},
)
auto_find_batch_size: bool | None = None
train_on_inputs: bool | None = False
group_by_length: bool | None = None
learning_rate: str | float
embedding_lr: float | None = None
embedding_lr_scale: float | None = None
weight_decay: float | None = 0.0
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
OptimizerNames.ADAMW_TORCH_FUSED
)
optim_args: (str | dict[str, Any]) | None = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
)
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
default=None,
json_schema_extra={
"description": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
torchdistx_path: str | None = None
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
SchedulerType.COSINE
)
lr_scheduler_kwargs: dict[str, Any] | None = None
lr_quadratic_warmup: bool | None = None
cosine_min_lr_ratio: float | None = None
cosine_constant_lr_ratio: float | None = None
lr_div_factor: float | None = None
lr_groups: list[LrGroup] | None = None
adam_epsilon: float | None = None
adam_beta1: float | None = None
adam_beta2: float | None = None
max_grad_norm: float | None = None
num_epochs: float = Field(default=1.0)
@field_validator("batch_size")
@classmethod
def hint_batch_size_set(cls, batch_size):
if batch_size:
LOG.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
return batch_size
@field_validator("learning_rate")
@classmethod
def convert_learning_rate(cls, learning_rate):
if learning_rate and isinstance(learning_rate, str):
learning_rate = float(learning_rate)
return learning_rate

View File

@@ -1,79 +0,0 @@
"""Utilities for Axolotl Pydantic models"""
import logging
LOG = logging.getLogger(__name__)
def handle_legacy_message_fields_logic(data: dict) -> dict:
"""
Handle backwards compatibility between legacy message field mapping and new property mapping system.
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
- message_field_role: Mapped to the role field
- message_field_content: Mapped to the content field
The new system uses message_property_mappings to support arbitrary field mappings:
message_property_mappings:
role: source_role_field
content: source_content_field
additional_field: source_field
Args:
data: Dictionary containing configuration data
Returns:
Updated dictionary with message field mappings consolidated
Raises:
ValueError: If there are conflicts between legacy and new mappings
"""
data = data.copy() # Create a copy to avoid modifying the original
if data.get("message_property_mappings") is None:
data["message_property_mappings"] = {}
# Check for conflicts and handle role
if "message_field_role" in data:
LOG.warning(
"message_field_role is deprecated, use message_property_mappings instead. "
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
)
if (
"role" in data["message_property_mappings"]
and data["message_property_mappings"]["role"] != data["message_field_role"]
):
raise ValueError(
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
)
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
del data["message_field_role"]
elif "role" not in data["message_property_mappings"]:
data["message_property_mappings"]["role"] = "role"
# Check for conflicts and handle content
if "message_field_content" in data:
LOG.warning(
"message_field_content is deprecated, use message_property_mappings instead. "
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
)
if (
"content" in data["message_property_mappings"]
and data["message_property_mappings"]["content"]
!= data["message_field_content"]
):
raise ValueError(
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
)
data["message_property_mappings"]["content"] = (
data["message_field_content"] or "content"
)
del data["message_field_content"]
elif "content" not in data["message_property_mappings"]:
data["message_property_mappings"]["content"] = "content"
return data

View File

@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
elif cfg.sample_packing:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
@@ -443,7 +443,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1
)
* cfg.num_epochs
* cfg.sequence_parallel_degree
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
@@ -474,11 +473,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
math.floor(
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
)
)
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -499,12 +494,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
)
else:
total_num_steps = int(
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_degree
/ cfg.batch_size
)
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
return total_num_steps

View File

@@ -14,7 +14,7 @@
h1 {
font-family: var(--font-title);
font-weight: 400;
font-size: 3rem;
font-size: 5rem;
line-height: 1.1;
letter-spacing: -0.05em;
font-feature-settings: "ss01" on;
@@ -24,7 +24,7 @@ h1 {
h2 {
font-family: var(--font-title);
font-weight: 500;
font-size: 1.5rem;
font-size: 2rem;
line-height: 1.2;
letter-spacing: -0.03em;
font-feature-settings: "ss01" on;
@@ -35,7 +35,7 @@ h3,
h4 {
font-family: var(--font-body);
font-weight: 400;
font-size: 1.25rem;
font-size: 1.5rem;
line-height: 1.5;
letter-spacing: -0.02em;
}
@@ -191,87 +191,3 @@ code span.er {
color: #5cb85c !important;
text-decoration: none !important;
}
/* API Documentation Styling */
/* Improve docstring section rendering */
.level3 p {
white-space: pre-line !important;
}
/* Format docstring sections */
.level3 p strong {
display: block;
margin-top: 1em;
font-weight: bold;
color: var(--cyan);
}
/* Add spacing after sections */
.level3 p:has(strong) {
margin-bottom: 0.5em;
}
/* Format Args and Returns sections */
p:has(code) {
line-height: 1.6;
}
/* Function signatures */
.sourceCode {
margin-bottom: 1.5em;
}
/* Parameter tables */
.doc-section-parameters table,
.doc-section-returns table {
margin-top: 1em;
margin-bottom: 1.5em;
}
/* Make parameter and returns headers smaller */
h2.anchored[data-anchor-id="parameters"],
h2.anchored[data-anchor-id="returns"],
.doc-section-parameters h4,
.doc-section-returns h4 {
font-size: 1.25rem;
margin-top: 2rem;
margin-bottom: 1rem;
color: var(--lime);
border-bottom: 1px solid var(--lime);
padding-bottom: 0.3rem;
font-family: var(--font-body);
font-weight: 500;
letter-spacing: normal;
}
/* Style documentation tables */
table {
width: 100%;
margin-bottom: 1.5rem;
border-collapse: collapse;
}
table th {
background-color: #1a1a1a;
padding: 0.5rem 1rem;
border-bottom: 2px solid var(--greige-600);
text-align: left;
}
table td {
padding: 0.5rem 1rem;
border-bottom: 1px solid var(--greige-600);
}
/* Code in table cells */
table td code {
background-color: transparent !important;
padding: 0;
}
/* Improve spacing in parameter and return tables */
.doc-section-parameters,
.doc-section-returns {
margin-top: 1rem;
}

View File

@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
)
peft_config = get_peft_config(
{
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
)
# Apply LoRA configuration
@@ -408,7 +408,7 @@ def test_kernel_training_integration():
)
# Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg)
model, _ = load_model_and_tokenizer(cfg=cfg)
# Verify correct activation function
layer = model.model.model.layers[0]

View File

@@ -1,209 +0,0 @@
"""Tests for sequence parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
set_ring_attn_group,
)
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
from axolotl.utils.dict import DictDefault
@pytest.fixture
def partial_state():
"""Create a real PartialState instance for testing."""
state = PartialState()
return state
@pytest.fixture(name="cfg")
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
)
return cfg
class TestSequenceParallelHelpers:
"""Test helper functions used in sequence parallelism."""
def test_adjust_position_ids_for_slice(self, partial_state):
"""Test position_ids adjustment for sequence slices."""
# Create sample position_ids with multiple sequences
position_ids = torch.tensor(
[
# First sequence with 2 samples
[0, 1, 2, 3, 4, 0, 1, 2, 3],
# Second sequence with 3 samples
[0, 1, 2, 0, 1, 2, 3, 0, 1],
]
)
# Adjust as if this was the second slice (start_idx = 4)
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
assert torch.all(adjusted[0] == expected_first_seq)
assert torch.all(adjusted[1] == expected_second_seq)
class TestRingAttention:
"""Tests for the ring attention functionality."""
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group returns None when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Get the group without registration
group = get_ring_attn_group()
# Verify that None was returned
assert group is None
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_register_ring_attn(
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
mock_group = MagicMock()
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2
# Verify that new_group was called
mock_new_group.assert_called()
# Clean up
set_ring_attn_group(None)
# Mock a simplified DataCollator test
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing(
mock_world_size, mock_rank, mock_get_group, partial_state
):
"""Test the basic sequence slicing logic without full collator instantiation."""
# Setup mocks
mock_get_group.return_value = MagicMock()
mock_rank.return_value = 1 # Second GPU
mock_world_size.return_value = 4 # 4 GPUs total
# Create a sample batch
batch = {
"input_ids": torch.tensor(
[
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
]
),
"attention_mask": torch.ones(2, 12),
}
# Simplified slicing logic from SequenceParallelDataCollator
def slice_batch(batch, rank, world_size):
result = {}
for key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
result[key] = batch[key][:, start_idx:end_idx]
return result
# Slice the batch
result = slice_batch(
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
)
# Check slicing
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
expected_input_ids = torch.tensor(
[
[104, 105, 106], # Second slice of first sequence
[204, 205, 206], # Second slice of second sequence
]
)
assert torch.all(result["input_ids"] == expected_input_ids)
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks
from axolotl.utils.schemas.config import AxolotlInputConfig
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
}
# Should validate without errors
config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
def test_config_validation_with_invalid_inputs(cfg):
"""Test that invalid sequence parallelism configurations fail validation."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": False,
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "flash_attention: true must be set" in str(excinfo.value)

View File

@@ -1,5 +1,5 @@
"""
E2E tests for deepseekv3
E2E tests for lora llama
"""
import logging

View File

@@ -1,133 +0,0 @@
"""
E2E tests for gemma2
"""
import logging
import os
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestGemma2:
"""
Test case for Gemma2 models
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_lora_gemma2(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-2-33M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"drop_system_message": True,
"split": "train[:1%]",
},
],
"special_tokens": {
"bos_token": "<bos>",
"eos_token": "<eos>",
},
"chat_template": "gemma", # gemma2's template is same as gemma
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_fft_gemma2(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-2-33M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"split": "train[:1%]",
"drop_system_message": True,
},
],
"chat_template": "gemma", # gemma2's template is same as gemma
"special_tokens": {
"bos_token": "<bos>",
"eos_token": "<eos>",
},
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -1,131 +0,0 @@
"""
E2E tests for gemma3_text
"""
import logging
import os
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestGemma3Text:
"""
Test case for Gemma3Text models
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_lora_gemma3_text(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-3-34M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"split": "train[:1%]",
},
],
"special_tokens": {
"bos_token": "<bos>",
"eos_token": "<eos>",
},
"chat_template": "gemma3",
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_fft_gemma3_text(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-3-34M",
"trust_remote_code": True,
"sample_packing": sample_packing,
"flash_attention": True,
"sequence_len": 2048,
"val_set_size": 0,
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"split": "train[:1%]",
},
],
"chat_template": "gemma3",
"special_tokens": {
"bos_token": "<bos>",
"eos_token": "<eos>",
},
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -54,7 +54,7 @@ class TestCustomSchedulers(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_hf",
"max_steps": 20,
"lr_scheduler": "rex",
"warmup_steps": 5,

View File

@@ -11,10 +11,10 @@ from pydantic import ValidationError
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import check_model_config
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
from axolotl.utils.wandb_ import setup_wandb_env_vars
warnings.filterwarnings("error")

View File

@@ -12,7 +12,6 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
@@ -263,7 +262,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault(
{
"base_model": "huggyllama/llama-7b",
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"dataset_exact_deduplication": True,
@@ -284,7 +282,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
"num_epochs": 1,
}
)
normalize_config(self.cfg_1)
def test_prepare_dataset_with_deduplication_train(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""

View File

@@ -6,8 +6,8 @@ from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.datasets import ChatTemplate
warnings.filterwarnings("error")

View File

@@ -119,7 +119,7 @@ class TestModelsUtils:
def test_message_property_mapping(self):
"""Test message property mapping configuration validation"""
from axolotl.utils.schemas.datasets import SFTDataset
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
# Test legacy fields are mapped orrectly
dataset = SFTDataset(