Compare commits
1 Commits
revert-290
...
fix/rl-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d47093fcdd |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -29,11 +29,11 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.5.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
|
||||
18
.github/workflows/main.yml
vendored
18
.github/workflows/main.yml
vendored
@@ -15,15 +15,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras: vllm
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
@@ -82,17 +82,17 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
|
||||
8
.github/workflows/multi-gpu-e2e.yml
vendored
8
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -33,11 +33,11 @@ jobs:
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras: vllm
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
|
||||
28
.github/workflows/nightlies.yml
vendored
28
.github/workflows/nightlies.yml
vendored
@@ -12,16 +12,16 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -65,16 +65,16 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
17
.github/workflows/preview-docs.yml
vendored
17
.github/workflows/preview-docs.yml
vendored
@@ -28,8 +28,6 @@ jobs:
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
- name: Set up Quarto
|
||||
uses: quarto-dev/quarto-actions/setup@v2
|
||||
@@ -52,11 +50,10 @@ jobs:
|
||||
|
||||
- name: Netlify Publish
|
||||
uses: nwtgck/actions-netlify@v3.0
|
||||
id: netlify
|
||||
with:
|
||||
publish-dir: './_site'
|
||||
enable-pull-request-comment: false
|
||||
enable-github-deployment: false
|
||||
enable-pull-request-comment: true
|
||||
enable-github-deployment: true
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
deploy-message: "Deployed On Netlify"
|
||||
github-deployment-environment: 'preview'
|
||||
@@ -64,13 +61,3 @@ jobs:
|
||||
env:
|
||||
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
|
||||
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
||||
|
||||
- name: Update PR with preview link
|
||||
if: ${{ steps.netlify.outcome == 'success' }}
|
||||
uses: marocchino/sticky-pull-request-comment@v2
|
||||
with:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
message: |
|
||||
📖 **Documentation Preview**: ${{ steps.netlify.outputs.deploy-url }}
|
||||
|
||||
Deployed on Netlify from commit ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
8
.github/workflows/tests-nightly.yml
vendored
8
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.0"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -80,9 +80,9 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 tests/cli/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
|
||||
30
.github/workflows/tests.yml
vendored
30
.github/workflows/tests.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -102,9 +102,9 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
@@ -125,7 +125,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -175,9 +175,9 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 tests/cli/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -198,7 +198,7 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
@@ -252,6 +252,18 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras: llmcompressor
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
|
||||
@@ -97,7 +97,7 @@
|
||||
# # 'no_input_format' cannot include {input}
|
||||
# no_input_format: "{instruction} "
|
||||
|
||||
# # For `completion` datasets only, uses the provided field instead of `text` column
|
||||
# # For `completion` datsets only, uses the provided field instead of `text` column
|
||||
# field:
|
||||
|
||||
# # Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
|
||||
@@ -55,7 +55,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.6.0
|
||||
- PyTorch ≥2.5.1
|
||||
|
||||
### Installation
|
||||
|
||||
|
||||
@@ -276,7 +276,6 @@ website:
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||
"CUDA": os.environ.get("CUDA", "126"),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
||||
"CUDA": os.environ.get("CUDA", "124"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
|
||||
@@ -24,9 +24,9 @@ df_template = template_env.get_template(dockerfile)
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||
"CUDA": os.environ.get("CUDA", "126"),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
||||
"CUDA": os.environ.get("CUDA", "124"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
|
||||
@@ -187,7 +187,6 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
"role": "assistant", // call the function via assistant
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "...", // required only for mistral
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "...",
|
||||
@@ -200,7 +199,6 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "...", // required only for mistral
|
||||
"name": "...",
|
||||
"content": "..."
|
||||
},
|
||||
|
||||
@@ -34,9 +34,9 @@ Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.0`
|
||||
- `main-base-py3.11-cu126-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.5.1`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -76,12 +76,12 @@ Tags examples:
|
||||
|
||||
- `main-py3.11-cu128-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.0`
|
||||
- `main-py3.11-cu126-2.6.0`
|
||||
- `main-py3.11-cu124-2.6.0`
|
||||
- `main-py3.11-cu124-2.5.1`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu126-2.6.0`
|
||||
- `main-20250303-py3.11-cu124-2.5.1`
|
||||
- `0.10.1`
|
||||
|
||||
## Cloud
|
||||
|
||||
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||
|
||||
1. Set `adapter: qlora` in your axolotl config file.
|
||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Example Config
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
---
|
||||
title: Gradient Checkpointing and Activation Offloading
|
||||
---
|
||||
|
||||
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
||||
models by reducing the memory footprint and improving computational efficiency.
|
||||
|
||||
### Enabling Gradient Checkpointing
|
||||
|
||||
```yaml
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
### Enabling Activation Offloading
|
||||
|
||||
```yaml
|
||||
gradient_checkpointing: true # required for activation offloading
|
||||
activation_offloading: true
|
||||
```
|
||||
|
||||
Activation offloading variants:
|
||||
|
||||
The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams
|
||||
to overlap the communications and computations when offloading.
|
||||
|
||||
The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations.
|
||||
|
||||
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
||||
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
||||
@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.11
|
||||
- PyTorch ≥2.6.0
|
||||
- PyTorch ≥2.5.1
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ Axolotl supports several methods for multi-GPU training:
|
||||
|
||||
## DeepSpeed {#sec-deepspeed}
|
||||
|
||||
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
|
||||
|
||||
### Configuration {#sec-deepspeed-config}
|
||||
|
||||
Add to your YAML config:
|
||||
@@ -30,6 +32,7 @@ Add to your YAML config:
|
||||
```{.yaml}
|
||||
deepspeed: deepspeed_configs/zero1.json
|
||||
```
|
||||
|
||||
### Usage {#sec-deepspeed-usage}
|
||||
|
||||
```{.bash}
|
||||
@@ -63,75 +66,9 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-tip}
|
||||
## FSDP {#sec-fsdp}
|
||||
|
||||
Using ZeRO Stage 3 with Single-GPU training
|
||||
|
||||
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
|
||||
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
|
||||
|
||||
:::
|
||||
|
||||
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
|
||||
|
||||
:::
|
||||
|
||||
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
|
||||
|
||||
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
|
||||
also follow the config field mapping below to update field names.
|
||||
|
||||
#### Config mapping
|
||||
|
||||
FSDP1 | FSDP2
|
||||
-------- | --------
|
||||
fsdp_sharding_strategy | reshard_after_forward
|
||||
fsdp_backward_prefetch_policy | **REMOVED**
|
||||
fsdp_backward_prefetch | **REMOVED**
|
||||
fsdp_forward_prefetch | **REMOVED**
|
||||
fsdp_sync_module_states | **REMOVED**
|
||||
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
|
||||
fsdp_state_dict_type | state_dict_type
|
||||
fsdp_use_orig_params | **REMOVED**
|
||||
|
||||
|
||||
For example, if you were using the following FSDP1 config:
|
||||
|
||||
```{.yaml}
|
||||
fsdp_version: 1
|
||||
fsdp_config:
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
```
|
||||
|
||||
You can migrate to the following FSDP2 config:
|
||||
|
||||
```{.yaml}
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
reshard_after_forward: true
|
||||
```
|
||||
|
||||
### FSDP1 (deprecated) {#sec-fsdp-config}
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
Using `fsdp` to configure FSDP is deprecated and will be removed in an upcoming release of Axolotl. Please use `fsdp_config` as above instead.
|
||||
|
||||
:::
|
||||
### Basic FSDP Configuration {#sec-fsdp-config}
|
||||
|
||||
```{.yaml}
|
||||
fsdp:
|
||||
@@ -143,7 +80,6 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
|
||||
## Sequence parallelism {#sec-sequence-parallelism}
|
||||
|
||||
We support sequence parallelism (SP) via the
|
||||
|
||||
@@ -40,13 +40,13 @@ use_cpu: false
|
||||
|
||||
Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
```yaml
|
||||
fsdp_version: 2
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
fsdp_offload_params: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
|
||||
|
||||
@@ -17,6 +17,7 @@ feedback. Various methods include, but not limited to:
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
|
||||
|
||||
|
||||
## RLHF using Axolotl
|
||||
@@ -274,14 +275,15 @@ rl: dpo
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type:
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_chosen: "chosen"
|
||||
field_rejected: "rejected"
|
||||
prompt_format: "{prompt}"
|
||||
chosen_format: "{chosen}"
|
||||
rejected_format: "{rejected}"
|
||||
type: user_defined.default
|
||||
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_chosen: "chosen"
|
||||
field_rejected: "rejected"
|
||||
prompt_format: "{prompt}"
|
||||
chosen_format: "{chosen}"
|
||||
rejected_format: "{rejected}"
|
||||
```
|
||||
|
||||
The input format is a simple JSON input with customizable fields based on the above config.
|
||||
@@ -474,13 +476,14 @@ rl: kto
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type:
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_completion: "completion"
|
||||
field_label: "label"
|
||||
prompt_format: "{prompt}"
|
||||
completion_format: "{completion}"
|
||||
type: user_defined.default
|
||||
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_completion: "completion"
|
||||
field_label: "label"
|
||||
prompt_format: "{prompt}"
|
||||
completion_format: "{completion}"
|
||||
```
|
||||
|
||||
The input format is a simple JSON input with customizable fields based on the above config.
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
# Archived Examples
|
||||
|
||||
This directory contains examples that are no longer maintained and may no longer be functional.
|
||||
|
||||
We keep them around for archival purposes in case they are useful to others.
|
||||
@@ -1,70 +0,0 @@
|
||||
# Finetune Devstral with Axolotl
|
||||
|
||||
Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505) and [Devstral-Small-2507](https://huggingface.co/mistralai/Devstral-Small-2507). `Devstral-Small-2507` is the latest version of the model and has [function calling](https://mistralai.github.io/mistral-common/usage/tools/) support.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
|
||||
|
||||
The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of up to 128k tokens.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/devstral/devstral-small-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 21GB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- Learn how to use function calling with Axolotl at [docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
||||
- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)
|
||||
|
||||
## Limitations
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)
|
||||
- [MistralAI Devstral 1.1 Blog](https://mistral.ai/news/devstral-2507)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
@@ -1,64 +0,0 @@
|
||||
base_model: mistralai/Devstral-Small-2507
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_ratio: 0.05
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,7 +0,0 @@
|
||||
# Liquid Foundation Models 2
|
||||
|
||||
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
|
||||
|
||||
```bash
|
||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
@@ -1,48 +0,0 @@
|
||||
base_model: LiquidAI/LFM2-350M
|
||||
|
||||
chunked_cross_entropy: true
|
||||
|
||||
chat_template: tokenizer_default
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
@@ -18,10 +18,16 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
2. Download the example config:
|
||||
|
||||
```bash
|
||||
axolotl fetch examples
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/magistral/magistral-small-qlora.yaml
|
||||
@@ -36,7 +42,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
@@ -48,7 +54,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
In addition, we do not support overriding tokens yet.
|
||||
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
|
||||
@@ -6,19 +6,19 @@ triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.6.0
|
||||
liger-kernel==0.5.10
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft==0.16.0
|
||||
transformers==4.53.2
|
||||
huggingface_hub==0.32.2
|
||||
peft==0.15.2
|
||||
transformers==4.53.1
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.8.1
|
||||
datasets==4.0.0
|
||||
datasets==3.6.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.19.1
|
||||
trl==0.18.2
|
||||
hf_xet==1.1.2
|
||||
|
||||
optimum==1.16.2
|
||||
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
mistral-common==1.7.0
|
||||
mistral-common==1.6.3
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"'
|
||||
)
|
||||
|
||||
15
setup.py
15
setup.py
@@ -66,16 +66,13 @@ def parse_requirements(extras_require_map):
|
||||
|
||||
if (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.30")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.31.post1")
|
||||
extras_require_map["vllm"] = ["vllm>=0.9.0"]
|
||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
# since we only support 2.6.0+cu126
|
||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||
_install_requires.append(
|
||||
"xformers==0.0.29.post2"
|
||||
) # vllm needs post2 w torch 2.6
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
@@ -121,7 +118,7 @@ extras_require = {
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.2",
|
||||
"deepspeed==0.17.1",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.12.0.dev"
|
||||
__version__ = "0.11.0.dev"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""CLI to run preprocessing of a dataset."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -96,7 +95,6 @@ def do_cli(
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.is_preprocess = True
|
||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||
|
||||
@@ -109,13 +109,6 @@ def ray_train_func(kwargs: dict):
|
||||
# initialize accelerator before model instantiation
|
||||
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
|
||||
|
||||
# Register plugins in Ray workers
|
||||
if cfg.get("plugins"):
|
||||
from axolotl.cli.config import plugin_set_cfg, prepare_plugins
|
||||
|
||||
prepare_plugins(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
kwargs["cfg"] = cfg
|
||||
|
||||
do_train(**kwargs)
|
||||
|
||||
@@ -37,6 +37,7 @@ def do_vllm_serve(
|
||||
Returns:
|
||||
process_id: the process id of the started VLLM server
|
||||
"""
|
||||
patch_vllm_worker()
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
@@ -46,9 +47,6 @@ def do_vllm_serve(
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
data_parallel_size = (
|
||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
@@ -70,7 +68,6 @@ def do_vllm_serve(
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
data_parallel_size=data_parallel_size,
|
||||
host=host,
|
||||
port=port,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
|
||||
@@ -112,6 +112,13 @@ class TrainerBuilderBase(abc.ABC):
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
steps_to_profile=self.cfg.profiler_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
@@ -138,14 +145,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
steps_to_profile=self.cfg.profiler_steps,
|
||||
profiler_steps_start=self.cfg.profiler_steps_start,
|
||||
)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -419,9 +418,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_args_kwargs["torch_compile_backend"] = (
|
||||
@@ -430,16 +426,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.torch_compile_mode:
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
if self.cfg.accelerator_config:
|
||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.activation_offloading is True:
|
||||
# don't use the HF gradient checkpointing, manually wrap
|
||||
training_args_kwargs["gradient_checkpointing"] = False
|
||||
training_args_kwargs["activation_offloading"] = True
|
||||
elif self.cfg.gradient_checkpointing:
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
@@ -513,15 +501,10 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||
training_args_kwargs["fsdp_config"] = self.cfg.fsdp_config
|
||||
training_args_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp else True
|
||||
|
||||
self._configure_reporting(training_args_kwargs)
|
||||
self._configure_hub_parameters(training_args_kwargs)
|
||||
self._configure_scheduler(training_args_kwargs)
|
||||
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
|
||||
self._configure_torch_compile(training_args_kwargs)
|
||||
self._configure_accelerator_config(training_args_kwargs)
|
||||
|
||||
return training_args_kwargs, trainer_kwargs
|
||||
|
||||
@@ -151,6 +151,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps
|
||||
)
|
||||
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = {
|
||||
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
|
||||
}
|
||||
|
||||
if self.cfg.adapter == "qlora":
|
||||
training_arguments_kwargs["qlora"] = True
|
||||
|
||||
@@ -310,6 +318,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.neftune_noise_alpha
|
||||
)
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs["accelerator_config"] = (
|
||||
self.cfg.accelerator_config
|
||||
)
|
||||
|
||||
if self.cfg.image_size:
|
||||
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||
if self.cfg.image_resize_algorithm:
|
||||
|
||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlDPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
)
|
||||
@@ -36,33 +37,23 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self, trainer_kwargs: dict):
|
||||
"""
|
||||
Returns trainer_cls and trainer_cls_args
|
||||
"""
|
||||
def _get_trainer_cls(self):
|
||||
"""Returns trainer_cls"""
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
trainer_cls_args = [] # type: ignore
|
||||
|
||||
if trainer_cls is not None:
|
||||
return trainer_cls, trainer_cls_args
|
||||
return trainer_cls
|
||||
|
||||
trainer_cls = None
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args.append(self.model_ref)
|
||||
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
@@ -72,7 +63,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
return trainer_cls, trainer_cls_args
|
||||
return trainer_cls
|
||||
|
||||
def _build_training_arguments(self, total_num_steps):
|
||||
"""
|
||||
@@ -182,7 +173,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
|
||||
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls_args.append(self.model_ref)
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters:
|
||||
@@ -190,9 +189,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (
|
||||
trainer_cls is DPOStrategy.get_trainer_class()
|
||||
):
|
||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
@@ -208,7 +205,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||
if self.cfg.fsdp:
|
||||
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||
@@ -218,3 +215,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
HF Factory class for PPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
# TODO: build PPOConfig
|
||||
raise NotImplementedError("PPO trainer builder is not implemented yet.")
|
||||
|
||||
@@ -14,4 +14,5 @@ from .trl import (
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
TRLPPOTrainer,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
ActivationOffloadingMixin,
|
||||
CheckpointSaveMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
@@ -49,7 +48,6 @@ class AxolotlTrainer(
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
ActivationOffloadingMixin,
|
||||
Trainer,
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
@@ -77,6 +75,18 @@ class AxolotlTrainer(
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _create_multipack_sampler(
|
||||
self, base_sampler: Sampler, dataset: Dataset
|
||||
) -> MultipackBatchSampler:
|
||||
|
||||
@@ -14,7 +14,6 @@ from axolotl.core.trainers.grpo.trainer import (
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -42,18 +41,9 @@ class GRPOStrategy:
|
||||
return grpo_args_kwargs
|
||||
|
||||
trl: TRLConfig = cfg.trl # type: ignore
|
||||
vllm_cfg: VllmConfig = cfg.vllm # type: ignore
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
grpo_args_kwargs["vllm_tensor_parallel_size"] = (
|
||||
vllm_cfg.tensor_parallel_size
|
||||
)
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined]
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined]
|
||||
if trl.vllm_server_timeout:
|
||||
|
||||
@@ -59,6 +59,42 @@ class AxolotlGRPOTrainer(
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def get_train_dataloader(self):
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="training"
|
||||
)
|
||||
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size
|
||||
* self.args.steps_per_generation, # < this is the change
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||
}
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
@@ -216,11 +252,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
if not is_eval:
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .activation_checkpointing import ActivationOffloadingMixin
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
"""
|
||||
Trainer mixin for activation checkpointing w offloading
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
|
||||
from torch import nn
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from transformers import GradientCheckpointingLayer, Trainer
|
||||
from trl.models.activation_offloading import get_act_offloading_ctx_manager
|
||||
|
||||
|
||||
class ActivationOffloadingMixin(Trainer):
|
||||
"""
|
||||
Trainer mixin class for activation checkpointing w offloading
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.args.activation_offloading:
|
||||
self.activation_offload_context = get_act_offloading_ctx_manager(
|
||||
self.model, use_streams=True
|
||||
)
|
||||
else:
|
||||
self.activation_offload_context = contextlib.nullcontext()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
with self.activation_offload_context:
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
||||
|
||||
def ac_wrap_hf_model(model: nn.Module, **kwargs):
|
||||
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
|
||||
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Module for TRL RL trainers"""
|
||||
"""Module for TRL PPO trainer"""
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
@@ -13,6 +16,64 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
|
||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||
|
||||
|
||||
class TRLPPOTrainer(PPOTrainer):
|
||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||
|
||||
tag_names = ["axolotl", "ppo"]
|
||||
|
||||
def train(
|
||||
self,
|
||||
reward_pipe,
|
||||
resume_from_checkpoint=None, # pylint: disable=unused-argument
|
||||
):
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": self.tokenizer.eos_token_id,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
sent_kwargs = {
|
||||
"return_all_scores": True,
|
||||
"function_to_apply": "none",
|
||||
"batch_size": 16,
|
||||
}
|
||||
|
||||
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# generate model response
|
||||
response_tensors, ref_response_tensors = self.generate(
|
||||
query_tensors,
|
||||
return_prompt=False,
|
||||
generate_ref_response=True,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
||||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
||||
|
||||
# Compute sentiment score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = reward_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
||||
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
|
||||
ref_rewards = [
|
||||
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
|
||||
]
|
||||
batch["ref_rewards"] = ref_rewards
|
||||
|
||||
# Run PPO step
|
||||
stats = self.step(query_tensors, response_tensors, rewards)
|
||||
self.log_stats(
|
||||
stats,
|
||||
batch,
|
||||
rewards,
|
||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
|
||||
):
|
||||
|
||||
@@ -217,11 +217,6 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
activation_offloading: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Use activation offloading with CUDA streams for training."},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
|
||||
@@ -48,6 +48,13 @@ class TokenizedPromptDataset(Dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||
|
||||
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
|
||||
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
|
||||
LOG.info(
|
||||
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
|
||||
)
|
||||
num_proc = 1
|
||||
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
|
||||
kd_alpha: 0.9
|
||||
kd_temperature: 1.0
|
||||
|
||||
torch_compile: True # torch>=2.6.0, recommended to reduce vram
|
||||
torch_compile: True # torch>=2.5.1, recommended to reduce vram
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
|
||||
@@ -6,21 +6,15 @@ from typing import Optional, Union, Unpack
|
||||
|
||||
import torch
|
||||
from transformers import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import LossKwargs
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.utils import LossKwargs
|
||||
|
||||
class TransformersKwargs(FlashAttentionKwargs, LossKwargs):
|
||||
"""
|
||||
placeholder kwargs for hf model classes
|
||||
"""
|
||||
|
||||
except ImportError:
|
||||
from transformers.utils.generic import ( # type: ignore[no-redef]
|
||||
TransformersKwargs,
|
||||
)
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
||||
"""
|
||||
placeholder kwargs for hf model classes
|
||||
"""
|
||||
|
||||
|
||||
def kldiv_forward_llama_like(
|
||||
@@ -39,7 +33,7 @@ def kldiv_forward_llama_like(
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
|
||||
**kwargs: Unpack[TransformersKwargs], # type: ignore[misc]
|
||||
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
|
||||
) -> CausalLMOutputWithPast:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
|
||||
@@ -122,9 +122,9 @@ def load_lora(
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if (
|
||||
cfg.fsdp_config
|
||||
cfg.fsdp
|
||||
and cfg.adapter
|
||||
and cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and rank != 0
|
||||
):
|
||||
setup_quantized_meta_for_peft(model)
|
||||
@@ -152,9 +152,9 @@ def load_lora(
|
||||
"Exception caught during model.print_trainable_parameters(): %s", exc
|
||||
)
|
||||
elif (
|
||||
cfg.fsdp_config
|
||||
cfg.fsdp
|
||||
and cfg.adapter
|
||||
and cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and rank != 0
|
||||
):
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
@@ -140,15 +140,10 @@ class ModelLoader:
|
||||
"""Check if flash attention is installed."""
|
||||
return find_spec("flash_attn") is not None
|
||||
|
||||
@property
|
||||
def is_fsdp_enabled(self):
|
||||
"""Property that determines if FSDP is enabled."""
|
||||
return self.cfg.fsdp_config is not None or self.cfg.fsdp is not None
|
||||
|
||||
@property
|
||||
def is_qlora_and_fsdp_enabled(self):
|
||||
@cached_property
|
||||
def qlora_fsdp(self):
|
||||
"""Property that determines if FSDP with QLoRA is enabled."""
|
||||
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
|
||||
return self.cfg.fsdp and self.cfg.adapter == "qlora"
|
||||
|
||||
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
||||
"""Load and prepare the model with all configurations and patches.
|
||||
@@ -194,25 +189,15 @@ class ModelLoader:
|
||||
# Handle PeftModel if needed
|
||||
if (
|
||||
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
and not self.qlora_fsdp
|
||||
):
|
||||
self.model = self.model.merge_and_unload()
|
||||
|
||||
self._apply_activation_checkpointing()
|
||||
self._resize_token_embeddings()
|
||||
self._adjust_model_config()
|
||||
self._log_memory_usage()
|
||||
self._configure_embedding_dtypes()
|
||||
self._configure_qat()
|
||||
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
|
||||
|
||||
def _apply_activation_checkpointing(self):
|
||||
if self.cfg.activation_offloading is True:
|
||||
from axolotl.core.trainers.mixins.activation_checkpointing import (
|
||||
ac_wrap_hf_model,
|
||||
)
|
||||
|
||||
# ^^ importing this at the module level breaks plugins
|
||||
ac_wrap_hf_model(self.model)
|
||||
|
||||
def _resize_token_embeddings(self):
|
||||
"""Resize token embeddings if needed."""
|
||||
@@ -266,13 +251,22 @@ class ModelLoader:
|
||||
):
|
||||
self.model.config.eos_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
def _log_memory_usage(self):
|
||||
"""Log device memory usage after model load."""
|
||||
if hasattr(self.model, "device") and self.model.device.type in (
|
||||
"cuda",
|
||||
"mps",
|
||||
"npu",
|
||||
):
|
||||
log_gpu_memory_usage(LOG, "after model load", self.model.device)
|
||||
|
||||
def _configure_embedding_dtypes(self):
|
||||
"""Configure embedding module dtypes."""
|
||||
# Get embedding modules
|
||||
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
|
||||
|
||||
# Initial dtype conversion
|
||||
if not self.is_fsdp_enabled:
|
||||
if not self.cfg.fsdp:
|
||||
# We don't run this during FSDP because this will leave mixed and bfloat16
|
||||
# dtypes in the model which FSDP doesn't like
|
||||
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
|
||||
@@ -288,7 +282,7 @@ class ModelLoader:
|
||||
self._set_z3_leaf_modules()
|
||||
|
||||
# Apply gradient checkpointing if needed
|
||||
needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled
|
||||
needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp
|
||||
if self.cfg.adapter in ["lora", "qlora"]:
|
||||
needs_fa2_dtype = True
|
||||
if self.cfg.gradient_checkpointing:
|
||||
@@ -304,12 +298,10 @@ class ModelLoader:
|
||||
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
(
|
||||
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
)
|
||||
or (
|
||||
# CCE requires embedding layers to be in fp16/bf16 for backward pass
|
||||
self.cfg.cut_cross_entropy
|
||||
and not self.qlora_fsdp
|
||||
)
|
||||
# CCE requires embedding layers to be in fp16/bf16 for backward pass
|
||||
or self.cfg.cut_cross_entropy
|
||||
)
|
||||
|
||||
if should_convert:
|
||||
@@ -365,6 +357,7 @@ class ModelLoader:
|
||||
and not (self.cfg.rl and self.cfg.load_in_4bit)
|
||||
and not skip_move_to_device
|
||||
):
|
||||
# TODO: validate this conditional
|
||||
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
|
||||
|
||||
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
@@ -437,17 +430,7 @@ class ModelLoader:
|
||||
|
||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||
|
||||
is_ds_zero3 = is_deepspeed_zero3_enabled()
|
||||
|
||||
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
|
||||
if self.is_fsdp_enabled:
|
||||
# For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization
|
||||
if self.is_qlora_and_fsdp_enabled:
|
||||
self.model_kwargs["device_map"] = {
|
||||
"": int(os.environ.get("LOCAL_RANK", 0))
|
||||
}
|
||||
# For other FSDP cases, don't set device_map at all
|
||||
elif not is_ds_zero3:
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
|
||||
cur_device = get_device_type()
|
||||
@@ -516,7 +499,7 @@ class ModelLoader:
|
||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||
}
|
||||
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
||||
self.cfg.deepspeed or self.is_fsdp_enabled
|
||||
self.cfg.deepspeed or self.cfg.fsdp
|
||||
):
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
@@ -621,21 +604,9 @@ class ModelLoader:
|
||||
def _build_model(self) -> bool:
|
||||
"""Load model, with load strategy depending on config."""
|
||||
skip_move_to_device = False
|
||||
if self.is_fsdp_enabled:
|
||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
|
||||
if (
|
||||
"device_map" in self.model_kwargs
|
||||
and not self.is_qlora_and_fsdp_enabled
|
||||
):
|
||||
del self.model_kwargs["device_map"]
|
||||
elif self.is_qlora_and_fsdp_enabled:
|
||||
skip_move_to_device = True
|
||||
|
||||
if (
|
||||
self.is_qlora_and_fsdp_enabled
|
||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
self.qlora_fsdp
|
||||
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
and (
|
||||
self.cfg.model_config_type == "dbrx"
|
||||
or self.cfg.qlora_sharded_model_loading
|
||||
@@ -661,6 +632,12 @@ class ModelLoader:
|
||||
and not self.cfg.trust_remote_code
|
||||
and not self.cfg.gptq
|
||||
):
|
||||
# TODO: Do we need to open this up for all models?
|
||||
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||
skip_move_to_device = True
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"]
|
||||
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
|
||||
@@ -714,22 +691,33 @@ class ModelLoader:
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
elif self.cfg.gptq:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
if self.cfg.gptq:
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
self.cfg.fsdp
|
||||
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||
):
|
||||
# disabling either of these two still leads to VRAM spike before setting back down
|
||||
skip_move_to_device = True
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"]
|
||||
|
||||
# Please don't remove underscore binding without reading the fn docstring.
|
||||
_ = self._configure_zero3_memory_efficient_loading()
|
||||
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
config=self.model_config,
|
||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||
**self.model_kwargs,
|
||||
)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_move_to_device = True
|
||||
|
||||
@@ -765,8 +753,8 @@ class ModelLoader:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if (
|
||||
self.is_qlora_and_fsdp_enabled
|
||||
or (self.is_fsdp_enabled and self.cfg.fsdp_config.cpu_ram_efficient_loading)
|
||||
self.qlora_fsdp
|
||||
or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
|
||||
or is_deepspeed_zero3_enabled()
|
||||
):
|
||||
# Make sure everything is in the same dtype
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user