Compare commits
9 Commits
codecov-pu
...
mistral-su
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f39aeefb9 | ||
|
|
8f75136ad3 | ||
|
|
70e9cb545d | ||
|
|
aa236a4669 | ||
|
|
65f8988efd | ||
|
|
13ddb8f172 | ||
|
|
b1570ed0fa | ||
|
|
9581a9efed | ||
|
|
7e44445494 |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -16,7 +16,6 @@ on:
|
||||
jobs:
|
||||
build-base:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
timeout-minutes: 480
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: ubuntu-latest-m
|
||||
strategy:
|
||||
@@ -48,14 +47,14 @@ jobs:
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
@@ -107,7 +106,6 @@ jobs:
|
||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||
build-base-uv:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
timeout-minutes: 480
|
||||
runs-on: ubuntu-latest-m
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -124,7 +122,7 @@ jobs:
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
|
||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter quartodoc
|
||||
python3 -m pip install -e .
|
||||
python3 -m pip install -e . --no-deps
|
||||
- name: Build autodoc
|
||||
run: quartodoc build
|
||||
- name: Publish to GitHub Pages (and render)
|
||||
|
||||
8
.github/workflows/main.yml
vendored
8
.github/workflows/main.yml
vendored
@@ -29,12 +29,12 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -97,12 +97,12 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
|
||||
6
.github/workflows/preview-docs.yml
vendored
6
.github/workflows/preview-docs.yml
vendored
@@ -8,9 +8,7 @@ on:
|
||||
paths:
|
||||
- '**/*.md' # any Markdown file
|
||||
- '**/*.qmd' # any Quarto file
|
||||
- '_quarto.yml'
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- src/axolotl/utils/schemas/**.py
|
||||
- '_quarto.yaml'
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
@@ -40,7 +38,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter quartodoc
|
||||
python3 -m pip install -e .
|
||||
python3 -m pip install -e . --no-deps
|
||||
|
||||
- name: Build autodoc
|
||||
run: quartodoc build
|
||||
|
||||
62
.github/workflows/tests.yml
vendored
62
.github/workflows/tests.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -106,12 +106,13 @@ jobs:
|
||||
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
- name: Upload coverage artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
name: coverage-${{ matrix.pytorch_version }}-${{ github.run_id }}
|
||||
path: ./coverage.xml
|
||||
retention-days: 1
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -124,7 +125,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -187,7 +188,7 @@ jobs:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
|
||||
strategy:
|
||||
@@ -233,19 +234,11 @@ jobs:
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
- name: Upload coverage artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-e2e-1st-${{ github.run_id }}
|
||||
path: ./e2e-coverage.xml
|
||||
retention-days: 1
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 90
|
||||
# Only run the remainder of the matrix if the first e2e check passed;
|
||||
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
@@ -269,13 +262,13 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -304,14 +297,6 @@ jobs:
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
- name: Upload coverage artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-e2e-${{ matrix.cuda }}-${{ matrix.pytorch }}-${{ github.run_id }}
|
||||
path: ./e2e-coverage.xml
|
||||
retention-days: 1
|
||||
|
||||
docker-e2e-cleanup:
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
@@ -351,26 +336,3 @@ jobs:
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.cleanup
|
||||
|
||||
upload-coverage:
|
||||
name: Upload Coverage to Codecov
|
||||
runs-on: ubuntu-latest
|
||||
needs: [pytest, docker-e2e-tests, docker-e2e-tests-1st]
|
||||
if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main'
|
||||
|
||||
steps:
|
||||
- name: Download coverage reports
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: coverage-reports
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
directory: coverage-reports
|
||||
fail_ci_if_error: false
|
||||
verbose: true
|
||||
name: codecov-umbrella
|
||||
override_commit: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
override_pr: ${{ github.event.pull_request.number }}
|
||||
|
||||
@@ -328,7 +328,7 @@ The following optimizers are supported:
|
||||
- Use `gradient_checkpointing: true` to reduce memory usage
|
||||
- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory
|
||||
|
||||
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config-reference.html).
|
||||
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html).
|
||||
|
||||
### Errors:
|
||||
|
||||
|
||||
76
README.md
76
README.md
@@ -22,32 +22,28 @@
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||
</p>
|
||||
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
||||
|
||||
## ✨ Overview
|
||||
|
||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||
Post-training refers to any modifications or additional training performed on
|
||||
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
|
||||
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
|
||||
techniques. With support for multiple model architectures and training configurations,
|
||||
Axolotl makes it easy to get started with these techniques.
|
||||
|
||||
Axolotl is designed to work with YAML config files that contain everything you need to
|
||||
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
|
||||
and much more.
|
||||
|
||||
Features:
|
||||
|
||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
|
||||
|
||||
- Train various Huggingface models such as llama, pythia, falcon, mpt
|
||||
- Supports fullfinetune, lora, qlora, relora, and gptq
|
||||
- Customize configurations using a simple yaml file or CLI overwrite
|
||||
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
|
||||
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||
- Easily run with Docker locally or on the cloud
|
||||
- Log results and optionally checkpoints to wandb, mlflow or Comet
|
||||
- And more!
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
@@ -85,12 +81,19 @@ axolotl train examples/llama-3/lora-1b.yml
|
||||
|
||||
That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough.
|
||||
|
||||
## ✨ Key Features
|
||||
|
||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
|
||||
- **Easy Configuration**: Simple YAML files to control your training setup
|
||||
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
|
||||
- **Flexible Dataset Handling**: Use various formats and custom datasets
|
||||
- **Cloud Ready**: Run on cloud platforms or local hardware
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments
|
||||
- [Configuration Guide](https://docs.axolotl.ai/docs/config-reference.html) - Full configuration options and examples
|
||||
- [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources
|
||||
- [Configuration Guide](https://docs.axolotl.ai/docs/config.html) - Full configuration options and examples
|
||||
- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
@@ -109,6 +112,31 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
||||
|
||||
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
|
||||
|
||||
## Supported Models
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
||||
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
||||
|
||||
✅: supported
|
||||
❌: not supported
|
||||
❓: untested
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
Thank you to our sponsors who help make Axolotl possible:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
project:
|
||||
type: website
|
||||
pre-render: docs/scripts/generate_config_docs.py
|
||||
|
||||
quartodoc:
|
||||
dir: docs/api
|
||||
@@ -236,7 +235,7 @@ website:
|
||||
- docs/installation.qmd
|
||||
- docs/inference.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/config-reference.qmd
|
||||
- docs/config.qmd
|
||||
- text: "API Reference"
|
||||
href: docs/api
|
||||
|
||||
|
||||
@@ -51,3 +51,5 @@ pytest -v --durations=10 \
|
||||
--cov=axolotl \
|
||||
--cov-append \
|
||||
--cov-report=xml:e2e-coverage.xml
|
||||
|
||||
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true
|
||||
|
||||
@@ -1,34 +1,20 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
import pathlib
|
||||
|
||||
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=120 * 60, # 90 min
|
||||
timeout=90 * 60, # 90 min
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
def cicd_pytest():
|
||||
|
||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||
|
||||
# Read the coverage file if it exists
|
||||
coverage_file = pathlib.Path("/workspace/axolotl/e2e-coverage.xml")
|
||||
if coverage_file.exists():
|
||||
return coverage_file.read_text(encoding="utf-8")
|
||||
return None
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
coverage = cicd_pytest.remote()
|
||||
|
||||
# Save the coverage file to the local filesystem if it was generated
|
||||
if coverage:
|
||||
with open("e2e-coverage.xml", "w", encoding="utf-8") as f:
|
||||
f.write(coverage)
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -69,7 +69,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=120 * 60,
|
||||
timeout=90 * 60,
|
||||
cpu=16.0,
|
||||
memory=131072 * N_GPUS,
|
||||
volumes=VOLUME_CONFIG,
|
||||
@@ -77,18 +77,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
def cicd_pytest():
|
||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||
|
||||
# Read the coverage file if it exists
|
||||
coverage_file = pathlib.Path("/workspace/axolotl/multigpu-coverage.xml")
|
||||
if coverage_file.exists():
|
||||
return coverage_file.read_text(encoding="utf-8")
|
||||
return None
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
coverage = cicd_pytest.remote()
|
||||
|
||||
# Save the coverage file to the local filesystem if it was generated
|
||||
if coverage:
|
||||
with open("multigpu-coverage.xml", "w", encoding="utf-8") as file:
|
||||
file.write(coverage)
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
{
|
||||
"compile": {
|
||||
"disable": false,
|
||||
"backend": "inductor"
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"contiguous_gradients": true,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
||||
pip3 install flash-attn==2.7.4.post1; \
|
||||
fi
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
|
||||
@@ -29,12 +29,8 @@ RUN uv venv --no-project --relocatable axolotl-venv
|
||||
|
||||
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
RUN uv pip install packaging setuptools wheel \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
|
||||
fi
|
||||
|
||||
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@@ -2,4 +2,3 @@
|
||||
_site/
|
||||
/api/*.qmd
|
||||
/api/*.html
|
||||
config-reference.qmd
|
||||
|
||||
795
docs/config.qmd
Normal file
795
docs/config.qmd
Normal file
@@ -0,0 +1,795 @@
|
||||
---
|
||||
title: Config Reference
|
||||
description: A complete list of all configuration options.
|
||||
---
|
||||
|
||||
```yaml
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# This can also be a relative path to a model on disk
|
||||
base_model: ./llama-7b-hf
|
||||
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# If the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# You can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# You can specify to choose a specific model revision from huggingface hub
|
||||
revision_of_model:
|
||||
# Optional tokenizer configuration path in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
tokenizer_config:
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
model_type: AutoModelForCausalLM
|
||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Trust remote code for untrusted source
|
||||
trust_remote_code:
|
||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
|
||||
embeddings_skip_upcast:
|
||||
# Whether to load the model with randomly initialized weights. Useful for
|
||||
# pre-training a model from scratch or debugging purposes.
|
||||
random_init_weights:
|
||||
|
||||
# (Internal use only)
|
||||
# Used to identify which the model is based on
|
||||
is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
is_qwen_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
|
||||
# optional overrides to the base model configuration
|
||||
overrides_of_model_config:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# optional overrides the base model loading from_pretrained
|
||||
overrides_of_model_kwargs:
|
||||
# use_cache: False
|
||||
|
||||
# optional overrides to the bnb 4bit quantization configuration
|
||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||
bnb_config_kwargs:
|
||||
# These are default values
|
||||
llm_int8_has_fp16_weight: false
|
||||
bnb_4bit_quant_type: nf4
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
# quantization aware training
|
||||
qat:
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
|
||||
|
||||
# post-training quantization
|
||||
quantization:
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
|
||||
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
|
||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# Use bitsandbytes 4 bit
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
|
||||
# Use CUDA fp16
|
||||
fp16: true
|
||||
# Use CUDA tf32
|
||||
tf32: true # require >=ampere
|
||||
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
|
||||
|
||||
# No AMP (automatic mixed precision)
|
||||
bfloat16: true # require >=ampere
|
||||
float16: true
|
||||
|
||||
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
|
||||
gpu_memory_limit: 20GiB
|
||||
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
||||
lora_on_cpu: true
|
||||
|
||||
# List[str]. Add plugins to extend the pipeline.
|
||||
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html
|
||||
plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets
|
||||
# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3:// | gs:// | path to local file or directory
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
|
||||
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
|
||||
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
|
||||
|
||||
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
|
||||
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
||||
|
||||
# Custom user instruction prompt
|
||||
- path: repo
|
||||
type:
|
||||
# The below are defaults. only set what's needed if you use a different column name.
|
||||
system_prompt: ""
|
||||
system_format: "{system}"
|
||||
field_system: system
|
||||
field_instruction: instruction
|
||||
field_input: input
|
||||
field_output: output
|
||||
|
||||
# Customizable to be single line or multi-line
|
||||
# Use {instruction}/{input} as key to be replaced
|
||||
# 'format' can include {input}
|
||||
format: |-
|
||||
User: {instruction} {input}
|
||||
Assistant:
|
||||
# 'no_input_format' cannot include {input}
|
||||
no_input_format: "{instruction} "
|
||||
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# Using chat template
|
||||
- path: ...
|
||||
# Set type to `chat_template` to use this strategy
|
||||
type: chat_template
|
||||
# Specify the name of the chat template to use
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
|
||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
chat_template: tokenizer_default
|
||||
|
||||
# Custom jinja chat template. Used only if `chat_template: jinja` or empty.
|
||||
chat_template_jinja:
|
||||
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
|
||||
# Key containing the system message (default: "system")
|
||||
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
|
||||
field_system: system
|
||||
|
||||
# Mapping of properties from the input dataset to the chat template.
|
||||
# (default: message_property_mappings={'role':'role', 'content':'content'})
|
||||
# If a property exists in the template but not in this mapping, the system will attempt
|
||||
# to load it directly from the message using the property name as the key.
|
||||
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
|
||||
# while 'value' is loaded and used as 'content' in the chat template.
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
# ...
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages.
|
||||
# The format is {target_role: [source_roles]}. All source roles will be mapped to the target role.
|
||||
# The default is:
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
assistant: ["gpt", "assistant"]
|
||||
system: ["system"]
|
||||
tool: ["tool"]
|
||||
|
||||
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
|
||||
# This does not drop the default system message from chat_template if it exists. If you wish to,
|
||||
# we recommend using a custom jinja template with the default system message removed or
|
||||
# adding a system turn with empty content.
|
||||
drop_system_message:
|
||||
|
||||
# Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags
|
||||
# See example at `docs/dataset-formats/conversation.qmd`
|
||||
split_thinking:
|
||||
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||
# Note: If the below 5 fields are empty, defaults to training only on the last message.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["assistant"] # default
|
||||
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
|
||||
# - all: train on all EOS tokens
|
||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
train_on_eos: turn
|
||||
# Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are:
|
||||
# - all: train on all EOT tokens
|
||||
# - turn: train on the EOT token at the end of each trainable turn
|
||||
# - last: train on the last EOT token in the conversation
|
||||
# If not specified, defaults to the value of train_on_eos for backward compatibility.
|
||||
train_on_eot:
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
||||
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
|
||||
message_field_training_detail: train_detail
|
||||
|
||||
|
||||
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
|
||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
||||
shuffle_merged_datasets: true
|
||||
|
||||
# Deduplicates datasets and test_datasets with identical entries.
|
||||
dataset_exact_deduplication: true
|
||||
|
||||
# A list of one or more datasets to eval the model with.
|
||||
# You can use either test_datasets, or val_set_size, but not both.
|
||||
test_datasets:
|
||||
- path: /workspace/data/eval.jsonl
|
||||
ds_type: json
|
||||
# You need to specify a split. For "json" datasets the default split is called "train".
|
||||
split: train
|
||||
type: completion
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
||||
rl:
|
||||
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
||||
|
||||
# dpo
|
||||
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
||||
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
||||
|
||||
# orpo
|
||||
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
||||
|
||||
# kto
|
||||
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
||||
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
||||
|
||||
# simpo
|
||||
cpo_alpha: 1.0 # Weight of the BC regularizer
|
||||
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||
|
||||
# grpo
|
||||
trl:
|
||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
|
||||
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
|
||||
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
|
||||
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
|
||||
|
||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||
|
||||
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
||||
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
||||
|
||||
num_generations: # Optional[int]. Number of generations to sample.
|
||||
log_completions: # Optional[bool]. Whether to log completions.
|
||||
num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True.
|
||||
|
||||
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
||||
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
||||
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
||||
scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation.
|
||||
|
||||
temperature: # Optional[float]. Sampling temperature for the GRPO policy.
|
||||
top_p: # Optional[float]. Top-p sampling probability for the generation policy.
|
||||
top_k: # Optional[int]. Top-k sampling for the generation policy.
|
||||
min_p: # Optional[float]. Minimum probability for the generation policy.
|
||||
repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text.
|
||||
|
||||
num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO.
|
||||
epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm.
|
||||
epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm.
|
||||
use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO.
|
||||
loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.
|
||||
mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation.
|
||||
|
||||
|
||||
# reward modelling: `True` or `False`
|
||||
reward_model:
|
||||
|
||||
# process reward modelling: `True` or `False`
|
||||
process_reward_model:
|
||||
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
|
||||
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
|
||||
chat_template: tokenizer_default
|
||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||
chat_template_jinja: null
|
||||
# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training.
|
||||
# These tokens mark the boundaries between conversation turns.
|
||||
# For example: ["/INST", "</s>", "[/SYSTEM_PROMPT]"]
|
||||
# If not specified, defaults to just the model's eos_token.
|
||||
# This is useful for templates that use multiple delimiter tokens.
|
||||
eot_tokens:
|
||||
# - "</s>"
|
||||
# - "[/INST]"
|
||||
# - "[/SYSTEM_PROMPT]"
|
||||
# Changes the default system message
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# Push prepared dataset to hub
|
||||
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
dataset_processes: # defaults to os.cpu_count() if not set
|
||||
# Keep dataset in memory while preprocessing
|
||||
# Only needed if cached dataset is taking too much storage
|
||||
dataset_keep_in_memory:
|
||||
# push checkpoints to hub
|
||||
hub_model_id: # private repo path to push finetuned model
|
||||
# how to push checkpoints to hub
|
||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
hub_strategy:
|
||||
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# Required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||
val_set_size: 0.04
|
||||
# Num shards for whole dataset
|
||||
dataset_shard_num:
|
||||
# Index of shard to use for whole dataset
|
||||
dataset_shard_idx:
|
||||
|
||||
# The maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# Pad inputs so each step uses constant sized buffers
|
||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
pad_to_sequence_len:
|
||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
||||
eval_sample_packing:
|
||||
# You can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
# Increasing the following values helps with packing, but usually only slightly (<%1.)
|
||||
# The number of samples packed at a time.
|
||||
sample_packing_group_size: 100000
|
||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||
sample_packing_bin_size: 200
|
||||
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
|
||||
|
||||
# whether to concatenate samples during pretraining
|
||||
pretraining_sample_concatenation:
|
||||
|
||||
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
|
||||
|
||||
# Use batch flattening for speedups when not using sample_packing
|
||||
batch_flattening:
|
||||
|
||||
# Passed through to transformers when loading the model when launched without accelerate
|
||||
# Use `sequential` when training w/ model parallelism to limit memory
|
||||
device_map:
|
||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
||||
max_memory:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `output_dir`.
|
||||
# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`.
|
||||
lora_model_dir:
|
||||
|
||||
# LoRA hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
# - gate_proj
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # If true, will target all linear modules
|
||||
|
||||
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
|
||||
peft_layers_to_transform:
|
||||
|
||||
# Optional[bool]. Whether to use DoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
|
||||
peft_use_dora:
|
||||
|
||||
# Optional[bool]. Whether to use RSLoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
|
||||
peft_use_rslora:
|
||||
|
||||
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
|
||||
peft_layer_replication:
|
||||
|
||||
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
|
||||
# How to initialize LoRA weights. Default to True which is MS original implementation.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
|
||||
peft_init_lora_weights:
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||
lora_modules_to_save:
|
||||
# - embed_tokens
|
||||
# - lm_head
|
||||
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# Apply custom LoRA autograd functions and activation function Triton kernels for
|
||||
# speed and memory savings
|
||||
# See: https://docs.axolotl.ai/docs/lora_optims.html
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
# LoRA+ hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
|
||||
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
|
||||
|
||||
peft:
|
||||
# Configuration options for loftq initialization for LoRA
|
||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
||||
loftq_config:
|
||||
loftq_bits: # typically 4 bits
|
||||
|
||||
# ReLoRA configuration
|
||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # Number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # Number of per-restart warmup steps
|
||||
relora_anneal_steps: # Number of anneal steps for each relora cycle
|
||||
relora_prune_ratio: # threshold for optimizer magnitude when pruning
|
||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
|
||||
# wandb configuration if you're using it
|
||||
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_name: # Set the name of your wandb run
|
||||
wandb_run_id: # Set the ID of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# mlflow configuration if you're using it
|
||||
mlflow_tracking_uri: # URI to mlflow
|
||||
mlflow_experiment_name: # Your experiment name
|
||||
mlflow_run_name: # Your run name
|
||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||
|
||||
# Comet configuration if you're using it
|
||||
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
|
||||
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
|
||||
use_comet: # Enable or disable Comet integration.
|
||||
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
|
||||
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
|
||||
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
|
||||
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
|
||||
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
|
||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
||||
|
||||
# Tensorboard
|
||||
use_tensorboard: # Optional[bool]
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
# Whether to use torch.compile and which backend to use
|
||||
# setting to `auto` will enable torch compile when torch>=2.5.1
|
||||
torch_compile: # Optional[Union[Literal["auto"], bool]]
|
||||
torch_compile_backend: # Optional[str]
|
||||
torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune'
|
||||
|
||||
# Training hyperparameters
|
||||
|
||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 4
|
||||
warmup_steps: 100 # cannot use with warmup_ratio
|
||||
warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
|
||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
|
||||
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
|
||||
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
|
||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
max_steps:
|
||||
|
||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||
include_tokens_per_second: # Optional[bool]
|
||||
|
||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
||||
auto_find_batch_size: # Optional[bool]
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
|
||||
# snapshots can be visualized @ https://pytorch.org/memory_viz
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
# Save model as safetensors (require safetensors package). Default True
|
||||
save_safetensors:
|
||||
|
||||
# Whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# Group similarly sized data to minimize padding.
|
||||
# May be slower to start, as it must download and sort the entire dataset.
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
|
||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||
# gradient_checkpointing_kwargs:
|
||||
# use_reentrant: true
|
||||
|
||||
# Stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
# Valid values are driven by the Transformers SchedulerType class, see:
|
||||
# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420
|
||||
# Valid values include
|
||||
# - 'linear'
|
||||
# - 'cosine' (default)
|
||||
# - 'cosine_with_restarts'
|
||||
# - 'polynomial'
|
||||
# - 'constant'
|
||||
# - 'constant_with_warmup'
|
||||
# - 'inverse_sqrt'
|
||||
# - 'reduce_lr_on_plateau'
|
||||
# - 'cosine_with_min_lr'
|
||||
# - 'warmup_stable_decay'
|
||||
|
||||
# Additional schedulers include:
|
||||
# - 'one_cycle'
|
||||
# - 'rex'
|
||||
lr_scheduler:
|
||||
lr_scheduler_kwargs:
|
||||
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
||||
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
||||
|
||||
# For one_cycle optim
|
||||
lr_div_factor: # Learning rate div factor
|
||||
|
||||
# Specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
||||
#
|
||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||
# in the examples/ for your model and fine-tuning use case.
|
||||
#
|
||||
# Valid values for 'optimizer' include:
|
||||
# - adamw_torch
|
||||
# - adamw_torch_fused (default)
|
||||
# - adamw_torch_xla
|
||||
# - adamw_torch_npu_fused
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - adamw_torch_4bit
|
||||
# - ademamix
|
||||
# - sgd
|
||||
# - adagrad
|
||||
# - adamw_bnb_8bit
|
||||
# - adamw_8bit # alias for adamw_bnb_8bit
|
||||
# - ademamix_8bit
|
||||
# - lion_8bit
|
||||
# - lion_32bit
|
||||
# - paged_adamw_32bit
|
||||
# - paged_adamw_8bit
|
||||
# - paged_ademamix_32bit
|
||||
# - paged_ademamix_8bit
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
# - rmsprop
|
||||
# - rmsprop_bnb
|
||||
# - rmsprop_bnb_8bit
|
||||
# - rmsprop_bnb_32bit
|
||||
# - galore_adamw
|
||||
# - galore_adamw_8bit
|
||||
# - galore_adafactor
|
||||
# - galore_adamw_layerwise
|
||||
# - galore_adamw_8bit_layerwise
|
||||
# - galore_adafactor_layerwise
|
||||
# - lomo
|
||||
# - adalomo
|
||||
# - grokadamw
|
||||
# - schedule_free_adamw
|
||||
# - schedule_free_sgd
|
||||
# - apollo_adamw
|
||||
# - apollo_adamw_layerwise
|
||||
#
|
||||
# Additional custom optimizers include:
|
||||
# - optimi_adamw
|
||||
# - ao_adamw_8bit
|
||||
# - ao_adamw_fp8
|
||||
# - came_pytorch
|
||||
optimizer:
|
||||
# Dictionary of arguments to pass to the optimizer
|
||||
optim_args:
|
||||
# For Galore Optimizers the following optim_args are available
|
||||
# rank: # type: int
|
||||
# update_proj_gap # type: int
|
||||
# scale # type: float
|
||||
# proj_type: # type: str, default = std
|
||||
|
||||
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
||||
optim_target_modules:
|
||||
# - self_attn # for llama
|
||||
# - mlp
|
||||
|
||||
# Specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
adam_beta2:
|
||||
adam_beta3: # only used for CAME Optimizer
|
||||
adam_epsilon:
|
||||
adam_epsilon2: # only used for CAME Optimizer
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# Augmentation techniques
|
||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# currently only supported on Llama and Mistral
|
||||
neftune_noise_alpha:
|
||||
|
||||
# Optional[bool]. Whether to bettertransformers
|
||||
flash_optimum:
|
||||
|
||||
# Note: Only one of the following attention patches can be used at a time.
|
||||
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
|
||||
|
||||
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
|
||||
# Optional[bool]. Whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
s2_attention:
|
||||
|
||||
# Optional[bool]. Whether to use low_cpu_mem_usage
|
||||
low_cpu_mem_usage:
|
||||
# Optional[str]. Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
## Multimodal section
|
||||
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
||||
# Will read from model/processor config if not set.
|
||||
image_size:
|
||||
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
||||
image_resize_algorithm: 'bilinear'
|
||||
## End of multimodal section
|
||||
|
||||
# Don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
|
||||
# Add or change special tokens.
|
||||
# If you add tokens here, you don't need to add them to the `tokens` list.
|
||||
special_tokens:
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
# pad_token: "[PAD]"
|
||||
|
||||
# Optional[list[str]]. Add extra tokens to the tokenizer.
|
||||
tokens:
|
||||
# - "<|startoftext|>"
|
||||
# - "<|endoftext|>"
|
||||
|
||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||
# Can be checked if they exist in tokenizer.json added_tokens.
|
||||
added_tokens_overrides: # Dict[int, str]
|
||||
# 128041: "<|im_start|>"
|
||||
# 128042: "<|im_end|>"
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
# Deepspeed config path. e.g., deepspeed_configs/zero3.json
|
||||
deepspeed:
|
||||
|
||||
# Advanced DDP Arguments
|
||||
ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Sequence parallelism
|
||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
|
||||
# in the sample packing case, and "batch_ring" in the non-sample packing case.
|
||||
ring_attn_func:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
||||
pretraining_dataset:
|
||||
|
||||
# Debug mode
|
||||
debug:
|
||||
|
||||
# Seed
|
||||
seed:
|
||||
|
||||
# Allow overwrite yml config using from cli
|
||||
strict:
|
||||
```
|
||||
@@ -12,7 +12,7 @@ Chat Template strategy uses a jinja2 template that converts a list of messages i
|
||||
{"conversations": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
See [configs](../config-reference.qmd) for full configs and supported templates.
|
||||
See [configs](../config.qmd) for full configs and supported templates.
|
||||
|
||||
### Migrating from sharegpt
|
||||
|
||||
@@ -52,9 +52,7 @@ We recommend checking the below examples for other usecases.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Training on last message
|
||||
|
||||
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -68,9 +66,7 @@ datasets:
|
||||
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
||||
:::
|
||||
|
||||
#### Overriding default chat template
|
||||
|
||||
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: gemma # this overwrites the tokenizer's chat_template
|
||||
@@ -80,13 +76,7 @@ datasets:
|
||||
roles_to_train: ["assistant"] # default value
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
|
||||
:::
|
||||
|
||||
#### Using default chat template with fallback
|
||||
|
||||
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
|
||||
@@ -95,9 +85,7 @@ datasets:
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
#### Custom Jinja template
|
||||
|
||||
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
|
||||
@@ -112,9 +100,7 @@ datasets:
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||
:::
|
||||
|
||||
#### Using template with different token for EOT and EOS
|
||||
|
||||
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
||||
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
@@ -130,16 +116,16 @@ datasets:
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
See [config documentation](../config-reference.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
|
||||
See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
|
||||
|
||||
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config-reference.qmd) for more details.
|
||||
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
|
||||
:::
|
||||
|
||||
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
@@ -159,73 +145,7 @@ If EOS token only appears at the end of a prompt, `train_on_eos: last` is equiva
|
||||
:::
|
||||
|
||||
|
||||
#### Using tool use
|
||||
|
||||
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "...",
|
||||
"function": {
|
||||
"name": "...",
|
||||
"description": "...",
|
||||
"parameters": {
|
||||
"type": "...",
|
||||
"properties": {
|
||||
// ...
|
||||
},
|
||||
"required": ["..."],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
// ...
|
||||
{
|
||||
"role": "assistant", // call the function via assistant
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "...",
|
||||
"arguments": {
|
||||
"...": "...",
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "...",
|
||||
"content": "..."
|
||||
},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
# field_tools: tools # default is `tools`
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
|
||||
:::
|
||||
|
||||
|
||||
#### Using fine-grained control over token masking
|
||||
|
||||
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
For a data sample that looks like:
|
||||
|
||||
@@ -276,9 +196,7 @@ datasets:
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
#### Reasoning split
|
||||
|
||||
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
|
||||
@@ -186,4 +186,4 @@ datasets:
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
```
|
||||
|
||||
See full config options under [here](../config-reference.qmd).
|
||||
See full config options under [here](../config.qmd).
|
||||
|
||||
@@ -36,7 +36,7 @@ This matches the API of [`datasets.load_dataset`](https://github.com/huggingface
|
||||
|
||||
For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading).
|
||||
|
||||
For full details on the config, see [config-reference.qmd](config-reference.qmd).
|
||||
For full details on the config, see [config.qmd](config.qmd).
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ format:
|
||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
|
||||
For Blackwell GPUs, please use the tags with Pytorch 2.7.0 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
## Base
|
||||
@@ -32,8 +32,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.1`
|
||||
- `main-base-py3.11-cu128-2.7.0`
|
||||
- `main-base-py3.11-cu126-2.7.0`
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.5.1`
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ output_dir: ./outputs/lora-out
|
||||
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
||||
:::
|
||||
|
||||
See our [config options](config-reference.qmd) for more details.
|
||||
See our [Config options](config.qmd) for more details.
|
||||
|
||||
### Training {#sec-training}
|
||||
|
||||
@@ -179,7 +179,7 @@ Now that you have the basics, you might want to:
|
||||
|
||||
Check our other guides for details on these topics:
|
||||
|
||||
- [Configuration Guide](config-reference.qmd) - Full configuration options
|
||||
- [Configuration Guide](config.qmd) - Full configuration options
|
||||
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
|
||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||
- [Multi-GPU Training](multi-gpu.qmd)
|
||||
|
||||
@@ -14,7 +14,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
## Requirements {#sec-requirements}
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.11
|
||||
- Python ≥3.10
|
||||
- PyTorch ≥2.5.1
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
@@ -153,7 +153,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
|
||||
### Conda/Pip venv {#sec-conda}
|
||||
|
||||
1. Install Python ≥3.11
|
||||
1. Install Python ≥3.10
|
||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||
3. Install Axolotl:
|
||||
```{.bash}
|
||||
|
||||
@@ -29,4 +29,4 @@ qat:
|
||||
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
|
||||
```
|
||||
|
||||
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.
|
||||
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this.
|
||||
|
||||
@@ -32,7 +32,7 @@ output_dir: # The path to the output directory.
|
||||
|
||||
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
|
||||
|
||||
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.qmd) - you can do this by using the existing QAT configuration file which
|
||||
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
|
||||
you used to train the model:
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -500,7 +500,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
### GRPO
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||
:::
|
||||
|
||||
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
|
||||
|
||||
@@ -1,752 +0,0 @@
|
||||
# type: ignore
|
||||
|
||||
"""
|
||||
Quarto documentation generation from Pydantic models. Uses Pydantic model source code
|
||||
to automatically group fields, including inherited fields from parent classes.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
import types
|
||||
import typing
|
||||
from typing import Any, FrozenSet, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
|
||||
class QuartoGenerator:
|
||||
"""Generate Quarto documentation from Pydantic models."""
|
||||
|
||||
def __init__(self):
|
||||
self._class_fields_cache = {}
|
||||
self._inheritance_map_cache = {}
|
||||
self._nested_models_cache = {}
|
||||
|
||||
def _get_direct_fields(self, cls: Type[BaseModel]) -> FrozenSet[str]:
|
||||
"""Get fields defined directly in a single class (not inherited)."""
|
||||
if cls in self._class_fields_cache:
|
||||
return self._class_fields_cache[cls]
|
||||
|
||||
fields = set()
|
||||
|
||||
# Get annotated fields
|
||||
if hasattr(cls, "__annotations__"):
|
||||
fields.update(cls.__annotations__.keys())
|
||||
|
||||
# Filter out private/special methods
|
||||
fields = {f for f in fields if not f.startswith("_")}
|
||||
|
||||
result = frozenset(fields)
|
||||
self._class_fields_cache[cls] = result
|
||||
return result
|
||||
|
||||
def _is_pydantic_model(self, type_obj) -> bool:
|
||||
"""Check if a type is a Pydantic BaseModel."""
|
||||
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def _extract_nested_type(self, field_type) -> Any:
|
||||
"""Extract the actual type from complex type annotations."""
|
||||
# Handle Annotated types (Python 3.9+)
|
||||
if hasattr(typing, "get_origin") and hasattr(typing, "get_args"):
|
||||
origin = typing.get_origin(field_type)
|
||||
args = typing.get_args(field_type)
|
||||
|
||||
if origin is not None:
|
||||
# Handle Annotated[SomeType, ...] - extract the first argument
|
||||
if hasattr(typing, "Annotated") and origin is typing.Annotated:
|
||||
if args:
|
||||
return self._extract_nested_type(
|
||||
args[0]
|
||||
) # Recursively process the actual type
|
||||
|
||||
# Handle list[SomeType], List[SomeType], etc.
|
||||
elif origin in (list, typing.List):
|
||||
if args:
|
||||
return self._extract_nested_type(
|
||||
args[0]
|
||||
) # Extract element type
|
||||
|
||||
# Handle Union types (including | syntax)
|
||||
elif origin is typing.Union:
|
||||
# Get non-None types from the Union
|
||||
non_none_types = [arg for arg in args if arg is not type(None)]
|
||||
if len(non_none_types) >= 1:
|
||||
# Prioritize Pydantic models over primitive types
|
||||
pydantic_models = [
|
||||
arg
|
||||
for arg in non_none_types
|
||||
if self._is_pydantic_model(arg)
|
||||
]
|
||||
if pydantic_models:
|
||||
# Return the first Pydantic model found
|
||||
return self._extract_nested_type(pydantic_models[0])
|
||||
|
||||
# No Pydantic models, return the first non-None type
|
||||
return self._extract_nested_type(non_none_types[0])
|
||||
|
||||
# Handle new Python 3.10+ union syntax (PeftConfig | None)
|
||||
if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType:
|
||||
# Get non-None types from the Union
|
||||
non_none_types = [
|
||||
arg for arg in field_type.__args__ if arg is not type(None)
|
||||
]
|
||||
if len(non_none_types) >= 1:
|
||||
# Prioritize Pydantic models over primitive types
|
||||
pydantic_models = [
|
||||
arg for arg in non_none_types if self._is_pydantic_model(arg)
|
||||
]
|
||||
if pydantic_models:
|
||||
return self._extract_nested_type(pydantic_models[0])
|
||||
return self._extract_nested_type(non_none_types[0])
|
||||
|
||||
# Handle old typing.Union syntax (fallback)
|
||||
if hasattr(field_type, "__origin__"):
|
||||
if field_type.__origin__ is Union:
|
||||
# Get non-None types from the Union
|
||||
non_none_types = [
|
||||
arg for arg in field_type.__args__ if arg is not type(None)
|
||||
]
|
||||
if len(non_none_types) >= 1:
|
||||
# Prioritize Pydantic models over primitive types
|
||||
pydantic_models = [
|
||||
arg for arg in non_none_types if self._is_pydantic_model(arg)
|
||||
]
|
||||
if pydantic_models:
|
||||
return self._extract_nested_type(pydantic_models[0])
|
||||
return self._extract_nested_type(non_none_types[0])
|
||||
# Handle other generic types like dict[str, Any], etc.
|
||||
elif hasattr(field_type, "__args__"):
|
||||
return field_type
|
||||
|
||||
return field_type
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def _extract_all_pydantic_models_from_type(
|
||||
self, field_type
|
||||
) -> list[type[BaseModel]]:
|
||||
"""Extract all Pydantic models from a type annotation, including from Unions."""
|
||||
models = []
|
||||
|
||||
if field_type is None:
|
||||
return models
|
||||
|
||||
# Handle Annotated types
|
||||
if hasattr(typing, "get_origin") and hasattr(typing, "get_args"):
|
||||
origin = typing.get_origin(field_type)
|
||||
args = typing.get_args(field_type)
|
||||
|
||||
if origin is not None:
|
||||
# Handle Annotated[SomeType, ...] - extract from the first argument
|
||||
if hasattr(typing, "Annotated") and origin is typing.Annotated:
|
||||
if args:
|
||||
models.extend(
|
||||
self._extract_all_pydantic_models_from_type(args[0])
|
||||
)
|
||||
return models
|
||||
|
||||
# Handle list[SomeType], List[SomeType], etc.
|
||||
if origin in (list, typing.List):
|
||||
if args:
|
||||
models.extend(
|
||||
self._extract_all_pydantic_models_from_type(args[0])
|
||||
)
|
||||
return models
|
||||
|
||||
# Handle Union types
|
||||
if origin is typing.Union:
|
||||
for arg in args:
|
||||
if arg is not type(None): # Skip None type
|
||||
models.extend(
|
||||
self._extract_all_pydantic_models_from_type(arg)
|
||||
)
|
||||
return models
|
||||
|
||||
# Handle new Python 3.10+ union syntax
|
||||
if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType:
|
||||
for arg in field_type.__args__:
|
||||
if arg is not type(None): # Skip None type
|
||||
models.extend(self._extract_all_pydantic_models_from_type(arg))
|
||||
return models
|
||||
|
||||
# Handle old typing.Union syntax (fallback)
|
||||
if hasattr(field_type, "__origin__") and field_type.__origin__ is Union:
|
||||
for arg in field_type.__args__:
|
||||
if arg is not type(None): # Skip None type
|
||||
models.extend(self._extract_all_pydantic_models_from_type(arg))
|
||||
return models
|
||||
|
||||
# Check if this type itself is a Pydantic model
|
||||
if self._is_pydantic_model(field_type):
|
||||
models.append(field_type)
|
||||
|
||||
return models
|
||||
|
||||
def _get_nested_models(
|
||||
self, model_class: type[BaseModel], visited=None
|
||||
) -> dict[str, type[BaseModel]]:
|
||||
"""Get all nested Pydantic models from a model class."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Avoid infinite recursion
|
||||
if model_class in visited:
|
||||
return {}
|
||||
|
||||
if model_class in self._nested_models_cache:
|
||||
return self._nested_models_cache[model_class]
|
||||
|
||||
visited.add(model_class)
|
||||
nested_models = {}
|
||||
|
||||
# Check all fields in the model
|
||||
for field_info in model_class.model_fields.values():
|
||||
field_type = self._extract_nested_type(field_info.annotation)
|
||||
|
||||
if self._is_pydantic_model(field_type):
|
||||
nested_models[field_type.__name__] = field_type
|
||||
# Recursively get nested models from this nested model
|
||||
deeper_nested = self._get_nested_models(field_type, visited.copy())
|
||||
nested_models.update(deeper_nested)
|
||||
|
||||
self._nested_models_cache[model_class] = nested_models
|
||||
return nested_models
|
||||
|
||||
def _build_inheritance_map(self, child_class: Type[BaseModel]):
|
||||
"""Build inheritance map for a class and all its parents."""
|
||||
if child_class in self._inheritance_map_cache:
|
||||
return self._inheritance_map_cache[child_class]
|
||||
|
||||
inheritance_map = {}
|
||||
|
||||
# Get MRO and filter out BaseModel and object
|
||||
mro_classes = [
|
||||
cls
|
||||
for cls in child_class.__mro__
|
||||
if cls not in (BaseModel, object) and hasattr(cls, "__annotations__")
|
||||
]
|
||||
|
||||
# Process each class in the MRO
|
||||
for cls in mro_classes:
|
||||
inheritance_map[cls] = self._get_direct_fields(cls)
|
||||
|
||||
self._inheritance_map_cache[child_class] = inheritance_map
|
||||
return inheritance_map
|
||||
|
||||
def _wrap_comment(self, text: str, width: int = 88) -> list[str]:
|
||||
"""Wrap a comment to specified width, accounting for '# ' prefix."""
|
||||
if not text.strip():
|
||||
return ["#"]
|
||||
|
||||
# Account for "# " prefix (2 characters)
|
||||
content_width = width - 2
|
||||
wrapped_lines = textwrap.wrap(text, width=content_width)
|
||||
return [f"# {line}" for line in wrapped_lines]
|
||||
|
||||
def _extract_type_from_source(
|
||||
self, model_class: type[BaseModel], field_name: str
|
||||
) -> str:
|
||||
"""Extract the actual type annotation text from source code, checking inheritance chain."""
|
||||
# Use inheritance map to check classes efficiently
|
||||
inheritance_map = self._build_inheritance_map(model_class)
|
||||
|
||||
# Check classes in MRO order
|
||||
for cls in model_class.__mro__:
|
||||
if cls in inheritance_map and field_name in inheritance_map[cls]:
|
||||
type_annotation = self._get_type_from_class_source(cls, field_name)
|
||||
if type_annotation != "unknown":
|
||||
return type_annotation
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _get_type_from_class_source(self, class_obj: type, field_name: str) -> str:
|
||||
"""Extract type annotation from a specific class's source code."""
|
||||
try:
|
||||
source = inspect.getsource(class_obj)
|
||||
tree = ast.parse(source)
|
||||
except (OSError, TypeError):
|
||||
return "unknown"
|
||||
|
||||
# Find the class definition
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__:
|
||||
# Find the field assignment
|
||||
for body_node in node.body:
|
||||
if isinstance(body_node, ast.AnnAssign) and isinstance(
|
||||
body_node.target, ast.Name
|
||||
):
|
||||
if body_node.target.id == field_name and body_node.annotation:
|
||||
return ast.unparse(body_node.annotation)
|
||||
break
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _extract_field_groups_from_all_classes(
|
||||
self, model_class: type[BaseModel]
|
||||
) -> list[dict]:
|
||||
"""Extract field groups from all classes in the inheritance hierarchy."""
|
||||
all_groups = []
|
||||
inheritance_map = self._build_inheritance_map(model_class)
|
||||
|
||||
# Get all Pydantic base classes in MRO order (most specific first)
|
||||
# This puts AxolotlInputConfig fields first, then parent class fields
|
||||
pydantic_classes = [
|
||||
cls
|
||||
for cls in model_class.__mro__
|
||||
if cls in inheritance_map and inheritance_map[cls]
|
||||
]
|
||||
|
||||
# Extract groups from each class
|
||||
for cls in pydantic_classes:
|
||||
class_groups = self._extract_field_groups_from_source(cls)
|
||||
for group in class_groups:
|
||||
all_groups.append(group)
|
||||
|
||||
# If no groups found, create a default grouping by class
|
||||
if not all_groups:
|
||||
for cls in pydantic_classes:
|
||||
fields_in_class = inheritance_map[cls]
|
||||
if fields_in_class:
|
||||
all_groups.append(
|
||||
{
|
||||
"fields": list(fields_in_class),
|
||||
}
|
||||
)
|
||||
|
||||
return all_groups
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def _extract_field_groups_from_source(
|
||||
self, model_class: type[BaseModel]
|
||||
) -> list[dict]:
|
||||
"""Extract field groups from source code based on blank lines and comments."""
|
||||
try:
|
||||
source = inspect.getsource(model_class)
|
||||
tree = ast.parse(source)
|
||||
except (OSError, TypeError):
|
||||
# Fallback if we can't get source code
|
||||
fields_in_class = self._get_direct_fields(model_class)
|
||||
if fields_in_class:
|
||||
return [
|
||||
{
|
||||
"fields": list(fields_in_class),
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
groups = []
|
||||
current_group_fields = []
|
||||
current_group_comment = None
|
||||
|
||||
# Find the class definition
|
||||
class_node = None
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == model_class.__name__:
|
||||
class_node = node
|
||||
break
|
||||
|
||||
if not class_node:
|
||||
fields_in_class = self._get_direct_fields(model_class)
|
||||
if fields_in_class:
|
||||
return [
|
||||
{
|
||||
"fields": list(fields_in_class),
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
# Parse the source lines to detect groupings
|
||||
source_lines = source.split("\n")
|
||||
|
||||
# Get fields that are actually defined in this specific class
|
||||
fields_in_class = self._get_direct_fields(model_class)
|
||||
|
||||
# Find assignments that correspond to model fields for THIS class only
|
||||
field_assignments = []
|
||||
for node in class_node.body:
|
||||
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
|
||||
field_name = node.target.id
|
||||
if field_name in fields_in_class:
|
||||
field_assignments.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"lineno": node.lineno,
|
||||
"end_lineno": getattr(node, "end_lineno", node.lineno),
|
||||
}
|
||||
)
|
||||
|
||||
if not field_assignments:
|
||||
if fields_in_class:
|
||||
return [
|
||||
{
|
||||
"fields": list(fields_in_class),
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
# Sort by line number
|
||||
field_assignments.sort(key=lambda x: x["lineno"])
|
||||
|
||||
# Group fields based on blank lines and comments
|
||||
for i, field_info in enumerate(field_assignments):
|
||||
field_name = field_info["name"]
|
||||
current_line = field_info["lineno"]
|
||||
|
||||
# Check if this starts a new group (blank line before or significant gap)
|
||||
is_new_group = False
|
||||
|
||||
if i == 0:
|
||||
is_new_group = True
|
||||
else:
|
||||
prev_end_line = field_assignments[i - 1]["end_lineno"]
|
||||
|
||||
# Check for blank lines or comments between fields
|
||||
lines_between = source_lines[prev_end_line : current_line - 1]
|
||||
has_blank_line = any(line.strip() == "" for line in lines_between)
|
||||
has_comment = any(
|
||||
line.strip().startswith("#") for line in lines_between
|
||||
)
|
||||
|
||||
# Start new group if there's a blank line or comment, or significant gap
|
||||
if has_blank_line or has_comment or (current_line - prev_end_line > 3):
|
||||
is_new_group = True
|
||||
|
||||
if is_new_group and current_group_fields:
|
||||
# Save the previous group
|
||||
groups.append(
|
||||
{
|
||||
"fields": current_group_fields.copy(),
|
||||
"description": current_group_comment,
|
||||
}
|
||||
)
|
||||
current_group_fields = []
|
||||
current_group_comment = None
|
||||
|
||||
current_group_fields.append(field_name)
|
||||
|
||||
# Add the final group
|
||||
if current_group_fields:
|
||||
groups.append(
|
||||
{
|
||||
"fields": current_group_fields,
|
||||
"description": current_group_comment,
|
||||
}
|
||||
)
|
||||
|
||||
return groups
|
||||
|
||||
def _generate_field_documentation(
|
||||
self,
|
||||
model_class: type[BaseModel],
|
||||
field_name: str,
|
||||
field_info: dict,
|
||||
field_type_str: str,
|
||||
is_required: bool,
|
||||
indent_level: int = 0,
|
||||
visited_models: set = None,
|
||||
) -> list[str]:
|
||||
"""Generate documentation for a single field, expanding nested models inline."""
|
||||
if visited_models is None:
|
||||
visited_models = set()
|
||||
|
||||
lines = []
|
||||
indent = " " * indent_level
|
||||
|
||||
# Get the actual field type for nested model detection
|
||||
if field_name in model_class.model_fields:
|
||||
pydantic_field_info = model_class.model_fields[field_name]
|
||||
actual_field_type = pydantic_field_info.annotation
|
||||
else:
|
||||
actual_field_type = None
|
||||
|
||||
# Add description comment if available
|
||||
description = field_info.get("description", "")
|
||||
if description:
|
||||
wrapped_lines = self._wrap_comment(description, width=88 - len(indent))
|
||||
for line in wrapped_lines:
|
||||
lines.append(f"{indent}{line}")
|
||||
|
||||
# Extract nested Pydantic models from the type annotation
|
||||
nested_models = self._extract_all_pydantic_models_from_type(actual_field_type)
|
||||
|
||||
# Filter out already visited models to prevent infinite recursion
|
||||
expandable_models = [
|
||||
model for model in nested_models if model not in visited_models
|
||||
]
|
||||
|
||||
if expandable_models:
|
||||
# This field contains Pydantic models that can be expanded
|
||||
|
||||
# Show the field with its full type annotation
|
||||
field_line = f"{indent}{field_name}: {field_type_str}"
|
||||
if field_info.get("default") is not None:
|
||||
field_line += f" = {field_info['default']}"
|
||||
if is_required:
|
||||
field_line += " (required)"
|
||||
lines.append(field_line)
|
||||
|
||||
# Add to visited to prevent infinite recursion
|
||||
new_visited = visited_models.copy()
|
||||
new_visited.update(expandable_models)
|
||||
|
||||
# Expand each nested Pydantic model
|
||||
for i, nested_model in enumerate(expandable_models):
|
||||
if i > 0:
|
||||
lines.append("\n")
|
||||
lines.append(f"{indent} # For {nested_model.__name__}:")
|
||||
|
||||
# Get nested model schema
|
||||
try:
|
||||
nested_schema = nested_model.model_json_schema()
|
||||
nested_properties = nested_schema.get("properties", {})
|
||||
nested_required = nested_schema.get("required", [])
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# Fallback: use model fields directly
|
||||
nested_properties = {}
|
||||
nested_required = []
|
||||
for (
|
||||
nested_field_name,
|
||||
nested_field_info,
|
||||
) in nested_model.model_fields.items():
|
||||
nested_description = ""
|
||||
if (
|
||||
hasattr(nested_field_info, "json_schema_extra")
|
||||
and nested_field_info.json_schema_extra
|
||||
):
|
||||
nested_description = (
|
||||
nested_field_info.json_schema_extra.get(
|
||||
"description", ""
|
||||
)
|
||||
)
|
||||
elif (
|
||||
hasattr(nested_field_info, "description")
|
||||
and nested_field_info.description
|
||||
):
|
||||
nested_description = nested_field_info.description
|
||||
|
||||
nested_default_val = None
|
||||
if (
|
||||
hasattr(nested_field_info, "default")
|
||||
and nested_field_info.default is not None
|
||||
):
|
||||
if str(nested_field_info.default) != "PydanticUndefined":
|
||||
nested_default_val = nested_field_info.default
|
||||
|
||||
nested_properties[nested_field_name] = {
|
||||
"type": "unknown",
|
||||
"description": nested_description,
|
||||
"default": nested_default_val,
|
||||
}
|
||||
|
||||
if nested_field_info.is_required():
|
||||
nested_required.append(nested_field_name)
|
||||
|
||||
# Get field groups for the nested model
|
||||
nested_field_groups = self._extract_field_groups_from_all_classes(
|
||||
nested_model
|
||||
)
|
||||
|
||||
# Generate nested fields with increased indentation
|
||||
for i, group in enumerate(nested_field_groups):
|
||||
if not group["fields"]:
|
||||
continue
|
||||
|
||||
# Add blank line between groups (except before first group)
|
||||
if i > 0:
|
||||
lines.append("")
|
||||
|
||||
# Process nested fields
|
||||
for nested_field_name in group["fields"]:
|
||||
if nested_field_name not in nested_properties:
|
||||
continue
|
||||
|
||||
nested_field_info = nested_properties[nested_field_name]
|
||||
nested_field_type = self._extract_type_from_source(
|
||||
nested_model, nested_field_name
|
||||
)
|
||||
nested_is_required = nested_field_name in nested_required
|
||||
|
||||
# Recursively generate documentation for nested field
|
||||
nested_lines = self._generate_field_documentation(
|
||||
nested_model,
|
||||
nested_field_name,
|
||||
nested_field_info,
|
||||
nested_field_type,
|
||||
nested_is_required,
|
||||
indent_level + 1,
|
||||
new_visited,
|
||||
)
|
||||
lines.extend(nested_lines)
|
||||
else:
|
||||
# Regular field (no expandable nested models)
|
||||
field_line = f"{indent}{field_name}: {field_type_str}"
|
||||
if field_info.get("default") is not None:
|
||||
field_line += f" = {field_info['default']}"
|
||||
if is_required:
|
||||
field_line += " (required)"
|
||||
lines.append(field_line)
|
||||
|
||||
return lines
|
||||
|
||||
def generate_qmd(
|
||||
self,
|
||||
model_class: type[BaseModel],
|
||||
title: str | None = None,
|
||||
expand_nested: bool = True,
|
||||
) -> str:
|
||||
"""Auto-generate config reference documentation including inherited fields."""
|
||||
|
||||
if title is None:
|
||||
title = f"{model_class.__name__} Reference"
|
||||
|
||||
# Try to get JSON schema, with fallback for serialization issues
|
||||
try:
|
||||
schema = model_class.model_json_schema()
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
print(
|
||||
f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
|
||||
)
|
||||
# Fallback: use model fields directly
|
||||
properties = {}
|
||||
required = []
|
||||
for field_name, field_info in model_class.model_fields.items():
|
||||
# Extract description from json_schema_extra or field info
|
||||
description = ""
|
||||
if (
|
||||
hasattr(field_info, "json_schema_extra")
|
||||
and field_info.json_schema_extra
|
||||
):
|
||||
description = field_info.json_schema_extra.get("description", "")
|
||||
elif hasattr(field_info, "description") and field_info.description:
|
||||
description = field_info.description
|
||||
|
||||
# Get default value
|
||||
default_val = None
|
||||
if hasattr(field_info, "default") and field_info.default is not None:
|
||||
# Handle special Pydantic default markers
|
||||
if str(field_info.default) != "PydanticUndefined":
|
||||
default_val = field_info.default
|
||||
|
||||
properties[field_name] = {
|
||||
"type": "unknown",
|
||||
"description": description,
|
||||
"default": default_val,
|
||||
}
|
||||
|
||||
if field_info.is_required():
|
||||
required.append(field_name)
|
||||
|
||||
# Extract field groups from all classes in inheritance hierarchy
|
||||
field_groups = self._extract_field_groups_from_all_classes(model_class)
|
||||
|
||||
# Start building QMD content
|
||||
qmd_lines = [
|
||||
"---",
|
||||
f"title: {title}",
|
||||
"description: A complete list of all configuration options.",
|
||||
"---",
|
||||
"",
|
||||
]
|
||||
|
||||
# Generate one big code block with all fields (inline nested expansion)
|
||||
qmd_lines.append("```yaml")
|
||||
|
||||
for i, group in enumerate(field_groups):
|
||||
if not group["fields"]:
|
||||
continue
|
||||
|
||||
# Add blank line between groups (except before first group)
|
||||
if i > 0:
|
||||
qmd_lines.append("")
|
||||
|
||||
# Process fields in the order they appear in source
|
||||
for field_name in group["fields"]:
|
||||
if field_name not in properties:
|
||||
continue
|
||||
|
||||
field_info = properties[field_name]
|
||||
field_type = self._extract_type_from_source(model_class, field_name)
|
||||
is_required = field_name in required
|
||||
|
||||
if expand_nested:
|
||||
# Check if this field has nested models
|
||||
if field_name in model_class.model_fields:
|
||||
pydantic_field_info = model_class.model_fields[field_name]
|
||||
nested_models = self._extract_all_pydantic_models_from_type(
|
||||
pydantic_field_info.annotation
|
||||
)
|
||||
has_nested = bool(nested_models)
|
||||
else:
|
||||
has_nested = False
|
||||
|
||||
# Add blank line before nested config
|
||||
if has_nested:
|
||||
qmd_lines.append("")
|
||||
|
||||
# Use the new inline generation method
|
||||
field_lines = self._generate_field_documentation(
|
||||
model_class,
|
||||
field_name,
|
||||
field_info,
|
||||
field_type,
|
||||
is_required,
|
||||
indent_level=0,
|
||||
visited_models=set(),
|
||||
)
|
||||
qmd_lines.extend(field_lines)
|
||||
|
||||
# Add blank line after nested config
|
||||
if has_nested:
|
||||
qmd_lines.append("")
|
||||
else:
|
||||
# Original simple approach
|
||||
description = field_info.get("description", "")
|
||||
default = field_info.get("default")
|
||||
|
||||
# Add wrapped comment for description
|
||||
if description:
|
||||
wrapped_lines = self._wrap_comment(description)
|
||||
qmd_lines.extend(wrapped_lines)
|
||||
|
||||
line = f"{field_name}: {field_type}"
|
||||
if default is not None:
|
||||
line += f" = {default}"
|
||||
if is_required:
|
||||
line += " (required)"
|
||||
qmd_lines.append(line)
|
||||
|
||||
qmd_lines.append("```")
|
||||
|
||||
# Join all lines and clean up any double newlines
|
||||
content = "\n".join(qmd_lines)
|
||||
|
||||
# Replace multiple consecutive newlines with just two newlines (one blank line)
|
||||
import re
|
||||
|
||||
content = re.sub(r"\n{3,}", "\n\n", content)
|
||||
|
||||
# Ensure single newline at the very end
|
||||
content = content.rstrip("\n") + "\n"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def main():
|
||||
generator = QuartoGenerator()
|
||||
|
||||
print("Generating config reference content...")
|
||||
qmd_content = generator.generate_qmd(AxolotlInputConfig, "Config Reference", True)
|
||||
|
||||
print("Writing to file...")
|
||||
with open("docs/config-reference.qmd", "w", encoding="utf-8") as f:
|
||||
f.write(qmd_content)
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,10 +5,6 @@ tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot_id|>
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
# Finetune Magistral Small with Axolotl
|
||||
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on [HuggingFace](https://huggingface.co/mistralai/Magistral-Small-2506). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
|
||||
|
||||
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 recommended)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
|
||||
```
|
||||
|
||||
2. Download the example config:
|
||||
|
||||
```bash
|
||||
axolotl fetch examples
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/magistral/magistral-small-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 24GB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Limitations
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: mistralai/Magistral-Small-2506
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing:
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
||||
fsdp_activation_checkpointing: true
|
||||
@@ -1,63 +0,0 @@
|
||||
base_model: mistralai/Magistral-Small-2506
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
@@ -25,7 +25,7 @@ pad_to_sequence_len: false
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
BIN
favicon.jpg
BIN
favicon.jpg
Binary file not shown.
|
Before Width: | Height: | Size: 4.7 KiB After Width: | Height: | Size: 4.5 KiB |
@@ -13,13 +13,14 @@ packaging==23.2
|
||||
|
||||
huggingface_hub==0.32.2
|
||||
peft==0.15.2
|
||||
transformers==4.52.4
|
||||
transformers==4.52.3
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.7.0
|
||||
datasets==3.6.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.18.1
|
||||
hf_xet==1.1.2
|
||||
mistral-common[hf-hub]==1.6.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
@@ -67,5 +68,3 @@ schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
mistral-common==1.6.0
|
||||
|
||||
2
setup.py
2
setup.py
@@ -118,7 +118,7 @@ extras_require = {
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.1",
|
||||
"deepspeed==0.17.0",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.11.0.dev"
|
||||
__version__ = "0.10.0.dev0"
|
||||
|
||||
@@ -305,8 +305,8 @@ def load_model_and_tokenizer(
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
Helper function for loading a model, tokenizer, and processor specified in the
|
||||
given `axolotl` config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Various shared constants"""
|
||||
"""
|
||||
Various shared constants
|
||||
"""
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
@@ -3,13 +3,15 @@
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
@@ -28,7 +30,16 @@ class TrainDatasetMeta:
|
||||
|
||||
|
||||
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
"""Randomly sample `num_samples` samples with replacement from `dataset`."""
|
||||
"""
|
||||
Randomly sample `num_samples` samples from `dataset`.
|
||||
|
||||
Args:
|
||||
dataset: Dataset.
|
||||
num_samples: Number of samples to return.
|
||||
|
||||
Returns:
|
||||
Random sample (with replacement) of examples in `dataset`.
|
||||
"""
|
||||
return dataset.select(
|
||||
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
|
||||
)
|
||||
@@ -40,37 +51,44 @@ def load_datasets(
|
||||
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
|
||||
debug: bool = False,
|
||||
) -> TrainDatasetMeta:
|
||||
"""Loads one or more training or evaluation datasets, calling
|
||||
`axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.
|
||||
"""
|
||||
Loads one or more training or evaluation datasets, calling
|
||||
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Command-specific CLI arguments.
|
||||
debug: Whether to print out tokenization of sample. This is duplicated in
|
||||
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.
|
||||
debug: Whether to print out tokenization of sample
|
||||
|
||||
Returns:
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
`total_num_steps`.
|
||||
"""
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||
preprocess_iterable = getattr(cli_args, "iterable", False)
|
||||
preprocess_iterable = (
|
||||
cli_args
|
||||
and hasattr(cli_args, "iterable")
|
||||
and cli_args.iterable is not None
|
||||
and cli_args.iterable
|
||||
)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
preprocess_iterable=preprocess_iterable,
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.debug
|
||||
or getattr(cli_args, "debug", False)
|
||||
or getattr(cli_args, "debug_text_only", False)
|
||||
or getattr(cli_args, "debug_num_examples", 0) > 0
|
||||
or debug
|
||||
):
|
||||
if ( # pylint: disable=too-many-boolean-expressions
|
||||
cli_args
|
||||
and (
|
||||
cli_args.debug
|
||||
or cfg.debug
|
||||
or cli_args.debug_text_only
|
||||
or int(cli_args.debug_num_examples) > 0
|
||||
)
|
||||
) or debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
@@ -95,10 +113,13 @@ def load_datasets(
|
||||
|
||||
|
||||
def load_preference_datasets(
|
||||
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||
) -> TrainDatasetMeta:
|
||||
"""Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
|
||||
"""
|
||||
Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
|
||||
Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
@@ -109,28 +130,23 @@ def load_preference_datasets(
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
"""
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
total_num_steps: Optional[int] = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
if cfg.rl is RLType.GRPO:
|
||||
total_num_steps = None
|
||||
|
||||
total_num_steps: int | None = None
|
||||
if cfg.rl is not RLType.GRPO:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
if (cli_args and cli_args.debug) or cfg.debug:
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
text_only = cli_args.debug_text_only if cli_args else False
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||
check_dataset_labels(
|
||||
dataset=train_samples,
|
||||
tokenizer=tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
rl_mode=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -380,16 +380,14 @@ class TrainerBuilderBase(abc.ABC):
|
||||
)
|
||||
|
||||
# eval_strategy and eval_steps
|
||||
if not self.eval_dataset and self.cfg.val_set_size == 0:
|
||||
# do not eval if no eval_dataset and val_set_size=0
|
||||
if not self.eval_dataset or self.cfg.val_set_size == 0:
|
||||
# do not eval if no eval_dataset or val_set_size=0
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
training_args_kwargs["eval_on_start"] = True
|
||||
elif self.cfg.eval_strategy:
|
||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
training_args_kwargs["eval_on_start"] = True
|
||||
|
||||
def _configure_reporting(self, training_args_kwargs: dict):
|
||||
report_to = []
|
||||
@@ -492,9 +490,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
# max_length is not used in CausalTrainer
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
@@ -21,12 +21,18 @@ from axolotl.core.trainers import (
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlPRMConfig,
|
||||
AxolotlRewardConfig,
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
@@ -57,6 +63,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
callbacks.append(EvalFirstStepCallback())
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
@@ -123,9 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
"""
|
||||
Gets the trainer class for the given configuration.
|
||||
"""
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
@@ -142,12 +146,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlPRMConfig,
|
||||
AxolotlRewardConfig,
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
|
||||
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps
|
||||
)
|
||||
@@ -316,12 +314,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||
self.cfg.image_resize_algorithm
|
||||
)
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||
if plugin_training_args:
|
||||
training_arguments_kwargs.update(plugin_training_args)
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
if self.cfg.kd_alpha is not None:
|
||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||
if self.cfg.kd_temperature is not None:
|
||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||
if self.cfg.kd_zscore_base_temp is not None:
|
||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
||||
self.cfg.kd_zscore_base_temp
|
||||
)
|
||||
if self.cfg.kd_top_k_before_softmax is not None:
|
||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
@@ -375,7 +381,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
elif "tokenizer" in sig.parameters:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
if (
|
||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
|
||||
and self.cfg.datasets is not None
|
||||
):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
@@ -402,10 +408,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return trainer
|
||||
|
||||
def build_collator(
|
||||
self,
|
||||
training_args, # type: "AxolotlTrainingArguments" # type: ignore
|
||||
is_eval=False,
|
||||
**kwargs,
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if (
|
||||
@@ -434,19 +437,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
|
||||
collator_cls_and_kwargs = None
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
|
||||
self.cfg, is_eval=is_eval
|
||||
)
|
||||
|
||||
if collator_cls_and_kwargs:
|
||||
collator = collator_cls_and_kwargs[0]
|
||||
if kwargs and isinstance(kwargs, dict):
|
||||
kwargs.update(collator_cls_and_kwargs[1])
|
||||
elif self.cfg.reward_model:
|
||||
if self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif use_batch_sampler_collator:
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
@@ -477,6 +468,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator_args.pop(0)
|
||||
kwargs.pop("pad_to_multiple_of", None)
|
||||
kwargs.pop("padding", None)
|
||||
elif self.cfg.kd_trainer:
|
||||
from axolotl.integrations.kd.collator import (
|
||||
DataCollatorForKD,
|
||||
KDBatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing:
|
||||
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = DataCollatorForKD
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@@ -12,9 +12,13 @@ from axolotl.core.trainers import (
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -27,9 +31,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
if self.cfg.qat:
|
||||
callbacks.append(QATCallback(self.cfg.qat))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -78,12 +79,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Returns training_args and trainer_kwargs
|
||||
"""
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
)
|
||||
|
||||
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps=total_num_steps
|
||||
)
|
||||
@@ -95,6 +90,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
training_args_kwargs["remove_unused_columns"] = False
|
||||
|
||||
# only rlhf
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.trl and self.cfg.trl.beta is not None:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta
|
||||
elif self.cfg.rl_beta is not None:
|
||||
@@ -143,7 +142,22 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
if self.cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
|
||||
# Not compatible with IPO
|
||||
if self.cfg.rl is RLType.DPO and self.cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = (
|
||||
self.cfg.dpo_use_logits_to_keep
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
@@ -151,12 +165,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||
if plugin_training_args:
|
||||
training_args_kwargs.update(plugin_training_args)
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
logging_first_step=True,
|
||||
**training_args_kwargs,
|
||||
|
||||
@@ -25,7 +25,6 @@ from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
CheckpointSaveMixin,
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
@@ -34,16 +33,13 @@ from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils import get_not_null
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
|
||||
):
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||
@@ -105,7 +101,7 @@ class AxolotlTrainer(
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
return MultipackBatchSampler(
|
||||
base_sampler,
|
||||
lengths=get_dataset_lengths(dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
@@ -115,12 +111,8 @@ class AxolotlTrainer(
|
||||
bin_size=self.args.sample_packing_bin_size,
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
num_processes=self.args.dataset_num_proc,
|
||||
)
|
||||
|
||||
len(sampler)
|
||||
return sampler
|
||||
|
||||
def _get_train_sampler(
|
||||
self, train_dataset: Optional[Dataset] = None
|
||||
) -> Optional[Sampler]:
|
||||
@@ -228,9 +220,7 @@ class AxolotlTrainer(
|
||||
}
|
||||
|
||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["drop_last"] = get_not_null(
|
||||
self.args.dataloader_drop_last, True
|
||||
)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
if sampler_fn is not None:
|
||||
sampler = sampler_fn(dataset)
|
||||
if isinstance(sampler, BatchSampler):
|
||||
|
||||
@@ -22,19 +22,10 @@ class DPOStrategy:
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
# Label smoothing is not compatible with IPO
|
||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
if cfg.dpo_padding_free is not None:
|
||||
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||
if cfg.dpo_norm_loss is not None:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -14,5 +14,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
dpo_norm_loss: bool | None = False
|
||||
|
||||
@@ -83,20 +83,3 @@ class AxolotlDPOTrainer(
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: nn.Module,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
is_ref_model: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if self.args.dpo_norm_loss:
|
||||
# fmt: off
|
||||
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
# concatenated_forward handles avg token logprob for ipo case already
|
||||
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
|
||||
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
|
||||
return res
|
||||
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
"""Custom handling to not fail training if fsdp optimizer is not savable"""
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class CheckpointSaveMixin(Trainer):
|
||||
"""Mixin to handle saving the optimizer and scheduler if they are not savable."""
|
||||
|
||||
def _save_optimizer_and_scheduler(self, output_dir):
|
||||
try:
|
||||
super()._save_optimizer_and_scheduler(output_dir)
|
||||
except NotImplementedError as exc:
|
||||
LOG.warning(
|
||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||
"for this training run will not be possible."
|
||||
)
|
||||
@@ -2,17 +2,238 @@
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Resampling
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
from axolotl.integrations.config import merge_training_args
|
||||
|
||||
AxolotlTrainingMixins: Type = merge_training_args()
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
lr_groups: Optional[list[dict]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
kd_ce_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_alpha: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||
)
|
||||
|
||||
kd_temperature: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "the temperature parameter for KL divergence loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_zscore_base_temp: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "the base temperature parameter for KL divergence with z-score when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_top_k_before_softmax: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
adam_beta3: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
adam_epsilon2: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The size of the image to resize to"},
|
||||
)
|
||||
|
||||
image_resize_algorithm: Resampling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The algorithm to use for image resizing"},
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,224 +0,0 @@
|
||||
"""
|
||||
Base Axolotl Training Mixins shared across various trainer configs
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Resampling
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for data processing"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
lr_groups: Optional[list[dict]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
# kd_ce_alpha: Optional[float] = field(
|
||||
# default=None,
|
||||
# metadata={
|
||||
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||
# },
|
||||
# )
|
||||
#
|
||||
# kd_alpha: Optional[float] = field(
|
||||
# default=1.0,
|
||||
# metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||
# )
|
||||
#
|
||||
# kd_temperature: Optional[float] = field(
|
||||
# default=1.0,
|
||||
# metadata={
|
||||
# "help": "the temperature parameter for KL divergence loss when using KD"
|
||||
# },
|
||||
# )
|
||||
|
||||
adam_beta3: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
adam_epsilon2: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The size of the image to resize to"},
|
||||
)
|
||||
|
||||
image_resize_algorithm: Resampling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The algorithm to use for image resizing"},
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Module containing Dataset functionality"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -19,21 +20,21 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenizedPromptDataset(Dataset):
|
||||
"""Dataset that returns tokenized prompts from a stream of text files.
|
||||
|
||||
Args:
|
||||
prompt_tokenizer: The prompt tokenizing method for processing the data.
|
||||
dataset: Dataset with text files.
|
||||
process_count: Number of processes to use for tokenizing.
|
||||
keep_in_memory: Whether to keep the tokenized dataset in memory.
|
||||
"""
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
process_count (int): Number of processes to use for tokenizing.
|
||||
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Dataset,
|
||||
process_count: int | None = None,
|
||||
keep_in_memory: bool | None = False,
|
||||
process_count: Optional[int] = None,
|
||||
keep_in_memory: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
@@ -48,13 +49,6 @@ class TokenizedPromptDataset(Dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||
|
||||
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
|
||||
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
|
||||
LOG.info(
|
||||
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
|
||||
)
|
||||
num_proc = 1
|
||||
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
@@ -70,6 +64,10 @@ class TokenizedPromptDataset(Dataset):
|
||||
desc="Strategy Filtering Rows",
|
||||
)
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
@@ -82,14 +80,14 @@ class TokenizedPromptDataset(Dataset):
|
||||
|
||||
def wrap_dataset_for_tokenized_prompt(
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset: Union[Dataset, IterableDataset],
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(dataset, IterableDataset):
|
||||
map_kwargs = {}
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
features = list(dataset.features.keys())
|
||||
features = dataset.features.keys()
|
||||
return dataset.map(
|
||||
prompt_tokenizer.tokenize_prompt,
|
||||
remove_columns=features,
|
||||
@@ -100,13 +98,12 @@ def wrap_dataset_for_tokenized_prompt(
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""Iterable dataset that returns constant length chunks of tokens from stream of
|
||||
text files.
|
||||
|
||||
Args:
|
||||
tokenizer: The processor used for processing the data.
|
||||
dataset: Dataset with text files.
|
||||
seq_length: Length of token sequences to return.
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for processing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
@@ -117,7 +114,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.eos_token_id
|
||||
self.datasets: list[IterableDataset] = datasets
|
||||
self.datasets: List[IterableDataset] = datasets
|
||||
self.seq_length = seq_length
|
||||
|
||||
vocab_size = len(tokenizer.get_vocab())
|
||||
@@ -181,10 +178,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
}
|
||||
else:
|
||||
LOG.warning(
|
||||
"Dropping batch due to tensor size mismatch "
|
||||
f"input_ids: {input_ids.size()}, "
|
||||
f"labels: {labels.size()}, "
|
||||
f"attention_mask: {attention_mask.size()}"
|
||||
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
||||
)
|
||||
buffer = {
|
||||
"input_ids": [],
|
||||
|
||||
@@ -7,6 +7,7 @@ from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
@@ -16,7 +17,6 @@ from axolotl.train import (
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
@@ -22,7 +22,6 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import importlib
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||
|
||||
from peft import PeftModel
|
||||
@@ -84,11 +83,6 @@ class BasePlugin:
|
||||
def get_input_args(self) -> str | None:
|
||||
"""Returns a pydantic model for the plugin's input arguments."""
|
||||
|
||||
def get_training_args_mixin(self) -> str | None:
|
||||
"""
|
||||
Returns a dataclass model for the plugin's training arguments.
|
||||
"""
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
@@ -164,31 +158,6 @@ class BasePlugin:
|
||||
trainer: The trainer object for training.
|
||||
"""
|
||||
|
||||
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns custom training arguments to set on TrainingArgs.
|
||||
|
||||
Args:
|
||||
cfg: The global axolotl configuration.
|
||||
|
||||
Returns:
|
||||
object: dict containing the training arguments.
|
||||
"""
|
||||
|
||||
def get_collator_cls_and_kwargs(
|
||||
self, cfg: DictDefault, is_eval: bool = False
|
||||
): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns a custom class for the collator.
|
||||
|
||||
Args:
|
||||
cfg: The global axolotl configuration.
|
||||
is_eval: Whether this is an eval split.
|
||||
|
||||
Returns:
|
||||
class: The class for the collator.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
|
||||
"""Creates and returns an optimizer for training.
|
||||
@@ -309,7 +278,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
|
||||
return plugin
|
||||
|
||||
|
||||
class PluginManager: # pylint: disable=too-many-public-methods
|
||||
class PluginManager:
|
||||
"""The `PluginManager` class is responsible for loading and managing plugins. It
|
||||
should be a singleton so it can be accessed from anywhere in the codebase.
|
||||
|
||||
@@ -368,11 +337,8 @@ class PluginManager: # pylint: disable=too-many-public-methods
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
LOG.info(f"Plugin loaded successfully: {plugin_name}")
|
||||
except ImportError as exc:
|
||||
except ImportError:
|
||||
LOG.error(f"Failed to load plugin: {plugin_name}")
|
||||
# print stacktrace
|
||||
traceback.print_exc()
|
||||
print(f"Error: {exc}")
|
||||
|
||||
def get_input_args(self) -> list[str]:
|
||||
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
||||
@@ -387,20 +353,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
|
||||
input_args.append(input_args_from_plugin)
|
||||
return input_args
|
||||
|
||||
def get_training_args_mixin(self):
|
||||
"""
|
||||
Returns a list of dataclasses for all registered plugins' training args mixins'
|
||||
|
||||
Returns:
|
||||
list[str]: A list of dataclsses
|
||||
"""
|
||||
training_args = []
|
||||
for plugin in self.plugins.values():
|
||||
training_args_from_plugin = plugin.get_training_args_mixin()
|
||||
if training_args_from_plugin is not None:
|
||||
training_args.append(training_args_from_plugin)
|
||||
return training_args
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
@@ -490,42 +442,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
|
||||
return trainer_cls
|
||||
return None
|
||||
|
||||
def get_training_args(self, cfg):
|
||||
"""
|
||||
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
|
||||
Returns:
|
||||
object: The training arguments
|
||||
"""
|
||||
training_args_kwargs = {}
|
||||
for plugin in self.plugins.values():
|
||||
training_args = plugin.get_training_args(cfg)
|
||||
if training_args is not None:
|
||||
training_args_kwargs.update(training_args)
|
||||
|
||||
return training_args_kwargs
|
||||
|
||||
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||
"""
|
||||
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
is_eval (bool): Whether this is an eval split.
|
||||
|
||||
Returns:
|
||||
object: The collator class, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
|
||||
if collator is not None:
|
||||
collator_cls, collator_kwargs = collator
|
||||
return collator_cls, collator_kwargs
|
||||
return None
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||
"""Calls the `post_trainer_create` method of all registered plugins.
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
|
||||
This was moved here to prevent circular imports.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Type
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
@@ -61,43 +61,3 @@ def merge_input_args():
|
||||
]
|
||||
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
||||
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
||||
|
||||
|
||||
def merge_training_args() -> Type:
|
||||
"""
|
||||
Merges training arguments from registered plugins with the base TrainingArguments.
|
||||
|
||||
This function retrieves the training arguments from registered plugins using the PluginManager.
|
||||
It then dynamically creates new classes, AxolotlTrainingMixins,
|
||||
that inherit from the base configurations and include the training arguments from the plugins.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
from axolotl.core.training_args_base import (
|
||||
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
|
||||
mixin_classes = []
|
||||
dynamic_input = ""
|
||||
for plugin_args in training_args_mixins:
|
||||
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
||||
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
||||
mixin_classes.append(plugin_cls)
|
||||
if dynamic_input:
|
||||
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
|
||||
|
||||
namespace: Dict[Any, Any] = {}
|
||||
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
dynamic_input, {**globals(), **local_vars}, namespace
|
||||
)
|
||||
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
|
||||
"AxolotlTrainingMixins"
|
||||
]
|
||||
return AxolotlTrainingMixins
|
||||
return AxolotlTrainingMixinsBase
|
||||
|
||||
@@ -24,14 +24,6 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
|
||||
|
||||
## Usage
|
||||
|
||||
**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet.
|
||||
|
||||
```bash
|
||||
git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764
|
||||
|
||||
pip3 install --no-build-isolation -e .
|
||||
```
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
@@ -15,12 +15,7 @@
|
||||
"""
|
||||
Plugin init to add KD support to Axolotl.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
|
||||
|
||||
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
@@ -33,75 +28,9 @@ class KDPlugin(BasePlugin):
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.kd.KDArgs"
|
||||
|
||||
def get_training_args_mixin(self):
|
||||
return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
if cfg.kd_trainer:
|
||||
from .trainer import AxolotlKDTrainer
|
||||
|
||||
return AxolotlKDTrainer
|
||||
return None
|
||||
|
||||
def get_training_args(self, cfg):
|
||||
return {
|
||||
"kd_ce_alpha": cfg.kd_ce_alpha,
|
||||
"kd_alpha": cfg.kd_alpha,
|
||||
"kd_temperature": cfg.kd_temperature,
|
||||
"kd_beta": cfg.kd_beta,
|
||||
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||
}
|
||||
|
||||
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||
if not cfg.kd_trainer:
|
||||
return None, None
|
||||
|
||||
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
|
||||
|
||||
use_batch_sampler_collator = False
|
||||
if is_eval is False and cfg.sample_packing:
|
||||
use_batch_sampler_collator = True
|
||||
if cfg.eval_sample_packing and is_eval:
|
||||
use_batch_sampler_collator = True
|
||||
|
||||
if cfg.kd_online_server_base_url:
|
||||
from .collator_online_teacher import OnlineTeacherCollator
|
||||
|
||||
return OnlineTeacherCollator, {
|
||||
"kd_online_server_base_url": cfg.kd_online_server_base_url,
|
||||
"kd_online_topk": cfg.kd_online_topk,
|
||||
"kd_temperature": cfg.kd_temperature,
|
||||
"kd_online_server": cfg.kd_online_server,
|
||||
"kd_online_timeout": cfg.kd_online_timeout,
|
||||
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||
}
|
||||
|
||||
if use_batch_sampler_collator:
|
||||
return KDBatchSamplerDataCollatorForSeq2Seq, {}
|
||||
return DataCollatorForKD, {}
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from .kernels.models import apply_kernel
|
||||
|
||||
apply_kernel(cfg.model_config_type)
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
|
||||
"""
|
||||
Adds temp scheduler callback to the Trainer instance.
|
||||
|
||||
Args:
|
||||
cfg (Any): Configuration object containing the sparse recipe.
|
||||
trainer (Trainer): Huggingface Trainer instance.
|
||||
|
||||
Returns:
|
||||
list: List containing the configured callback instances.
|
||||
"""
|
||||
if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
|
||||
callback = KDTemperatureSchedulerCallback(
|
||||
cfg.kd_temperature,
|
||||
cfg.kd_temperature_min,
|
||||
trainer,
|
||||
)
|
||||
return [callback]
|
||||
|
||||
return []
|
||||
|
||||
@@ -15,19 +15,9 @@
|
||||
"""
|
||||
Plugin args for KD support.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InferenceServerType(str, Enum):
|
||||
"""
|
||||
Online inferences server types to handle different request args
|
||||
"""
|
||||
|
||||
vllm = "vllm" # pylint: disable=invalid-name
|
||||
sglang = "sglang" # pylint: disable=invalid-name
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KDArgs(BaseModel):
|
||||
@@ -35,41 +25,13 @@ class KDArgs(BaseModel):
|
||||
Input args for knowledge distillation.
|
||||
"""
|
||||
|
||||
kd_trainer: float | None = None # whether to use KD trainer
|
||||
kd_ce_alpha: float | None = (
|
||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
||||
kd_ce_alpha: Optional[float] = (
|
||||
None # loss coefficient for cross-entropy loss during KD
|
||||
)
|
||||
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||
kd_temperature: float | None = None # temperature for sampling during KD
|
||||
kd_beta: float | None = 0.0 # beta coefficient for ratio of fwd and reverse KL
|
||||
kd_normalize_topk: bool | None = (
|
||||
None # whether to normalize student logits during KD
|
||||
)
|
||||
|
||||
# TODO online kd
|
||||
kd_online_server_base_url: str | None = None
|
||||
kd_online_topk: int | None = None
|
||||
kd_online_server: InferenceServerType | None = Field(
|
||||
default_factory=lambda: InferenceServerType.vllm
|
||||
)
|
||||
kd_online_timeout: int | None = 120
|
||||
kd_temperature_min: float | None = (
|
||||
None # kd temperature scheduling during online kd
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KDTrainingArgsMixin:
|
||||
"""
|
||||
Additional args for KD training.
|
||||
"""
|
||||
|
||||
kd_ce_alpha: float | None = (
|
||||
None # loss coefficient for cross-entropy loss during KD
|
||||
)
|
||||
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||
kd_temperature: float | None = None # temperature for sampling during KD
|
||||
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||
kd_normalize_topk: float | None = (
|
||||
None # whether to normalize student logits during KD
|
||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||
kd_top_k_before_softmax: Optional[bool] = (
|
||||
None # whether to sample top k before softmax during KD
|
||||
)
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
Transformers trainer callbacks to schedule the KD temperature during training
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
|
||||
class KDTemperatureSchedulerCallback(TrainerCallback):
|
||||
"""
|
||||
KD temperature scheduler callback for the trainer.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature_start, temperature_min, trainer):
|
||||
self.temperature_start = temperature_start
|
||||
self.temperature_min = temperature_min
|
||||
self.temperature = temperature_start
|
||||
|
||||
self.trainer = trainer
|
||||
|
||||
def on_step_end(
|
||||
self, args, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
# cosine decay temperature over the max steps
|
||||
|
||||
progress = state.global_step / state.max_steps
|
||||
# Cosine decay factor: 0.5 * (1 + cos(pi * progress))
|
||||
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
|
||||
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||
self.temperature = self.temperature_start - (
|
||||
(self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
|
||||
)
|
||||
|
||||
if hasattr(self.trainer.data_collator, "kd_temperature"):
|
||||
self.trainer.data_collator.kd_temperature = self.temperature
|
||||
@@ -15,15 +15,12 @@
|
||||
"""
|
||||
Chat template prompt strategy loader with KD support
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
"""
|
||||
@@ -104,8 +101,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# we shift for causal models in the trainer, so start the range from 0
|
||||
for _ in range(0, input_padding_len):
|
||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||
# otherwise, we need to shift in the trainer
|
||||
shift = 0
|
||||
for _ in range(shift, input_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
@@ -144,10 +143,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# normalize probabilities to sum to 1 in case they aren't already
|
||||
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||
if teacher_probs_t1_sum > 1e-9:
|
||||
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
@@ -167,115 +162,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
sample["target_mask"] = target_mask
|
||||
|
||||
return sample
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
||||
"""
|
||||
Strat for datasets with complete structured KD logprob data
|
||||
"""
|
||||
|
||||
def transform_logprobs(self, sample):
|
||||
"""
|
||||
Transform logprobs to target format for KD training
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
for i in range(len(logprobs))
|
||||
if logprobs[i] is not None and len(logprobs[i])
|
||||
]
|
||||
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||
top_k = min(max_top_k, min_top_k)
|
||||
if top_k == 0:
|
||||
raise ValueError("No non-zero top-k logprobs found.")
|
||||
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
if input_padding_len < 0:
|
||||
# logprobs is longer than target_seq_len,
|
||||
# so we need to slice from the left/beginning of logprobs
|
||||
logprobs = logprobs[:-input_seq_len]
|
||||
input_padding_len = 0
|
||||
# target_seq_len = input_seq_len
|
||||
|
||||
# truncate the second dimension of the logprobs to top_k
|
||||
logprobs = [row[:top_k] for row in logprobs]
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# we shift for causal models in the trainer, so start the range from 0
|
||||
for _ in range(0, input_padding_len):
|
||||
if shift == 1:
|
||||
# since we started at index 1 for causal, we need one more padding token
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for position in range(input_padding_len, input_seq_len):
|
||||
if sample["labels"][position] == -100:
|
||||
target_mask.append([0] * top_k)
|
||||
else:
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for token_pos_logprobs, pos_target_token_ids in zip(
|
||||
logprobs, sample["target_token_ids"]
|
||||
):
|
||||
# Convert to a tensor for easier manipulation
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
token_pos_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# normalize probabilities to sum to 1 in case they aren't already
|
||||
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||
if teacher_probs_t1_sum > 1e-9:
|
||||
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
else:
|
||||
teacher_probs_t2 = teacher_probs_t1
|
||||
# Re-normalize
|
||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(pos_target_token_ids)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
@@ -285,10 +177,8 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
target_token_ids = prompt.pop("target_token_ids")
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
return tokenized_prompt
|
||||
@@ -299,7 +189,7 @@ class KDStrategyLoader(StrategyLoader):
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
|
||||
def _get_strategy_cls(self):
|
||||
return ChatTemplateStrategyWithKD
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
@@ -314,14 +204,4 @@ class KDStrategyLoader(StrategyLoader):
|
||||
return strategy_params
|
||||
|
||||
|
||||
class KDStrategyLoaderV2(KDStrategyLoader):
|
||||
"""
|
||||
Load KD chat template datasets with pre-tokenized logprob data
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
|
||||
return ChatTemplateStrategyWithKDv2
|
||||
|
||||
|
||||
load_legacy = KDStrategyLoader()
|
||||
load = KDStrategyLoaderV2()
|
||||
load = KDStrategyLoader()
|
||||
|
||||
@@ -47,16 +47,11 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
max_len = 0
|
||||
|
||||
# Pad labels and position_ids first
|
||||
for feature_name, pad_token_id in [
|
||||
@@ -107,9 +102,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
target_mask_list.append(f.pop("target_mask"))
|
||||
|
||||
# Determine max lengths
|
||||
max_teacher_seq_len = max_len or max(
|
||||
len(seq) for seq in target_logprobs_list
|
||||
)
|
||||
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
||||
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
||||
|
||||
padded_target_logprobs = []
|
||||
@@ -216,9 +209,7 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
||||
# We want to produce a single "merged" feature dict for each sub-batch.
|
||||
out_features = [{} for _ in features]
|
||||
|
||||
for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks
|
||||
features
|
||||
):
|
||||
for i, sub_features in enumerate(features):
|
||||
# sub_features is a list of dicts, each dict = one sequence’s features
|
||||
# We'll merge them into out_features[i].
|
||||
#
|
||||
@@ -252,17 +243,10 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
||||
# For example, input_ids or labels are often arrays.
|
||||
arrays = []
|
||||
for feat in sub_features:
|
||||
if field_name in feat and isinstance(
|
||||
feat[field_name], (list, torch.Tensor)
|
||||
):
|
||||
if isinstance(
|
||||
feat[field_name][0], (dict, str)
|
||||
): # pylint: disable=too-many-nested-blocks
|
||||
continue
|
||||
if field_name in feat:
|
||||
arr = np.array(feat[field_name])
|
||||
arrays.append(arr)
|
||||
if arrays:
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
|
||||
# 3) Now call the parent collator, which will do:
|
||||
# - padding of labels/position_ids
|
||||
|
||||
@@ -1,561 +0,0 @@
|
||||
"""
|
||||
Packed data loader for online teacher training supporting vllm and sglang.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from orjson import orjson
|
||||
|
||||
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||
from axolotl.utils.data.utils import retry_on_request_exceptions
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):
|
||||
"""
|
||||
Create HMAC-SHA hash from a list of integers
|
||||
|
||||
Args:
|
||||
int_list: List of integers
|
||||
key: Secret key (string or bytes)
|
||||
hash_func: Hash function (default: sha256)
|
||||
|
||||
Returns:
|
||||
HMAC digest as hex string
|
||||
"""
|
||||
# Convert key to bytes if it's a string
|
||||
if isinstance(key, str):
|
||||
key = key.encode("utf-8")
|
||||
|
||||
# Convert list of ints to bytes
|
||||
# Method 1: Convert each int to bytes and concatenate
|
||||
data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list)
|
||||
|
||||
# Create HMAC
|
||||
h = hmac.new(key, data, hash_func)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
"""
|
||||
Collator for online teacher training.
|
||||
"""
|
||||
|
||||
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
kd_online_server_base_url: Optional[str] = None,
|
||||
kd_online_topk: Optional[int] = None,
|
||||
kd_temperature: Optional[float] = 1.0,
|
||||
kd_online_server: Optional[str] = "vllm",
|
||||
kd_online_timeout: Optional[int] = 120,
|
||||
kd_cache_dir: Optional[str] = None,
|
||||
kd_normalize_topk: Optional[bool] = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if kd_online_server_base_url is None:
|
||||
raise ValueError(
|
||||
"kd_online_server_base_url must be provided for OnlineTeacherDataloader"
|
||||
)
|
||||
if kd_online_topk is None or kd_online_topk <= 0:
|
||||
raise ValueError(
|
||||
"kd_online_topk must be a positive integer for OnlineTeacherDataloader"
|
||||
)
|
||||
|
||||
self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
|
||||
self.kd_online_topk = kd_online_topk
|
||||
self.kd_temperature = kd_temperature
|
||||
self.kd_online_server = kd_online_server
|
||||
self.http_session = requests.Session()
|
||||
self.kd_online_timeout = kd_online_timeout
|
||||
self.kd_cache_dir = kd_cache_dir
|
||||
self.kd_normalize_topk = kd_normalize_topk
|
||||
|
||||
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
|
||||
"""
|
||||
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
||||
"""
|
||||
if not raw_logprobs or self.kd_online_topk == 0:
|
||||
return (
|
||||
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
|
||||
)
|
||||
|
||||
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
|
||||
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
|
||||
|
||||
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||
def fetch_online_logprobs_sglang(
|
||||
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||
):
|
||||
"""
|
||||
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
|
||||
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||
"""
|
||||
api_endpoint = f"{self.kd_online_server_base_url}/generate"
|
||||
|
||||
payload = {
|
||||
"input_ids": batch_input_ids,
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": self.kd_online_topk,
|
||||
"logprob_start_len": 0,
|
||||
"return_text_in_logprobs": True,
|
||||
"echo": True,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 0,
|
||||
"temperature": self.kd_temperature,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
}
|
||||
|
||||
# Initialize with empty lists, so if API call fails, these are returned.
|
||||
ret_data_target_token_ids: List[List[List[int]]] = []
|
||||
ret_data_target_logprobs: List[List[List[float]]] = []
|
||||
ret_data_target_mask: List[List[List[int]]] = []
|
||||
|
||||
try:
|
||||
response = self.http_session.post(
|
||||
api_endpoint, json=payload, timeout=self.kd_online_timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
api_data: list[dict] = response.json()
|
||||
|
||||
# Ensure api_data is a list, and its length matches batch_input_ids
|
||||
if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids):
|
||||
LOG.error(
|
||||
f"API response format error. Expected a list of {len(batch_input_ids)} "
|
||||
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
||||
)
|
||||
# Return empty data; items processed later will get default empty KD fields
|
||||
return {
|
||||
"target_token_ids": ret_data_target_token_ids,
|
||||
"target_logprobs": ret_data_target_logprobs,
|
||||
"target_mask": ret_data_target_mask,
|
||||
}
|
||||
|
||||
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||
api_data, batch_input_ids, labels
|
||||
):
|
||||
current_target_logprobs = []
|
||||
current_target_token_ids = []
|
||||
current_target_mask = []
|
||||
|
||||
meta_info = sequence_data.pop("meta_info", {})
|
||||
# Ensure input_top_logprobs is a list
|
||||
input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
|
||||
"input_top_logprobs", []
|
||||
)
|
||||
if not isinstance(input_top_logprobs, list):
|
||||
LOG.warning(
|
||||
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
||||
)
|
||||
input_top_logprobs = [] # Treat as empty
|
||||
|
||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||
|
||||
for i, _, label in zip(
|
||||
range(len(seq_input_ids)), seq_input_ids, seq_labels
|
||||
):
|
||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||
# this is always the case for the first token.
|
||||
# there is never logprob data for the first token since that's a true input
|
||||
# so we replace the None value with padding data
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
elif (
|
||||
i < len(input_top_logprobs)
|
||||
and input_top_logprobs[i] is not None
|
||||
):
|
||||
pos_top_logprobs_data = input_top_logprobs[i]
|
||||
# Ensure pos_top_logprobs_data is a list of lists as expected
|
||||
if not (
|
||||
isinstance(pos_top_logprobs_data, list)
|
||||
and all(
|
||||
isinstance(item, list) for item in pos_top_logprobs_data
|
||||
)
|
||||
and len(pos_top_logprobs_data) > 0
|
||||
and len(pos_top_logprobs_data[0]) == 3
|
||||
): # [logprob, token_id, token_str]
|
||||
LOG.warning(
|
||||
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
continue
|
||||
|
||||
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
||||
pos_logprobs_raw, pos_token_ids, _ = [
|
||||
list(row) for row in zip(*pos_top_logprobs_data)
|
||||
]
|
||||
|
||||
# Ensure correct length (top_k)
|
||||
if len(pos_logprobs_raw) < self.kd_online_topk:
|
||||
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
||||
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
||||
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
||||
|
||||
# truncate to top_k in case the response was longer
|
||||
current_target_token_ids.append(
|
||||
pos_token_ids[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
if self.kd_normalize_topk:
|
||||
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
normalized_logprobs_for_position
|
||||
)
|
||||
else:
|
||||
current_target_logprobs.append(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
# Mask depends on the corresponding label for the student
|
||||
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
else:
|
||||
current_target_mask.append([1] * self.kd_online_topk)
|
||||
else:
|
||||
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
|
||||
ret_data_target_token_ids.append(current_target_token_ids)
|
||||
ret_data_target_logprobs.append(current_target_logprobs)
|
||||
ret_data_target_mask.append(current_target_mask)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||
raise e
|
||||
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
||||
except Exception as e: # Catch other potential errors during processing
|
||||
LOG.error(
|
||||
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
|
||||
return {
|
||||
"target_token_ids": ret_data_target_token_ids,
|
||||
"target_logprobs": ret_data_target_logprobs,
|
||||
"target_mask": ret_data_target_mask,
|
||||
}
|
||||
|
||||
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||
def fetch_online_logprobs_vllm(
|
||||
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||
):
|
||||
"""
|
||||
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
|
||||
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||
"""
|
||||
api_endpoint = f"{self.kd_online_server_base_url}/v1/completions"
|
||||
|
||||
payload = {
|
||||
"prompt": batch_input_ids,
|
||||
"echo": True,
|
||||
"logprobs": True,
|
||||
"prompt_logprobs": self.kd_online_topk,
|
||||
"top_logprobs": self.kd_online_topk,
|
||||
"max_new_tokens": 0,
|
||||
"skip_special_tokens": False,
|
||||
"temperature": self.kd_temperature,
|
||||
"sampling_params": {
|
||||
"max_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# Initialize with empty lists, so if API call fails, these are returned.
|
||||
ret_data_target_token_ids: List[List[List[int]]] = []
|
||||
ret_data_target_logprobs: List[List[List[float]]] = []
|
||||
ret_data_target_mask: List[List[List[int]]] = []
|
||||
|
||||
try:
|
||||
headers = {"Accept-Encoding": "deflate, gzip, br, zstd"}
|
||||
response = self.http_session.post(
|
||||
api_endpoint,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=self.kd_online_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
api_data: dict = orjson.loads(response.content)
|
||||
choices: list[dict] = api_data["choices"]
|
||||
|
||||
# Ensure api_data is a list, and its length matches batch_input_ids
|
||||
if not isinstance(choices, list) or len(choices) != len(batch_input_ids):
|
||||
LOG.error(
|
||||
f"API response format error. Expected a list of {len(batch_input_ids)} "
|
||||
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
||||
)
|
||||
# Return empty data; items processed later will get default empty KD fields
|
||||
return {
|
||||
"target_token_ids": ret_data_target_token_ids,
|
||||
"target_logprobs": ret_data_target_logprobs,
|
||||
"target_mask": ret_data_target_mask,
|
||||
}
|
||||
|
||||
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||
choices, batch_input_ids, labels
|
||||
):
|
||||
# seq_input_ids: List[int]
|
||||
# seq_labels: List[int]
|
||||
|
||||
current_target_logprobs = []
|
||||
current_target_token_ids = []
|
||||
current_target_mask = []
|
||||
|
||||
# Ensure input_top_logprobs is a list
|
||||
input_top_logprobs: Optional[list[None | dict[str, dict]]] = (
|
||||
sequence_data.pop("prompt_logprobs", [])
|
||||
)
|
||||
|
||||
if not isinstance(input_top_logprobs, list):
|
||||
LOG.warning(
|
||||
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
||||
)
|
||||
input_top_logprobs = [] # Treat as empty
|
||||
|
||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||
|
||||
seq_len = len(seq_input_ids)
|
||||
|
||||
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
|
||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||
# this is always the case for the first token.
|
||||
# there is never logprob data for the first token since that's a true input
|
||||
continue
|
||||
if (
|
||||
i < len(input_top_logprobs)
|
||||
and input_top_logprobs[i] is not None
|
||||
):
|
||||
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment]
|
||||
# Ensure pos_top_logprobs_data is a list of lists as expected
|
||||
if not (
|
||||
isinstance(pos_top_logprobs_data, dict)
|
||||
and all(
|
||||
isinstance(item, dict)
|
||||
for item in pos_top_logprobs_data.values()
|
||||
)
|
||||
and len(pos_top_logprobs_data.keys()) > 0
|
||||
): # [logprob, token_id, token_str]
|
||||
LOG.warning(
|
||||
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append(
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
continue
|
||||
|
||||
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
||||
pos_token_ids_str = list(pos_top_logprobs_data.keys())
|
||||
pos_logprobs_dict = pos_top_logprobs_data.values()
|
||||
pos_token_ids = [
|
||||
int(token_id) for token_id in pos_token_ids_str
|
||||
]
|
||||
pos_logprobs_raw = [
|
||||
float(logprob.get("logprob", -float("inf")))
|
||||
for logprob in pos_logprobs_dict
|
||||
]
|
||||
|
||||
# Ensure correct length (top_k)
|
||||
if len(pos_logprobs_raw) < self.kd_online_topk:
|
||||
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
||||
LOG.warning(
|
||||
f"Padding position {i} with {pad_len} top-k tokens and logprobs."
|
||||
)
|
||||
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
||||
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
||||
|
||||
# truncate to top_k in case the response was longer
|
||||
current_target_token_ids.append(
|
||||
pos_token_ids[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
if self.kd_normalize_topk:
|
||||
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
normalized_logprobs_for_position
|
||||
)
|
||||
else:
|
||||
current_target_logprobs.append(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
# Mask depends on the corresponding label for the student
|
||||
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
else:
|
||||
current_target_mask.append([1] * self.kd_online_topk)
|
||||
else:
|
||||
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append(
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
for i in range(max(0, seq_len - len(current_target_logprobs))):
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
|
||||
ret_data_target_token_ids.append(current_target_token_ids)
|
||||
ret_data_target_logprobs.append(current_target_logprobs)
|
||||
ret_data_target_mask.append(current_target_mask)
|
||||
|
||||
# TODO save and load targets to disk for caching for next epoch
|
||||
# generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int
|
||||
# if self.kd_cache_dir:
|
||||
# hash_input_ids = hmac_sha_from_int_list(
|
||||
# seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}"
|
||||
# )
|
||||
# with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f:
|
||||
# pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||
raise e
|
||||
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
||||
except Exception as e: # Catch other potential errors during processing
|
||||
LOG.error(
|
||||
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
|
||||
return {
|
||||
"target_token_ids": ret_data_target_token_ids,
|
||||
"target_logprobs": ret_data_target_logprobs,
|
||||
"target_mask": ret_data_target_mask,
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
if not features:
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
for (
|
||||
sub_batch_features
|
||||
) in features: # sub_batch_features is List[Dict[str, Any]]
|
||||
if not sub_batch_features:
|
||||
continue
|
||||
|
||||
input_ids_for_api_call: List[List[int]] = []
|
||||
labels_for_api_call: List[List[int]] = []
|
||||
# Store references to the original item dictionaries to update them in-place
|
||||
items_for_api_call: List[Dict[str, Any]] = []
|
||||
|
||||
for item_dict in sub_batch_features:
|
||||
if not isinstance(item_dict, dict):
|
||||
LOG.warning(
|
||||
f"Skipping non-dict item in sub_batch_features: {item_dict}"
|
||||
)
|
||||
continue
|
||||
|
||||
current_input_ids = item_dict.get("input_ids")
|
||||
current_labels = item_dict.get("labels")
|
||||
|
||||
if current_input_ids is not None and current_labels is not None:
|
||||
# Ensure input_ids and labels are lists of ints for JSON serialization
|
||||
input_ids_list = (
|
||||
current_input_ids.tolist()
|
||||
if hasattr(current_input_ids, "tolist")
|
||||
else list(current_input_ids)
|
||||
)
|
||||
labels_list = (
|
||||
current_labels.tolist()
|
||||
if hasattr(current_labels, "tolist")
|
||||
else list(current_labels)
|
||||
)
|
||||
|
||||
input_ids_for_api_call.append(input_ids_list)
|
||||
labels_for_api_call.append(labels_list)
|
||||
items_for_api_call.append(item_dict)
|
||||
else:
|
||||
# This item will not get teacher logprobs from the API.
|
||||
# Initialize KD fields to empty lists so downstream collators handle them uniformly.
|
||||
item_dict.setdefault("target_token_ids", [])
|
||||
item_dict.setdefault("target_logprobs", [])
|
||||
item_dict.setdefault("target_mask", [])
|
||||
|
||||
# print(items_for_api_call)
|
||||
if items_for_api_call: # Only call API if there's something to process
|
||||
if self.kd_online_server == "sglang":
|
||||
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
|
||||
input_ids_for_api_call, labels_for_api_call
|
||||
)
|
||||
else:
|
||||
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
|
||||
input_ids_for_api_call, labels_for_api_call
|
||||
)
|
||||
|
||||
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
|
||||
# Each value is a list, corresponding to items_for_api_call
|
||||
for i, item_to_update in enumerate(items_for_api_call):
|
||||
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
|
||||
if api_responses_for_sub_batch and i < len(
|
||||
api_responses_for_sub_batch["target_token_ids"]
|
||||
): # Check bounds
|
||||
assert len(
|
||||
api_responses_for_sub_batch["target_token_ids"][i]
|
||||
) == len(item_to_update["input_ids"])
|
||||
assert len(
|
||||
api_responses_for_sub_batch["target_logprobs"][i]
|
||||
) == len(item_to_update["input_ids"])
|
||||
assert len(
|
||||
api_responses_for_sub_batch["target_mask"][i]
|
||||
) == len(item_to_update["labels"])
|
||||
item_to_update["target_token_ids"] = (
|
||||
api_responses_for_sub_batch["target_token_ids"][i]
|
||||
)
|
||||
item_to_update["target_logprobs"] = api_responses_for_sub_batch[
|
||||
"target_logprobs"
|
||||
][i]
|
||||
item_to_update["target_mask"] = api_responses_for_sub_batch[
|
||||
"target_mask"
|
||||
][i]
|
||||
else:
|
||||
# API call failed for this item, or response was shorter than expected.
|
||||
# Ensure KD fields are initialized as empty lists.
|
||||
LOG.warning(
|
||||
f" (index {i}), or API response was too short. "
|
||||
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
|
||||
)
|
||||
item_to_update.setdefault("target_token_ids", [])
|
||||
item_to_update.setdefault("target_logprobs", [])
|
||||
item_to_update.setdefault("target_mask", [])
|
||||
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
Liger Chunked loss optimizations module
|
||||
"""
|
||||
|
||||
from .liger import LigerFusedLinearKLTopKLogprobLoss
|
||||
from .models import apply_kernel
|
||||
|
||||
__all__ = ["LigerFusedLinearKLTopKLogprobLoss", "apply_kernel"]
|
||||
|
||||
@@ -1,485 +0,0 @@
|
||||
"""
|
||||
Liger Kernels for Chunked Top-K Log-Prob Distillation
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from liger_kernel.chunked_loss.fused_linear_distillation import (
|
||||
LigerFusedLinearDistillationBase,
|
||||
)
|
||||
|
||||
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||
|
||||
|
||||
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
"""
|
||||
Chunked kl-div loss for top-k logprobs
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def distillation_loss_fn(
|
||||
student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled
|
||||
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
||||
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||
beta: float = 0.0,
|
||||
normalize_topk: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Top-K KL divergence loss for a chunk.
|
||||
Args:
|
||||
student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).
|
||||
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
|
||||
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
|
||||
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
|
||||
beta: Controls the type of KL divergence.
|
||||
0.0 for Forward KL (P_teacher || P_student).
|
||||
1.0 for Reverse KL (P_student || P_teacher).
|
||||
0.5 for Symmetric KL (average of Forward and Reverse).
|
||||
normalize_topk: Whether to normalize the log probabilities
|
||||
Returns:
|
||||
Sum of KL divergence losses for the chunk.
|
||||
"""
|
||||
topk = target_token_ids_chunk.shape[-1]
|
||||
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
|
||||
student_logits_temp_scaled.float()
|
||||
)
|
||||
target_logprobs_chunk = target_logprobs_chunk.float()
|
||||
|
||||
# Gather student logits for the top-k teacher token IDs
|
||||
# target_token_ids_chunk: [chunk_size, top_k]
|
||||
# student_logits_topk_temp_scaled: [chunk_size, top_k]
|
||||
student_logits_topk_temp_scaled = torch.gather(
|
||||
student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk
|
||||
)
|
||||
|
||||
# Student log-probabilities for the gathered top-k tokens
|
||||
student_lse = torch.logsumexp(
|
||||
student_logits_temp_scaled, dim=-1, keepdim=True
|
||||
) # [chunk_size, 1]
|
||||
student_logprobs_topk_temp_scaled = (
|
||||
student_logits_topk_temp_scaled - student_lse
|
||||
)
|
||||
|
||||
# we have the top-k student logprobs, normalize them
|
||||
if normalize_topk:
|
||||
student_logprobs_topk_temp_scaled = normalize_logprobs(
|
||||
student_logprobs_topk_temp_scaled, topk
|
||||
)
|
||||
|
||||
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
||||
|
||||
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
||||
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
|
||||
|
||||
# Teacher probabilities P(y|x_teacher) from logprobs
|
||||
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
|
||||
teacher_probs_valid = teacher_logprobs_valid.exp()
|
||||
# Student probabilities P_student from log P_student
|
||||
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
||||
|
||||
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
|
||||
|
||||
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
|
||||
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
|
||||
# The distillation loss is often formulated as -sum(P_teacher * log P_student)
|
||||
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
|
||||
# Here, target_logprobs_valid are log_softmax_teacher.
|
||||
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
||||
if beta == 0.0: # Contribution from Forward KL
|
||||
fwd_kl_per_token = teacher_probs_valid * (
|
||||
teacher_logprobs_valid - student_logprobs_topk_valid
|
||||
)
|
||||
kd_loss = fwd_kl_per_token.sum()
|
||||
elif beta == 1.0: # Contribution from Reverse KL
|
||||
rev_kl_per_token = student_probs_topk_valid * (
|
||||
student_logprobs_topk_valid - teacher_logprobs_valid
|
||||
)
|
||||
kd_loss = rev_kl_per_token.sum()
|
||||
else:
|
||||
# JSD - Jensen-Shannon Divergence / Symmetric
|
||||
mean_probs = (
|
||||
1 - beta
|
||||
) * student_probs_topk_valid + beta * teacher_probs_valid
|
||||
log_mean_probs = mean_probs.log()
|
||||
student_kl = F.kl_div(
|
||||
log_mean_probs,
|
||||
student_logprobs_topk_valid,
|
||||
reduction="sum",
|
||||
log_target=True,
|
||||
)
|
||||
teacher_kl = F.kl_div(
|
||||
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
|
||||
)
|
||||
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
||||
kd_loss = jsd_loss
|
||||
|
||||
return kd_loss
|
||||
|
||||
@staticmethod
|
||||
def _compute_loss_kl_topk(
|
||||
student_input_chunk: torch.Tensor,
|
||||
student_weight: torch.Tensor,
|
||||
# Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value
|
||||
# or through `partial`. Let's make them explicit here for clarity.
|
||||
target_token_ids_chunk: torch.Tensor,
|
||||
target_logprobs_chunk: torch.Tensor,
|
||||
target_mask_chunk: torch.Tensor,
|
||||
target_chunk: torch.Tensor, # For hard loss (true labels)
|
||||
student_bias: torch.Tensor = None, # This will be one of the grad targets
|
||||
# Other params passed via `partial` from `forward`
|
||||
distillation_loss_fn=None,
|
||||
ignore_index: int = -100,
|
||||
weight_hard_loss: float = 0.5,
|
||||
weight_soft_loss: float = 0.5,
|
||||
compute_ce_loss: bool = True,
|
||||
temperature: float = 1.0,
|
||||
beta: float = 0.0,
|
||||
normalize_topk: bool = True,
|
||||
):
|
||||
# Compute student logits for the chunk from hidden states and LM head
|
||||
# student_input_chunk: [chunk_size, hidden_dim]
|
||||
# student_lm_head_weight: [vocab_size, hidden_dim]
|
||||
# student_logits_chunk: [chunk_size, vocab_size]
|
||||
student_logits_chunk = F.linear(
|
||||
student_input_chunk, student_weight, student_bias
|
||||
)
|
||||
|
||||
ce_loss = torch.tensor(
|
||||
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
|
||||
)
|
||||
if compute_ce_loss and weight_hard_loss > 0.0:
|
||||
ce_loss = F.cross_entropy(
|
||||
student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),
|
||||
target_chunk.view(-1),
|
||||
reduction="sum",
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
soft_loss = torch.tensor(
|
||||
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
|
||||
)
|
||||
if weight_soft_loss > 0.0:
|
||||
student_logits_chunk_temp_scaled = student_logits_chunk / temperature
|
||||
|
||||
# Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()
|
||||
# No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.
|
||||
|
||||
soft_loss = distillation_loss_fn(
|
||||
student_logits_chunk_temp_scaled,
|
||||
target_token_ids_chunk,
|
||||
target_logprobs_chunk,
|
||||
target_mask_chunk,
|
||||
beta=beta,
|
||||
normalize_topk=normalize_topk,
|
||||
)
|
||||
|
||||
return soft_loss, ce_loss
|
||||
|
||||
@classmethod
|
||||
def forward(
|
||||
cls,
|
||||
ctx,
|
||||
student_input: torch.Tensor, # [batch_size, seq_len, dim]
|
||||
student_lm_head_weight: torch.Tensor, # [dim, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||
target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||
target_mask: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||
true_labels: torch.Tensor, # [batch_size, seq_len]
|
||||
student_lm_head_bias: torch.Tensor = None,
|
||||
weight_hard_loss: float = 0.5,
|
||||
weight_soft_loss: float = 0.5,
|
||||
ignore_index: int = -100,
|
||||
temperature: float = 1.0,
|
||||
beta: float = 0.0,
|
||||
compiled: bool = False,
|
||||
chunk_size: int = 1024,
|
||||
compute_ce_loss: bool = True,
|
||||
normalize_topk: bool = True,
|
||||
):
|
||||
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
|
||||
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
|
||||
grad_inputs_list = []
|
||||
grad_bias_acc = (
|
||||
torch.zeros_like(student_lm_head_bias)
|
||||
if student_lm_head_bias is not None
|
||||
else None
|
||||
)
|
||||
kd_loss_acc = torch.zeros(
|
||||
(), device=student_input.device, dtype=student_input.dtype
|
||||
)
|
||||
ce_loss_acc = torch.zeros(
|
||||
(), device=student_input.device, dtype=student_input.dtype
|
||||
)
|
||||
|
||||
# This function will be what torch.func.grad_and_value differentiates.
|
||||
# It takes student_input_chunk, student_weight (full), student_bias (full) as primals.
|
||||
# Other necessary data (target_*, etc.) are passed as non-differentiable arguments.
|
||||
def loss_fn_for_grad(
|
||||
_student_input_chunk,
|
||||
_student_lm_head_weight, # full weight
|
||||
_student_lm_head_bias, # full bias
|
||||
# Fixed arguments for a given chunk, not differentiated:
|
||||
_target_token_ids_chunk,
|
||||
_target_logprobs_chunk,
|
||||
_target_mask_chunk,
|
||||
_true_labels_chunk,
|
||||
):
|
||||
return cls._compute_loss_kl_topk(
|
||||
student_input_chunk=_student_input_chunk,
|
||||
student_weight=_student_lm_head_weight,
|
||||
target_token_ids_chunk=_target_token_ids_chunk,
|
||||
target_logprobs_chunk=_target_logprobs_chunk,
|
||||
target_mask_chunk=_target_mask_chunk,
|
||||
target_chunk=_true_labels_chunk,
|
||||
student_bias=_student_lm_head_bias,
|
||||
distillation_loss_fn=cls.distillation_loss_fn,
|
||||
ignore_index=ignore_index,
|
||||
weight_hard_loss=weight_hard_loss,
|
||||
weight_soft_loss=weight_soft_loss,
|
||||
compute_ce_loss=compute_ce_loss,
|
||||
temperature=temperature,
|
||||
beta=beta,
|
||||
normalize_topk=normalize_topk,
|
||||
)
|
||||
|
||||
def accumulate_chunk_grads(
|
||||
student_input_chunk_ac,
|
||||
target_token_ids_chunk_ac,
|
||||
target_logprobs_chunk_ac,
|
||||
target_mask_chunk_ac,
|
||||
true_labels_chunk_ac,
|
||||
):
|
||||
# student_weight and student_bias are closed over from the outer scope (full tensors)
|
||||
if student_lm_head_bias is not None:
|
||||
(
|
||||
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
||||
(chunk_kd_loss, chunk_ce_loss),
|
||||
) = torch.func.grad_and_value(
|
||||
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
|
||||
)(
|
||||
student_input_chunk_ac,
|
||||
student_lm_head_weight,
|
||||
student_lm_head_bias, # primals
|
||||
target_token_ids_chunk_ac,
|
||||
target_logprobs_chunk_ac,
|
||||
target_mask_chunk_ac,
|
||||
true_labels_chunk_ac,
|
||||
) # non-primals
|
||||
grad_bias_acc.add_(chunk_grad_bias)
|
||||
else:
|
||||
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
|
||||
(
|
||||
(chunk_grad_input, chunk_grad_weight), # No grad for bias
|
||||
(chunk_kd_loss, chunk_ce_loss),
|
||||
) = torch.func.grad_and_value(
|
||||
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
|
||||
)(
|
||||
student_input_chunk_ac,
|
||||
student_lm_head_weight,
|
||||
None, # Pass None for student_bias primal
|
||||
target_token_ids_chunk_ac,
|
||||
target_logprobs_chunk_ac,
|
||||
target_mask_chunk_ac,
|
||||
true_labels_chunk_ac,
|
||||
)
|
||||
|
||||
grad_weight_acc.add_(chunk_grad_weight)
|
||||
kd_loss_acc.add_(chunk_kd_loss)
|
||||
ce_loss_acc.add_(chunk_ce_loss)
|
||||
|
||||
return chunk_grad_input
|
||||
|
||||
if compiled:
|
||||
accumulate_chunk_grads_compiled = torch.compile(
|
||||
accumulate_chunk_grads, dynamic=True, backend="inductor"
|
||||
) # dynamic=True often helpful
|
||||
else:
|
||||
accumulate_chunk_grads_compiled = accumulate_chunk_grads
|
||||
|
||||
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
|
||||
B, N, D = student_input.shape # pylint: disable=invalid-name
|
||||
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
|
||||
|
||||
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
|
||||
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
|
||||
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
|
||||
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
|
||||
# pad and shift for cross entropy loss
|
||||
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
|
||||
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
|
||||
|
||||
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
|
||||
|
||||
_student_input_chunks = torch.chunk(
|
||||
student_input_flat, chunks=num_chunks, dim=0
|
||||
)
|
||||
_target_token_ids_chunks = torch.chunk(
|
||||
target_token_ids_flat, chunks=num_chunks, dim=0
|
||||
)
|
||||
_target_logprobs_chunks = torch.chunk(
|
||||
target_logprobs_flat, chunks=num_chunks, dim=0
|
||||
)
|
||||
_target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)
|
||||
_true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)
|
||||
|
||||
for i in range(num_chunks):
|
||||
grad_input_chunk = accumulate_chunk_grads_compiled(
|
||||
_student_input_chunks[i],
|
||||
_target_token_ids_chunks[i],
|
||||
_target_logprobs_chunks[i],
|
||||
_target_mask_chunks[i],
|
||||
_true_labels_chunks[i],
|
||||
)
|
||||
grad_inputs_list.append(grad_input_chunk)
|
||||
|
||||
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
|
||||
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||
|
||||
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
|
||||
ctx.bias_was_none = student_lm_head_bias is None
|
||||
ctx.orig_dims = (B, N, D, K)
|
||||
|
||||
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
|
||||
# we still need to scale the kd_loss by the temp^2
|
||||
kd_loss_acc = kd_loss_acc * (temperature**2)
|
||||
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
||||
|
||||
return final_loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input_flat, grad_weight, grad_bias_maybe = (
|
||||
ctx.saved_tensors
|
||||
) # grad_input_flat is (B*N, D)
|
||||
|
||||
# Scale gradients by grad_output if it's not 1.0
|
||||
if not torch.equal(
|
||||
grad_output,
|
||||
torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),
|
||||
):
|
||||
grad_input_flat = grad_input_flat * grad_output
|
||||
grad_weight = grad_weight * grad_output
|
||||
if grad_bias_maybe is not None:
|
||||
grad_bias_maybe = grad_bias_maybe * grad_output
|
||||
|
||||
# Reshape grad_input_flat to match original student_input shape (B, N, D)
|
||||
# ctx.orig_dims stores (B, N, D, K)
|
||||
# We need the first three dimensions for student_input's shape.
|
||||
# Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors
|
||||
if (
|
||||
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
|
||||
and grad_input_flat.numel() == 0
|
||||
):
|
||||
# If original input was empty, gradient should also be empty with correct shape
|
||||
grad_input_reshaped = torch.zeros(
|
||||
ctx.orig_dims[0],
|
||||
ctx.orig_dims[1],
|
||||
ctx.orig_dims[2],
|
||||
dtype=grad_input_flat.dtype,
|
||||
device=grad_input_flat.device,
|
||||
)
|
||||
elif grad_input_flat.numel() == 0 and not (
|
||||
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
|
||||
):
|
||||
# This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)
|
||||
# but as a safeguard:
|
||||
grad_input_reshaped = torch.zeros(
|
||||
ctx.orig_dims[0],
|
||||
ctx.orig_dims[1],
|
||||
ctx.orig_dims[2],
|
||||
dtype=grad_input_flat.dtype,
|
||||
device=grad_input_flat.device,
|
||||
)
|
||||
else:
|
||||
grad_input_reshaped = grad_input_flat.view(
|
||||
ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]
|
||||
)
|
||||
|
||||
nones_for_hyperparams = [None] * ctx.hyperparams_count
|
||||
grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None
|
||||
|
||||
return (
|
||||
grad_input_reshaped, # Gradient for student_input (reshaped)
|
||||
grad_weight, # Gradient for student_lm_head_weight
|
||||
None, # Gradient for target_token_ids
|
||||
None, # Gradient for target_logprobs
|
||||
None, # Gradient for target_mask
|
||||
None, # Gradient for true_labels
|
||||
grad_bias_return, # Gradient for student_lm_head_bias
|
||||
*nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss
|
||||
)
|
||||
|
||||
|
||||
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||
"""
|
||||
wrapper for chunked top-k logprob kl-d
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_hard_loss: float = 0.5,
|
||||
weight_soft_loss: float = 0.5,
|
||||
temperature: float = 1.0, # This is the kd_temperature
|
||||
beta: float = 1.0,
|
||||
ignore_index: int = -100,
|
||||
compiled: bool = True,
|
||||
chunk_size: int = 1024,
|
||||
compute_ce_loss: bool = True,
|
||||
normalize_topk: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
|
||||
raise ValueError("Loss weights must be between 0.0 and 1.0.")
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be positive.")
|
||||
|
||||
self.weight_hard_loss = weight_hard_loss
|
||||
self.weight_soft_loss = weight_soft_loss
|
||||
self.temperature = temperature
|
||||
self.beta = beta
|
||||
self.ignore_index = ignore_index
|
||||
self.compiled = compiled
|
||||
self.chunk_size = chunk_size
|
||||
self.compute_ce_loss = compute_ce_loss
|
||||
self.normalize_topk = normalize_topk
|
||||
|
||||
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
|
||||
print(
|
||||
f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero."
|
||||
)
|
||||
# self.weight_hard_loss = 0.0 # Or let user manage this
|
||||
if self.weight_soft_loss == 0.0:
|
||||
print(
|
||||
"Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head
|
||||
student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head
|
||||
target_token_ids: torch.Tensor,
|
||||
target_logprobs: torch.Tensor,
|
||||
target_mask: torch.Tensor,
|
||||
true_labels: torch.Tensor,
|
||||
student_bias: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
return LigerFusedLinearKLTopKLogprobFunction.apply(
|
||||
student_hidden_states,
|
||||
lm_head_weight,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
true_labels,
|
||||
student_bias,
|
||||
self.weight_hard_loss,
|
||||
self.weight_soft_loss,
|
||||
self.ignore_index,
|
||||
self.temperature,
|
||||
self.beta,
|
||||
self.compiled,
|
||||
self.chunk_size,
|
||||
self.compute_ce_loss,
|
||||
self.normalize_topk,
|
||||
)
|
||||
@@ -1,97 +0,0 @@
|
||||
"""
|
||||
model patcher for chunked top-k kl-div
|
||||
"""
|
||||
|
||||
from typing import Optional, Union, Unpack
|
||||
|
||||
import torch
|
||||
from transformers import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import LossKwargs
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
||||
"""
|
||||
placeholder kwargs for hf model classes
|
||||
"""
|
||||
|
||||
|
||||
def kldiv_forward_llama_like(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
target_logprobs: Optional[torch.Tensor] = None,
|
||||
target_token_ids: Optional[torch.LongTensor] = None,
|
||||
target_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
|
||||
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
|
||||
) -> CausalLMOutputWithPast:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
loss = self.loss_function(
|
||||
self.lm_head.weight,
|
||||
hidden_states,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
true_labels=labels,
|
||||
)
|
||||
num_items_in_batch = kwargs.pop("num_items_in_batch", -1)
|
||||
if num_items_in_batch is not None and num_items_in_batch > 0:
|
||||
loss = loss / num_items_in_batch
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=None,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_kernel(model_type):
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||
model_cls.forward = kldiv_forward_llama_like
|
||||
@@ -16,7 +16,40 @@
|
||||
loss for top_k KL divergence
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def zscore_standardize(
|
||||
logits: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
base_temperature: float = 1.0,
|
||||
eps: float = 1e-9,
|
||||
):
|
||||
"""
|
||||
Z-score standardize along the last dimension of `logits`.
|
||||
i.e., for each [B, seq_len] row, across K entries:
|
||||
z = (logits - mean) / std,
|
||||
then scale by 1 / base_temperature if desired.
|
||||
|
||||
mask can be broadcastable or None. If None, we standardize all elements.
|
||||
"""
|
||||
if mask is None:
|
||||
# shape: [B, seq_len, K]
|
||||
# Mean and std over dim=-1
|
||||
mean = logits.mean(dim=-1, keepdim=True)
|
||||
var = logits.var(dim=-1, unbiased=False, keepdim=True)
|
||||
else:
|
||||
# If you have to exclude some tokens, multiply by mask, etc.
|
||||
float_mask = mask.to(logits.dtype)
|
||||
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
|
||||
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
|
||||
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
|
||||
|
||||
std = torch.sqrt(var.clamp_min(eps))
|
||||
z = (logits - mean) / std
|
||||
|
||||
# Scale by 1 / base_temperature
|
||||
z = z / base_temperature
|
||||
return z
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -27,6 +60,7 @@ def loss(
|
||||
target_mask: torch.Tensor,
|
||||
num_items_in_batch: int = -1, # Use -1 to indicate "None"
|
||||
kd_temperature: float = 1.0,
|
||||
top_k_before_softmax: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A KD loss function that is TorchScript-friendly.
|
||||
@@ -43,6 +77,8 @@ def loss(
|
||||
num_items_in_batch (int, optional): The number of items in the batch.
|
||||
kd_temperature (float, optional): The temperature for KD.
|
||||
Default: 1.0
|
||||
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
|
||||
Default: 0
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
@@ -52,24 +88,46 @@ def loss(
|
||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = (
|
||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||
) # [B, teacher_seq_len, vocab_size]
|
||||
if top_k_before_softmax:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# keep in full precision for numerical stability of loss
|
||||
student_logits_for_kd = student_logits_for_kd.float()
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
student_logits_topk = student_logits_topk.float()
|
||||
|
||||
# Compute logsumexp across full vocabulary
|
||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||
# Apply KD temperature to student’s logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_topk = student_logits_topk / kd_temperature
|
||||
|
||||
# Convert just the top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - student_lse
|
||||
# Convert student top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||
student_logits_topk, dim=-1, keepdim=True
|
||||
) # [B, teacher_seq_len, K]
|
||||
else:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = (
|
||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||
) # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# keep in full precision for numerical stability of loss
|
||||
student_logits_for_kd = student_logits_for_kd.float()
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Compute logsumexp across full vocabulary
|
||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||
|
||||
# Convert just the top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - student_lse
|
||||
|
||||
# Convert teacher_mask to boolean for indexing
|
||||
# In TorchScript, .bool() is sometimes unsupported, so we do:
|
||||
@@ -86,6 +144,10 @@ def loss(
|
||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# Multiply by T^2 (classical KD scaling)
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items (if provided) or by valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
@@ -96,74 +158,80 @@ def loss(
|
||||
return kd_loss
|
||||
|
||||
|
||||
class ChunkedTopKKDLoss(nn.Module):
|
||||
def topk_kd_loss_with_zscore(
|
||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
||||
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
||||
kd_temperature: float = 1.0, # classic KD temperature
|
||||
zscore_base_temp: float = 1.0, # from the paper
|
||||
num_items_in_batch: int = -1,
|
||||
):
|
||||
"""
|
||||
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
|
||||
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
|
||||
|
||||
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
|
||||
A variant of top_k KL divergence with Z-score scaling
|
||||
from "Logit Standardization in Knowledge Distillation".
|
||||
"""
|
||||
|
||||
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
|
||||
super().__init__()
|
||||
self.num_output_chunks = num_output_chunks
|
||||
self.kd_temperature = kd_temperature
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||
target_logprobs: torch.Tensor, # [B, seq_len, K]
|
||||
target_mask: torch.Tensor, # [B, seq_len, K]
|
||||
num_items_in_batch: int = -1, # optional batch size for normalization
|
||||
) -> torch.Tensor:
|
||||
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
|
||||
# 1) Gather the student's top-k logits to match teacher
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, seq_len, vocab]
|
||||
student_topk_logits = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, seq_len, K]
|
||||
|
||||
# 1. Split along the "token" dimension (dim=1).
|
||||
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
|
||||
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
|
||||
logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)
|
||||
mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)
|
||||
student_topk_logits = student_topk_logits.float()
|
||||
|
||||
# We'll accumulate a global "sum of losses" and "sum of valid tokens"
|
||||
# so that our final average is consistent with the entire sequence/batch.
|
||||
total_loss = 0.0
|
||||
total_valid_tokens = 0
|
||||
# 2) If you want to keep the "classical" T scaling, apply it first
|
||||
if kd_temperature != 1.0:
|
||||
student_topk_logits = student_topk_logits / kd_temperature
|
||||
|
||||
# 2. Loop over each chunk and compute a chunk-specific loss.
|
||||
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
|
||||
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
|
||||
):
|
||||
# We pass num_items_in_batch=-1 so that the kd_loss
|
||||
# will average over *this chunk's* valid tokens only.
|
||||
chunk_loss = loss(
|
||||
student_logits=st_chunk,
|
||||
target_token_ids=tid_chunk,
|
||||
target_logprobs=lp_chunk,
|
||||
target_mask=msk_chunk,
|
||||
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
|
||||
kd_temperature=self.kd_temperature,
|
||||
)
|
||||
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
||||
# (They differ by +some_constant from real logits, but in z-score
|
||||
# that constant is subtracted out anyway.)
|
||||
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
|
||||
|
||||
# kd_loss returns an average over the chunk's valid tokens.
|
||||
# We want a global average in the end, so we need to re‐weight
|
||||
# by the number of valid tokens in this chunk and keep track of the total.
|
||||
chunk_valid_mask = msk_chunk.to(torch.bool)
|
||||
chunk_valid_count = chunk_valid_mask.sum() # scalar tensor
|
||||
# 4) Z-score teacher and student
|
||||
# If target_mask is 2D, expand to 3D for the K dimension
|
||||
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
|
||||
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
|
||||
|
||||
# Re-scale "chunk average" back to "chunk sum"
|
||||
chunk_loss_sum = chunk_loss * chunk_valid_count
|
||||
teacher_z = zscore_standardize(
|
||||
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
student_z = zscore_standardize(
|
||||
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
|
||||
total_loss += chunk_loss_sum
|
||||
total_valid_tokens += chunk_valid_count
|
||||
# 5) Convert to log-probs for KL
|
||||
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
|
||||
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
||||
|
||||
# 3. Normalize *once* at the end.
|
||||
if num_items_in_batch > 0:
|
||||
# If the user gave us a manual denominator (e.g. total items in batch),
|
||||
# we divide by it. Typically used if each item is of different length.
|
||||
final_loss = total_loss / float(num_items_in_batch)
|
||||
else:
|
||||
# Otherwise, divide by total valid tokens across all chunks.
|
||||
# to get the same result as a non-chunked approach.
|
||||
final_loss = total_loss / float(total_valid_tokens)
|
||||
# 6) Restrict to valid tokens if needed
|
||||
valid_mask = target_mask.bool() # shape [B, seq_len, K]
|
||||
teacher_probs_z = teacher_logprobs_z.exp()
|
||||
teacher_probs_z = teacher_probs_z[valid_mask]
|
||||
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
||||
student_logprobs_z = student_logprobs_z[valid_mask]
|
||||
|
||||
return final_loss
|
||||
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
|
||||
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# 8) If using classical KD scaling by T^2
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
|
||||
# kd_loss = kd_loss * (zscore_base_temp**2)
|
||||
|
||||
# 9) Normalize
|
||||
if num_items_in_batch is not None and num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
||||
|
||||
return kd_loss
|
||||
|
||||
@@ -18,7 +18,8 @@ KD trainer
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
@@ -26,18 +27,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_accepts_loss_kwargs = True
|
||||
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
|
||||
self.args.kd_ce_alpha, # hard label loss
|
||||
self.args.kd_alpha, # kd loss
|
||||
self.args.kd_temperature,
|
||||
self.args.kd_beta or 0.0,
|
||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||
normalize_topk=self.args.kd_normalize_topk,
|
||||
)
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
columns_to_add = []
|
||||
@@ -63,12 +52,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and hasattr(inputs, "attention_mask")
|
||||
and hasattr(inputs, "position_ids")
|
||||
):
|
||||
del inputs["attention_mask"]
|
||||
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
|
||||
seq_len = target_token_ids.shape[1]
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
@@ -76,4 +65,49 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
return outputs[0]
|
||||
|
||||
# FIXME: account for tokenizer.padding_side
|
||||
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
|
||||
|
||||
shift_logits = student_logits.contiguous()
|
||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
|
||||
if self.args.kd_zscore_base_temp:
|
||||
loss_kd = topk_kd_loss_with_zscore(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
zscore_base_temp=self.args.kd_zscore_base_temp,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_kd = topk_kd_loss(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
kd_alpha = self.args.kd_alpha
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
||||
self.args.past_index
|
||||
]
|
||||
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss *= self.accelerator.num_processes
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Helper KD utils"""
|
||||
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import FloatTensor, Tensor
|
||||
|
||||
|
||||
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
||||
"""
|
||||
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
||||
"""
|
||||
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
|
||||
# This should ideally be handled by the caller ensuring correct padding/truncation first
|
||||
if logprobs.shape[-1] != topk:
|
||||
# pad last dimension of logprobs to match topk length with -inf
|
||||
padding_len = topk - logprobs.shape[-1]
|
||||
padding_tensor = torch.full(
|
||||
(
|
||||
*logprobs.shape[:-1],
|
||||
padding_len,
|
||||
), # Takes all dimensions of logprobs except the last, then appends padding_needed
|
||||
float("-inf"),
|
||||
dtype=logprobs.dtype,
|
||||
device=logprobs.device,
|
||||
)
|
||||
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
|
||||
|
||||
# Convert logprobs at T_online to probabilities
|
||||
# use log sum exp trick to avoid underflow
|
||||
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
|
||||
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
|
||||
|
||||
# Normalize probabilities (sum to 1)
|
||||
# This is important if the top-k from server aren't a full distribution
|
||||
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
|
||||
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
|
||||
|
||||
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||
|
||||
return final_logprobs_tensor
|
||||
|
||||
|
||||
def strided_chunk_views(
|
||||
tensor: Union[np.ndarray, torch.Tensor],
|
||||
chunks: int,
|
||||
dim: int = 0,
|
||||
stride: int = 1,
|
||||
chunk_size: int | None = None,
|
||||
) -> List[Union[np.ndarray, torch.Tensor]]:
|
||||
"""
|
||||
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor (numpy array or torch tensor)
|
||||
chunks: Number of chunks to create
|
||||
dim: Dimension along which to chunk (default: 0)
|
||||
stride: Stride between chunk starting positions (default: 1)
|
||||
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
|
||||
|
||||
Returns:
|
||||
List of tensor chunks (views when possible, copies when necessary)
|
||||
"""
|
||||
|
||||
# Get the size of the specified dimension
|
||||
dim_size = tensor.shape[dim]
|
||||
|
||||
# Calculate chunk size if not provided
|
||||
if chunk_size is None:
|
||||
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
|
||||
|
||||
chunks_list = []
|
||||
|
||||
for i in range(chunks):
|
||||
start_idx = i * stride
|
||||
end_idx = min(start_idx + chunk_size, dim_size)
|
||||
|
||||
# Break if we've gone beyond the tensor
|
||||
if start_idx >= dim_size:
|
||||
break
|
||||
|
||||
# Create slice objects for all dimensions
|
||||
slices = [slice(None)] * tensor.ndim
|
||||
slices[dim] = slice(start_idx, end_idx)
|
||||
|
||||
chunk = tensor[tuple(slices)]
|
||||
chunks_list.append(chunk)
|
||||
|
||||
return chunks_list
|
||||
|
||||
|
||||
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
|
||||
dim_size = input_tensor.shape[dim]
|
||||
stride = math.ceil(dim_size / chunks)
|
||||
|
||||
return strided_chunk_views(
|
||||
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
|
||||
)
|
||||
@@ -166,17 +166,6 @@ class PatchManager:
|
||||
def _apply_self_attention_lora_patch(self):
|
||||
"""Apply self-attention LoRA patches if configured."""
|
||||
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||
# Only patch if conditions are met
|
||||
can_patch = (
|
||||
self.cfg.lora_dropout == 0
|
||||
if hasattr(self.cfg, "lora_dropout")
|
||||
else True
|
||||
) # default to True if lora_dropout is not set
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch self-attention - requires no dropout")
|
||||
return
|
||||
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
@@ -2,19 +2,25 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mistral_common.protocol.instruct.messages import SystemMessage, UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer,
|
||||
)
|
||||
from transformers import (
|
||||
AddedToken,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
)
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
is_local_main_process,
|
||||
@@ -25,253 +31,622 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
# Constants
|
||||
LLAMA_TOKENIZER_CLASSES = {
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
"CodeLlamaTokenizerFast",
|
||||
}
|
||||
|
||||
def modify_tokenizer_files(
|
||||
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
|
||||
) -> str:
|
||||
FAST_LLAMA_TOKENIZER_CLASSES = {"LlamaTokenizerFast", "CodeLlamaTokenizerFast"}
|
||||
|
||||
QWEN_DEFAULT_TOKEN = "<|endoftext|>"
|
||||
GPTNEOX_PAD_TOKEN = "[PAD]"
|
||||
CHATML_DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
|
||||
|
||||
|
||||
class MistralTokenizerWrapper:
|
||||
"""
|
||||
Modify tokenizer files to replace added_tokens strings, save to output directory,
|
||||
and return the path to the modified tokenizer.
|
||||
|
||||
This only works with reserved tokens that were added to the tokenizer, not tokens
|
||||
already part of the vocab.
|
||||
|
||||
Args:
|
||||
tokenizer_path: Path or name of the original tokenizer
|
||||
token_mappings: Dict mapping {token_id (int): new_token_string}
|
||||
output_dir: Directory to save the modified tokenizer
|
||||
|
||||
Returns:
|
||||
Path to the modified tokenizer directory
|
||||
|
||||
Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941
|
||||
Wrapper to make MistralTokenizer compatible with Hugging Face tokenizer interface.
|
||||
This provides a bridge between Mistral's native tokenizer and axolotl's expectations.
|
||||
"""
|
||||
# Create the tokenizer directory in output_dir if it doesn't exist
|
||||
tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
|
||||
if is_local_main_process(): # pylint: disable=too-many-nested-blocks
|
||||
# Load the tokenizer
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
|
||||
def __init__(self, mistral_tokenizer: "MistralTokenizer", model_id: str):
|
||||
self.mistral_tokenizer = mistral_tokenizer
|
||||
self.model_id = model_id
|
||||
self._system_prompt = None
|
||||
self.padding_side = "right" # Default padding side
|
||||
self.chat_template = None
|
||||
|
||||
# Save the tokenizer to the output directory
|
||||
temp_tokenizer.save_pretrained(tokenizer_dir)
|
||||
# Cache token IDs by inspecting the actual tokenizer
|
||||
self._token_ids = self._discover_token_ids()
|
||||
|
||||
# Get the token IDs and map them to their new values
|
||||
# Try to load system prompt if available
|
||||
try:
|
||||
self._system_prompt = self._load_system_prompt(
|
||||
model_id, "SYSTEM_PROMPT.txt"
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.debug(f"Could not load system prompt: {e}")
|
||||
|
||||
def _discover_token_ids(self) -> Dict[str, int]:
|
||||
"""Discover the actual token IDs used by this Mistral tokenizer."""
|
||||
token_ids = {}
|
||||
|
||||
try:
|
||||
if hasattr(self.mistral_tokenizer, "instruct_tokenizer"):
|
||||
instruct_tokenizer = self.mistral_tokenizer.instruct_tokenizer
|
||||
|
||||
# Get BOS token ID from instruct_tokenizer
|
||||
token_ids["bos_token_id"] = getattr(instruct_tokenizer, "BOS", 1)
|
||||
|
||||
# Get token IDs from the underlying Tekkenizer
|
||||
if hasattr(instruct_tokenizer, "tokenizer"):
|
||||
tekkenizer = instruct_tokenizer.tokenizer
|
||||
|
||||
# Get BOS ID from tekkenizer (should match instruct_tokenizer.BOS)
|
||||
if hasattr(tekkenizer, "bos_id"):
|
||||
token_ids["bos_token_id"] = tekkenizer.bos_id
|
||||
|
||||
# Get vocab size to help find EOS token
|
||||
vocab_size = getattr(tekkenizer, "_vocab_size", None)
|
||||
|
||||
# Check special tokens
|
||||
if hasattr(tekkenizer, "_all_special_tokens"):
|
||||
special_tokens = tekkenizer._all_special_tokens
|
||||
keys = (
|
||||
list(special_tokens.keys())
|
||||
if hasattr(special_tokens, "keys")
|
||||
else special_tokens
|
||||
)
|
||||
LOG.debug(f"Special tokens available: {keys}")
|
||||
|
||||
# Try to find EOS token in special tokens
|
||||
if hasattr(special_tokens, "get"):
|
||||
# Common EOS token patterns
|
||||
for eos_pattern in ["</s>", "<|endoftext|>", "eos", "EOS"]:
|
||||
if eos_pattern in special_tokens:
|
||||
token_ids["eos_token_id"] = special_tokens[
|
||||
eos_pattern
|
||||
]
|
||||
break
|
||||
|
||||
# Check special tokens reverse vocab
|
||||
if hasattr(tekkenizer, "_special_tokens_reverse_vocab"):
|
||||
reverse_vocab = tekkenizer._special_tokens_reverse_vocab
|
||||
LOG.debug(f"Reverse special tokens: {reverse_vocab}")
|
||||
|
||||
# Look for common special token IDs
|
||||
for token_id, token_str in reverse_vocab.items():
|
||||
if token_str in ["</s>", "<|endoftext|>"]:
|
||||
token_ids["eos_token_id"] = token_id
|
||||
elif token_str in ["<unk>", "<UNK>"]:
|
||||
token_ids["unk_token_id"] = token_id
|
||||
|
||||
# If we have vocab_size, EOS is often vocab_size - 1 or similar
|
||||
if "eos_token_id" not in token_ids and vocab_size:
|
||||
# Common patterns: EOS could be 2, vocab_size-1, or other values
|
||||
# Let's try a safer approach by checking what tokens decode to
|
||||
for candidate_id in [2, vocab_size - 1, vocab_size - 2]:
|
||||
try:
|
||||
# Try to decode and see if it looks like EOS
|
||||
decoded = tekkenizer.decode([candidate_id])
|
||||
if decoded in ["</s>", "<|endoftext|>", ""]:
|
||||
token_ids["eos_token_id"] = candidate_id
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
LOG.debug(f"Could not discover token IDs: {e}")
|
||||
|
||||
# Set reasonable defaults for any missing token IDs
|
||||
token_ids.setdefault("bos_token_id", 1)
|
||||
token_ids.setdefault("eos_token_id", 2)
|
||||
token_ids.setdefault("unk_token_id", 0)
|
||||
token_ids.setdefault(
|
||||
"pad_token_id", token_ids["eos_token_id"]
|
||||
) # Use EOS as pad
|
||||
|
||||
LOG.info(f"Discovered Mistral token IDs: {token_ids}")
|
||||
return token_ids
|
||||
|
||||
def _load_system_prompt(self, repo_id: str, filename: str) -> str:
|
||||
"""Load system prompt from HuggingFace Hub"""
|
||||
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
with open(file_path, "r") as file:
|
||||
return file.read()
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> List[int]:
|
||||
"""Encode text to token IDs"""
|
||||
if isinstance(text, str):
|
||||
# For simple string encoding, create a user message
|
||||
messages = []
|
||||
if self._system_prompt and add_special_tokens:
|
||||
messages.append(SystemMessage(content=self._system_prompt))
|
||||
messages.append(UserMessage(content=text))
|
||||
|
||||
tokenized = self.mistral_tokenizer.encode_chat_completion(
|
||||
ChatCompletionRequest(messages=messages)
|
||||
)
|
||||
return tokenized.tokens
|
||||
else:
|
||||
raise ValueError("MistralTokenizer wrapper only supports string input")
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids: Union[List[int], torch.Tensor],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> str:
|
||||
"""Decode token IDs to text"""
|
||||
if isinstance(token_ids, torch.Tensor):
|
||||
token_ids = token_ids.tolist()
|
||||
return self.mistral_tokenizer.decode(token_ids)
|
||||
|
||||
def __call__(self, text: str, **kwargs):
|
||||
"""Make the tokenizer callable like HF tokenizers"""
|
||||
tokens = self.encode(text, **kwargs)
|
||||
return {"input_ids": torch.tensor([tokens])}
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self._token_ids["eos_token_id"]
|
||||
|
||||
@property
|
||||
def bos_token_id(self):
|
||||
return self._token_ids["bos_token_id"]
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self._token_ids["pad_token_id"]
|
||||
|
||||
@property
|
||||
def unk_token_id(self):
|
||||
return self._token_ids["unk_token_id"]
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return "</s>" # Standard Mistral EOS token
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return "<s>" # Standard Mistral BOS token
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
return self.eos_token # Use EOS as pad token
|
||||
|
||||
@property
|
||||
def unk_token(self):
|
||||
return "<unk>" # Standard UNK token
|
||||
|
||||
@property
|
||||
def __class__(self):
|
||||
# Create a mock class for compatibility checks
|
||||
class MistralTokenizerWrapperClass:
|
||||
__name__ = "MistralTokenizerWrapper"
|
||||
|
||||
return MistralTokenizerWrapperClass
|
||||
|
||||
def add_special_tokens(self, special_tokens_dict: Dict[str, str]) -> int:
|
||||
"""Placeholder for special token addition - Mistral tokenizer handles this internally"""
|
||||
LOG.warning(
|
||||
"add_special_tokens called on MistralTokenizer wrapper - this is handled internally"
|
||||
)
|
||||
return 0
|
||||
|
||||
def add_tokens(self, tokens) -> int:
|
||||
"""Placeholder for token addition - Mistral tokenizer handles this internally"""
|
||||
LOG.warning(
|
||||
"add_tokens called on MistralTokenizer wrapper - this is handled internally"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
class TokenizerFileModifier:
|
||||
"""Handles modification of tokenizer files for token overrides."""
|
||||
|
||||
def __init__(
|
||||
self, tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str
|
||||
):
|
||||
self.tokenizer_path = tokenizer_path
|
||||
self.token_mappings = token_mappings
|
||||
self.output_dir = output_dir
|
||||
self.tokenizer_dir = os.path.join(output_dir, "tokenizer")
|
||||
|
||||
def modify_and_save(self) -> str:
|
||||
"""Modify tokenizer files and return path to modified tokenizer."""
|
||||
os.makedirs(self.tokenizer_dir, exist_ok=True)
|
||||
|
||||
if is_local_main_process():
|
||||
self._perform_modifications()
|
||||
barrier()
|
||||
|
||||
return self.tokenizer_dir
|
||||
|
||||
def _perform_modifications(self):
|
||||
"""Perform the actual file modifications."""
|
||||
# Load and save tokenizer to output directory
|
||||
temp_tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.tokenizer_path, use_fast=True
|
||||
)
|
||||
temp_tokenizer.save_pretrained(self.tokenizer_dir)
|
||||
|
||||
# Convert token mappings to proper format
|
||||
token_id_mappings = {
|
||||
int(token_id): new_value for token_id, new_value in token_mappings.items()
|
||||
int(token_id): new_value
|
||||
for token_id, new_value in self.token_mappings.items()
|
||||
}
|
||||
|
||||
# 1. Update tokenizer_config.json - added_tokens_decoder
|
||||
config_path = os.path.join(tokenizer_dir, "tokenizer_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
# Update both tokenizer files
|
||||
self._update_tokenizer_config(token_id_mappings)
|
||||
self._update_tokenizer_json(token_id_mappings)
|
||||
|
||||
# Update added_tokens_decoder
|
||||
if "added_tokens_decoder" in config_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
token_id_str = str(token_id)
|
||||
if token_id_str in config_data["added_tokens_decoder"]:
|
||||
config_data["added_tokens_decoder"][token_id_str][
|
||||
"content"
|
||||
] = new_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||
)
|
||||
def _update_tokenizer_config(self, token_id_mappings: Dict[int, str]):
|
||||
"""Update tokenizer_config.json with new token mappings."""
|
||||
config_path = os.path.join(self.tokenizer_dir, "tokenizer_config.json")
|
||||
if not os.path.exists(config_path):
|
||||
return
|
||||
|
||||
# Write the updated config back
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# 2. Update tokenizer.json - added_tokens
|
||||
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
|
||||
if os.path.exists(tokenizer_path):
|
||||
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||
tokenizer_data = json.load(f)
|
||||
if "added_tokens_decoder" in config_data:
|
||||
self._update_added_tokens_decoder(config_data, token_id_mappings)
|
||||
|
||||
# Update added_tokens
|
||||
if "added_tokens" in tokenizer_data:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||
if token_entry["id"] == token_id:
|
||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||
break
|
||||
else:
|
||||
# Reaching this section means the token_id was not found in tokenizer.json added_tokens
|
||||
raise ValueError(
|
||||
f"Token ID {token_id} not found in added_tokens"
|
||||
)
|
||||
if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]:
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for entry_val, entry_id in tokenizer_data["model"]["vocab"].items():
|
||||
if entry_id == token_id:
|
||||
del tokenizer_data["model"]["vocab"][entry_val]
|
||||
tokenizer_data["model"]["vocab"][new_value] = token_id
|
||||
break
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
|
||||
# Write the updated tokenizer data back
|
||||
with open(tokenizer_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer_data, f, indent=2)
|
||||
def _update_added_tokens_decoder(
|
||||
self, config_data: Dict, token_id_mappings: Dict[int, str]
|
||||
):
|
||||
"""Update the added_tokens_decoder section."""
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
token_id_str = str(token_id)
|
||||
if token_id_str in config_data["added_tokens_decoder"]:
|
||||
config_data["added_tokens_decoder"][token_id_str]["content"] = new_value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Token ID {token_id_str} not found in added_tokens_decoder"
|
||||
)
|
||||
|
||||
barrier()
|
||||
return tokenizer_dir
|
||||
def _update_tokenizer_json(self, token_id_mappings: Dict[int, str]):
|
||||
"""Update tokenizer.json with new token mappings."""
|
||||
tokenizer_json_path = os.path.join(self.tokenizer_dir, "tokenizer.json")
|
||||
if not os.path.exists(tokenizer_json_path):
|
||||
return
|
||||
|
||||
with open(tokenizer_json_path, "r", encoding="utf-8") as f:
|
||||
tokenizer_data = json.load(f)
|
||||
|
||||
self._update_added_tokens_list(tokenizer_data, token_id_mappings)
|
||||
self._update_vocab_mappings(tokenizer_data, token_id_mappings)
|
||||
|
||||
with open(tokenizer_json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer_data, f, indent=2)
|
||||
|
||||
def _update_added_tokens_list(
|
||||
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
|
||||
):
|
||||
"""Update the added_tokens list in tokenizer.json."""
|
||||
if "added_tokens" not in tokenizer_data:
|
||||
return
|
||||
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
for i, token_entry in enumerate(tokenizer_data["added_tokens"]):
|
||||
if token_entry["id"] == token_id:
|
||||
tokenizer_data["added_tokens"][i]["content"] = new_value
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Token ID {token_id} not found in added_tokens")
|
||||
|
||||
def _update_vocab_mappings(
|
||||
self, tokenizer_data: Dict, token_id_mappings: Dict[int, str]
|
||||
):
|
||||
"""Update vocab mappings in tokenizer.json."""
|
||||
if not (tokenizer_data.get("model") and tokenizer_data["model"].get("vocab")):
|
||||
return
|
||||
|
||||
vocab = tokenizer_data["model"]["vocab"]
|
||||
for token_id, new_value in token_id_mappings.items():
|
||||
# Find and update the vocab entry
|
||||
for entry_val, entry_id in list(vocab.items()):
|
||||
if entry_id == token_id:
|
||||
del vocab[entry_val]
|
||||
vocab[new_value] = token_id
|
||||
break
|
||||
|
||||
|
||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
class TokenizerConfiguration:
|
||||
"""Handles tokenizer configuration and initialization."""
|
||||
|
||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||
"""Load mistral-common tokenizer"""
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.model_config = load_model_config(cfg)
|
||||
|
||||
# Load the HF-compatible wrapper around MistralTokenizer
|
||||
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
|
||||
def load_mistral_tokenizer(self) -> MistralTokenizerWrapper:
|
||||
"""Load Mistral tokenizer from model configuration."""
|
||||
# Instantiate Mistral tokenizer
|
||||
model_id = self.cfg.base_model
|
||||
mistral_tokenizer = MistralTokenizer.from_hf_hub(model_id)
|
||||
|
||||
# Wrap it for compatibility
|
||||
tokenizer = MistralTokenizerWrapper(mistral_tokenizer, model_id)
|
||||
LOG.info(f"Loaded Mistral tokenizer for model: {model_id}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
return _load_mistral_common_tokenizer(cfg)
|
||||
def get_tokenizer_class(self):
|
||||
"""Get the appropriate tokenizer class."""
|
||||
if self.cfg.tokenizer_type:
|
||||
return getattr(transformers, self.cfg.tokenizer_type)
|
||||
return AutoTokenizer
|
||||
|
||||
model_config = load_model_config(cfg)
|
||||
tokenizer_kwargs = {}
|
||||
use_fast = True # this is the default
|
||||
def get_tokenizer_kwargs(self) -> Dict[str, Any]:
|
||||
"""Build tokenizer initialization kwargs."""
|
||||
kwargs = {}
|
||||
if self.cfg.tokenizer_legacy is not None:
|
||||
kwargs["legacy"] = self.cfg.tokenizer_legacy
|
||||
return kwargs
|
||||
|
||||
if cfg.tokenizer_use_fast is not None:
|
||||
use_fast = cfg.tokenizer_use_fast
|
||||
if cfg.tokenizer_legacy is not None:
|
||||
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||
def get_tokenizer_path(self) -> str:
|
||||
"""Get the tokenizer path, applying overrides if needed."""
|
||||
tokenizer_path = self.cfg.tokenizer_config
|
||||
|
||||
tokenizer_cls = AutoTokenizer
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
if self.cfg.added_tokens_overrides:
|
||||
modifier = TokenizerFileModifier(
|
||||
tokenizer_path, self.cfg.added_tokens_overrides, self.cfg.output_dir
|
||||
)
|
||||
tokenizer_path = modifier.modify_and_save()
|
||||
|
||||
# Set base tokenizer path
|
||||
tokenizer_path = cfg.tokenizer_config
|
||||
return tokenizer_path
|
||||
|
||||
# Apply token string overrides if specified
|
||||
if cfg.added_tokens_overrides:
|
||||
# Modify tokenizer files and get path to modified tokenizer
|
||||
tokenizer_path = modify_tokenizer_files(
|
||||
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
|
||||
def should_use_fast_tokenizer(self) -> bool:
|
||||
"""Determine if fast tokenizer should be used."""
|
||||
return (
|
||||
self.cfg.tokenizer_use_fast
|
||||
if self.cfg.tokenizer_use_fast is not None
|
||||
else True
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_path,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
tokenizer.__class__.__name__
|
||||
in [
|
||||
"LlamaTokenizer",
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizer",
|
||||
"CodeLlamaTokenizerFast",
|
||||
class TokenizerPostProcessor:
|
||||
"""Handles post-processing configuration of loaded tokenizers."""
|
||||
|
||||
def __init__(self, tokenizer, cfg):
|
||||
self.tokenizer = tokenizer
|
||||
self.cfg = cfg
|
||||
self.model_config = load_model_config(cfg)
|
||||
|
||||
def apply_all_configurations(self):
|
||||
"""Apply all post-processing configurations to the tokenizer."""
|
||||
# Skip most configurations for Mistral wrapper
|
||||
if isinstance(self.tokenizer, MistralTokenizerWrapper):
|
||||
self._configure_mistral_wrapper()
|
||||
return
|
||||
|
||||
self._configure_padding_token()
|
||||
self._configure_gptneox_settings()
|
||||
self._configure_mistral_padding()
|
||||
self._configure_qwen_tokens()
|
||||
self._add_special_tokens()
|
||||
self._add_regular_tokens()
|
||||
self._configure_chat_template()
|
||||
|
||||
def _configure_mistral_wrapper(self):
|
||||
"""Apply limited configurations for Mistral wrapper."""
|
||||
# Set padding side if needed
|
||||
if (
|
||||
self.cfg.is_mistral_derived_model
|
||||
and self.cfg.flash_attention
|
||||
and not self.cfg.sample_packing
|
||||
):
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
# Configure chat template for Mistral
|
||||
self._configure_chat_template()
|
||||
|
||||
def _configure_padding_token(self):
|
||||
"""Configure padding token for Llama-based tokenizers."""
|
||||
if (
|
||||
self.tokenizer.__class__.__name__ in LLAMA_TOKENIZER_CLASSES
|
||||
and hasattr(self.tokenizer, "pad_token")
|
||||
and not self.tokenizer.pad_token
|
||||
):
|
||||
self.tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
def _configure_gptneox_settings(self):
|
||||
"""Configure GPTNeoX-specific settings."""
|
||||
if self.tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
self.tokenizer.add_special_tokens({"pad_token": GPTNEOX_PAD_TOKEN})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
def _configure_mistral_padding(self):
|
||||
"""Configure left padding for Mistral models with Flash Attention."""
|
||||
if (
|
||||
self.cfg.is_mistral_derived_model
|
||||
and self.cfg.flash_attention
|
||||
and not self.cfg.sample_packing
|
||||
):
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
def _configure_qwen_tokens(self):
|
||||
"""Configure special tokens for Qwen models."""
|
||||
if not self.cfg.is_qwen_derived_model:
|
||||
return
|
||||
|
||||
# Set token IDs
|
||||
token_id_attributes = [
|
||||
"bos_token_id",
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
"unk_token_id",
|
||||
]
|
||||
and hasattr(tokenizer, "pad_token")
|
||||
and not tokenizer.pad_token
|
||||
):
|
||||
# set a pad_token, but use eos_token so we don't add a new token
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
for attr_name in token_id_attributes:
|
||||
if getattr(self.tokenizer, attr_name) is None:
|
||||
setattr(self.tokenizer, attr_name, self.tokenizer.eod_id)
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
# Set token strings
|
||||
token_name_attributes = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
||||
for attr_name in token_name_attributes:
|
||||
if getattr(self.tokenizer, attr_name) is None:
|
||||
setattr(self.tokenizer, attr_name, QWEN_DEFAULT_TOKEN)
|
||||
|
||||
# Mistral's official FA implementation requires left padding
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||
tokenizer.padding_side = "left"
|
||||
def _add_special_tokens(self):
|
||||
"""Add special tokens from configuration."""
|
||||
if not self.cfg.special_tokens:
|
||||
return
|
||||
|
||||
# Qwen base only has single token, so we need to set the special tokens
|
||||
if cfg.is_qwen_derived_model:
|
||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||
for attr_name in token_ids:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, tokenizer.eod_id)
|
||||
|
||||
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
||||
for attr_name in token_names:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
additional_special_tokens = None
|
||||
if cfg.special_tokens:
|
||||
special_tokens = cfg.special_tokens.to_dict()
|
||||
additional_special_tokens = special_tokens.pop(
|
||||
special_tokens_dict = self.cfg.special_tokens.to_dict()
|
||||
additional_special_tokens = special_tokens_dict.pop(
|
||||
"additional_special_tokens", None
|
||||
)
|
||||
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||
for k, val in special_tokens.items():
|
||||
# check if new special token is not already in tokenizer and
|
||||
# is adapter training to make sure lora_modules_to_save is set
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (
|
||||
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
|
||||
and cfg.adapter
|
||||
and (
|
||||
not cfg.lora_modules_to_save
|
||||
or not all(
|
||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||
)
|
||||
)
|
||||
and k != "pad_token"
|
||||
):
|
||||
lora_modules_to_save_str = ", ".join(
|
||||
[f"`{x}`" for x in lora_modules_to_save]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Please set lora_modules_to_save to [{lora_modules_to_save_str}] "
|
||||
"when using an adapter and changing the special tokens."
|
||||
)
|
||||
|
||||
tokenizer.add_special_tokens(
|
||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||
self._validate_and_add_special_tokens(special_tokens_dict)
|
||||
self._update_post_processor_if_needed(special_tokens_dict)
|
||||
self._add_additional_special_tokens_if_present(additional_special_tokens)
|
||||
|
||||
def _validate_and_add_special_tokens(self, special_tokens: Dict[str, str]):
|
||||
"""Validate special tokens for adapter training and add them."""
|
||||
lora_modules_to_save = get_linear_embedding_layers(self.model_config.model_type)
|
||||
|
||||
for key, value in special_tokens.items():
|
||||
self._validate_token_for_adapter(key, value, lora_modules_to_save)
|
||||
self.tokenizer.add_special_tokens(
|
||||
{key: AddedToken(value, rstrip=False, lstrip=False, normalized=False)}
|
||||
)
|
||||
|
||||
# If we add bos_token and eos_token, we need to update the post processor to
|
||||
# handle them correctly.
|
||||
# https://github.com/huggingface/transformers/pull/24132
|
||||
bos_or_eos_in_special_tokens = (
|
||||
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
||||
)
|
||||
if (
|
||||
tokenizer.__class__.__name__
|
||||
in (
|
||||
"LlamaTokenizerFast",
|
||||
"CodeLlamaTokenizerFast",
|
||||
)
|
||||
and bos_or_eos_in_special_tokens
|
||||
def _validate_token_for_adapter(
|
||||
self, key: str, value: str, lora_modules_to_save: List[str]
|
||||
):
|
||||
"""Validate a single token for adapter training requirements."""
|
||||
if not self._should_validate_token_for_adapter(
|
||||
key, value, lora_modules_to_save
|
||||
):
|
||||
tokenizer.update_post_processor()
|
||||
return
|
||||
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
||||
for token in cfg.tokens
|
||||
]
|
||||
modules_str = ", ".join(f"`{x}`" for x in lora_modules_to_save)
|
||||
raise ValueError(
|
||||
f"Please set lora_modules_to_save to [{modules_str}] "
|
||||
f"when using an adapter and changing the special tokens."
|
||||
)
|
||||
|
||||
# Additional special tokens are a List, and need to be treated differently than regular special
|
||||
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
|
||||
# are new tokens.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# ```py
|
||||
# special_tokens:
|
||||
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
|
||||
# ```
|
||||
if additional_special_tokens is not None:
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
def _should_validate_token_for_adapter(
|
||||
self, key: str, value: str, lora_modules_to_save: List[str]
|
||||
) -> bool:
|
||||
"""Check if token should be validated for adapter configuration."""
|
||||
if key == "pad_token" or not self.cfg.adapter:
|
||||
return False
|
||||
|
||||
current_token = getattr(self.tokenizer, key)
|
||||
token_changed = current_token is None or current_token != value
|
||||
token_is_multi_char = (
|
||||
len(self.tokenizer.encode(value, add_special_tokens=False)) > 2
|
||||
)
|
||||
lora_modules_missing = not self.cfg.lora_modules_to_save or not all(
|
||||
x in self.cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||
)
|
||||
|
||||
return token_changed and token_is_multi_char and lora_modules_missing
|
||||
|
||||
def _update_post_processor_if_needed(self, special_tokens: Dict[str, str]):
|
||||
"""Update post processor for Llama tokenizers when BOS/EOS tokens are added."""
|
||||
has_bos_and_eos = (
|
||||
"bos_token" in special_tokens and "eos_token" in special_tokens
|
||||
)
|
||||
is_fast_llama = (
|
||||
self.tokenizer.__class__.__name__ in FAST_LLAMA_TOKENIZER_CLASSES
|
||||
)
|
||||
|
||||
if is_fast_llama and has_bos_and_eos:
|
||||
self.tokenizer.update_post_processor()
|
||||
|
||||
def _add_additional_special_tokens_if_present(
|
||||
self, additional_special_tokens: Optional[List[str]]
|
||||
):
|
||||
"""Add additional special tokens if they exist."""
|
||||
if additional_special_tokens is not None:
|
||||
self.tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
def _add_regular_tokens(self):
|
||||
"""Add regular (non-special) tokens from configuration."""
|
||||
if self.cfg.tokens:
|
||||
self.tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
||||
for token in self.cfg.tokens
|
||||
]
|
||||
)
|
||||
|
||||
def _configure_chat_template(self):
|
||||
"""Configure chat template if specified."""
|
||||
if not self.cfg.chat_template:
|
||||
LOG.info(
|
||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||
)
|
||||
return
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
if self._should_replace_default_system_message():
|
||||
chat_template_string = chat_template_string.replace(
|
||||
CHATML_DEFAULT_SYSTEM_MESSAGE, self.cfg.default_system_message
|
||||
)
|
||||
|
||||
self.tokenizer.chat_template = chat_template_string
|
||||
|
||||
def _should_replace_default_system_message(self) -> bool:
|
||||
"""Check if default system message should be replaced."""
|
||||
return self.cfg.default_system_message and self.cfg.chat_template == "chatml"
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
"""Load and configure the tokenizer based on the provided config.
|
||||
|
||||
This function handles the complete tokenizer loading pipeline:
|
||||
- Check if Mistral tokenizer should be used
|
||||
- Configure tokenizer parameters and get the appropriate class
|
||||
- Handle token file modifications if needed
|
||||
- Initialize the tokenizer with the correct parameters
|
||||
- Apply all post-processing configurations (padding, special tokens, etc.)
|
||||
- Set up chat templates and logging
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
Fully configured tokenizer instance.
|
||||
"""
|
||||
# Configure tokenizer parameters
|
||||
config = TokenizerConfiguration(cfg)
|
||||
|
||||
# Check if we should use Mistral tokenizer
|
||||
try:
|
||||
tokenizer = config.load_mistral_tokenizer()
|
||||
except:
|
||||
# Standard tokenizer loading
|
||||
tokenizer_cls = config.get_tokenizer_class()
|
||||
tokenizer_path = config.get_tokenizer_path()
|
||||
use_fast = config.should_use_fast_tokenizer()
|
||||
tokenizer_kwargs = config.get_tokenizer_kwargs()
|
||||
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_path,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
use_fast=use_fast,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
# Apply all post-processing configurations
|
||||
post_processor = TokenizerPostProcessor(tokenizer, cfg)
|
||||
post_processor.apply_all_configurations()
|
||||
|
||||
if is_main_process(use_environ=True):
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
@@ -279,19 +654,4 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if cfg.chat_template:
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
if cfg.default_system_message and cfg.chat_template == "chatml":
|
||||
chat_template_string = chat_template_string.replace(
|
||||
"You are a helpful assistant.", cfg.default_system_message
|
||||
)
|
||||
|
||||
tokenizer.chat_template = chat_template_string
|
||||
else:
|
||||
LOG.info(
|
||||
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
@@ -25,20 +25,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
axolotl_log_level = os.getenv(
|
||||
"AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL
|
||||
).upper()
|
||||
other_log_level = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper()
|
||||
|
||||
try:
|
||||
# py311+ only
|
||||
level_mapping = logging.getLevelNamesMapping()
|
||||
self.axolotl_level = level_mapping[axolotl_log_level]
|
||||
self.other_level = level_mapping[other_log_level]
|
||||
except AttributeError:
|
||||
# For py310, use getLevelName directly
|
||||
self.axolotl_level = logging.getLevelName(axolotl_log_level)
|
||||
self.other_level = logging.getLevelName(other_log_level)
|
||||
self.axolotl_level = logging.getLevelNamesMapping()[
|
||||
os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL)
|
||||
]
|
||||
self.other_level = logging.getLevelNamesMapping()[
|
||||
os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
|
||||
]
|
||||
|
||||
def filter(self, record: LogRecord) -> bool:
|
||||
# General filter
|
||||
|
||||
@@ -145,11 +145,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
|
||||
return Qwen2Attention
|
||||
|
||||
if model_type == "mllama":
|
||||
from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention
|
||||
|
||||
return MllamaTextSelfAttention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
@@ -274,29 +269,6 @@ def find_mlp_in_layer(
|
||||
)
|
||||
|
||||
|
||||
def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
|
||||
"""
|
||||
Get the layers of the model. Handles text-only and multimodal models.
|
||||
|
||||
Args:
|
||||
model: A PEFT model.
|
||||
|
||||
Returns:
|
||||
A list of layers.
|
||||
"""
|
||||
pretrained_model = model.model
|
||||
|
||||
# check for multimodal models first
|
||||
if hasattr(pretrained_model, "language_model"):
|
||||
return pretrained_model.language_model.layers
|
||||
if hasattr(pretrained_model, "model"):
|
||||
return pretrained_model.model.layers
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
model: PeftModelForCausalLM, cfg: DictDefault
|
||||
) -> PeftModelForCausalLM:
|
||||
@@ -368,7 +340,17 @@ def apply_lora_kernel_patches(
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||
|
||||
layers = get_layers(model)
|
||||
layers = []
|
||||
# check for multimodal models first
|
||||
pretrained_model = model.model
|
||||
if hasattr(pretrained_model, "language_model"):
|
||||
layers = pretrained_model.language_model.layers
|
||||
elif hasattr(pretrained_model, "model"):
|
||||
layers = pretrained_model.model.layers
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
|
||||
)
|
||||
|
||||
# Patch each layer
|
||||
for layer in layers:
|
||||
|
||||
@@ -17,10 +17,7 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
|
||||
load_fn = "load"
|
||||
package = "axolotl.prompt_strategies"
|
||||
if (
|
||||
strategy.split(".")[-1].startswith("load_")
|
||||
or strategy.split(".")[-1] == "load"
|
||||
):
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
elif len(strategy.split(".")) > 1:
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
HF Chat Templates prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import ProcessorMixin
|
||||
@@ -17,9 +15,6 @@ from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
# Configure the logger
|
||||
LOG = get_logger(__name__)
|
||||
LOG.setLevel("INFO")
|
||||
@@ -39,7 +34,6 @@ class ChatTemplatePrompter(Prompter):
|
||||
message_field_training_detail: str | None = None,
|
||||
field_messages: str = "messages",
|
||||
field_system: str = "system",
|
||||
field_tools: str = "tools",
|
||||
roles: dict[str, list[str]] | None = None,
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
drop_system_message: bool = False,
|
||||
@@ -72,7 +66,6 @@ class ChatTemplatePrompter(Prompter):
|
||||
self.message_field_training_detail = message_field_training_detail
|
||||
self.field_messages = field_messages
|
||||
self.field_system = field_system
|
||||
self.field_tools = field_tools
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: ProcessorMixin | None = processor
|
||||
self.chat_template = chat_template
|
||||
@@ -84,38 +77,17 @@ class ChatTemplatePrompter(Prompter):
|
||||
def chat_template_msg_variables(self) -> Set[str]:
|
||||
return self._chat_template_msg_variables
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
conversation: list[dict],
|
||||
add_generation_prompt=False,
|
||||
images=None,
|
||||
tools=None,
|
||||
):
|
||||
"""
|
||||
Build a prompt from a conversation.
|
||||
|
||||
Args:
|
||||
conversation: A list of messages.
|
||||
add_generation_prompt: Whether to add a generation prompt.
|
||||
images: A list of images. (optional)
|
||||
tools: A list of tools. (optional)
|
||||
"""
|
||||
chat_template_kwargs = {
|
||||
"chat_template": self.chat_template,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
}
|
||||
|
||||
if tools:
|
||||
chat_template_kwargs["tools"] = tools
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||
if self.processor:
|
||||
if not callable(self.processor):
|
||||
raise TypeError("Processor must be callable")
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
conversation,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
**chat_template_kwargs,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**self.chat_template_kwargs,
|
||||
)
|
||||
batch = self.processor(
|
||||
text=text,
|
||||
@@ -132,7 +104,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
**chat_template_kwargs,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template=self.chat_template,
|
||||
**self.chat_template_kwargs,
|
||||
)
|
||||
|
||||
def get_offsets_for_train_detail(
|
||||
@@ -276,15 +250,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
|
||||
|
||||
# Default to eos_token if eot_tokens not provided
|
||||
self.eot_tokens = []
|
||||
if eot_tokens is not None:
|
||||
self.eot_tokens = eot_tokens
|
||||
elif (
|
||||
hasattr(self.tokenizer, "eos_token")
|
||||
and self.tokenizer.eos_token is not None
|
||||
):
|
||||
self.eot_tokens = [self.tokenizer.eos_token]
|
||||
|
||||
self.eot_tokens = (
|
||||
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
|
||||
)
|
||||
self.split_thinking = split_thinking
|
||||
|
||||
self.images = "images"
|
||||
@@ -408,7 +376,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
and not self.prompter.message_field_training_detail # type: ignore
|
||||
):
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
images = self._get_images(prompt)
|
||||
images = self.get_images(prompt)
|
||||
prompt_ids = self.prompter.build_prompt( # type: ignore
|
||||
turns[:-1],
|
||||
add_generation_prompt=True,
|
||||
@@ -437,8 +405,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return tokenized_prompt
|
||||
|
||||
turns = self.get_conversation_thread(prompt)
|
||||
tools = self._get_tools(prompt)
|
||||
input_ids = self.prompter.build_prompt(turns, tools=tools) # type: ignore
|
||||
input_ids = self.prompter.build_prompt(turns) # type: ignore
|
||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||
|
||||
last_eos_idx = -1
|
||||
@@ -477,9 +444,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
continue
|
||||
|
||||
turn_start_idx, turn_end_idx = self.find_turn(
|
||||
turns=turns, turn_idx=index, tools=tools
|
||||
)
|
||||
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
|
||||
|
||||
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||
|
||||
@@ -581,9 +546,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
return i
|
||||
return -1
|
||||
|
||||
def find_turn(
|
||||
self, turns: list[dict], turn_idx: int, tools: list[dict] | None = None
|
||||
):
|
||||
def find_turn(self, turns: list[dict], turn_idx: int):
|
||||
"""
|
||||
Locate the starting and ending indices of the specified turn in a conversation.
|
||||
"""
|
||||
@@ -596,7 +559,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
if (
|
||||
turn_idx == 0
|
||||
and turns[0].get("role") == "system"
|
||||
and ("mistral" in self.tokenizer.name_or_path.lower())
|
||||
and (
|
||||
"mistral" in self.tokenizer.name_or_path.lower()
|
||||
or "gemma"
|
||||
in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer
|
||||
)
|
||||
):
|
||||
return -1, -1
|
||||
|
||||
@@ -610,10 +577,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
turns_with_content = turns[: turn_idx + 1]
|
||||
|
||||
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
||||
dummy_ids = self.prompter.build_prompt(turns_with_empty, tools=tools) # type: ignore
|
||||
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
|
||||
|
||||
# Generate the conversation up to the turn, with final turn included
|
||||
full_ids = self.prompter.build_prompt(turns_with_content, tools=tools) # type: ignore
|
||||
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
|
||||
|
||||
if not full_ids or not dummy_ids:
|
||||
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
||||
@@ -666,10 +633,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = []
|
||||
|
||||
messages = self._get_messages(prompt)
|
||||
|
||||
possible_sys_turn = self.transform_message(messages[0])
|
||||
|
||||
possible_sys_turn = self.transform_message(
|
||||
prompt[self.prompter.field_messages][0]
|
||||
)
|
||||
if (
|
||||
possible_sys_turn["role"] != "system"
|
||||
and self.prompter.field_system in prompt
|
||||
@@ -677,7 +643,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
|
||||
turns.append(turn)
|
||||
|
||||
for message in messages:
|
||||
for message in prompt[self.prompter.field_messages]:
|
||||
transformed_message = self.transform_message(message)
|
||||
|
||||
turn = {
|
||||
@@ -695,7 +661,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return turns
|
||||
|
||||
def transform_message(self, message: dict) -> dict:
|
||||
def transform_message(self, message):
|
||||
# Build the initial transformed message from the mappings
|
||||
transformed_message = {}
|
||||
for key, value in self.prompter.message_property_mappings.items():
|
||||
@@ -772,135 +738,18 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return transformed_message
|
||||
|
||||
def _get_images(self, prompt):
|
||||
def get_images(self, prompt):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
def _get_tools(self, prompt) -> list[dict] | None:
|
||||
"""Get tools from prompt if available."""
|
||||
tools = prompt.get(self.prompter.field_tools, None)
|
||||
if tools is None:
|
||||
return None
|
||||
|
||||
if isinstance(tools, list):
|
||||
return tools
|
||||
|
||||
raise ValueError(
|
||||
"Unknown tools format. Please convert it into a list[dict].\n"
|
||||
f"Current format: {type(tools)}"
|
||||
)
|
||||
|
||||
def _get_messages(self, prompt):
|
||||
messages = prompt.get(self.prompter.field_messages, None)
|
||||
if messages is None:
|
||||
raise ValueError("Messages is null. Please check `field_messages`.")
|
||||
|
||||
if isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
raise ValueError(
|
||||
"Unknown messages format. Please convert it into a list[dict].\n"
|
||||
f"Current format: {type(messages)}"
|
||||
)
|
||||
|
||||
|
||||
class MistralStrategy(ChatTemplateStrategy):
|
||||
"""
|
||||
Mistral strategy for chat template.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: "ChatTemplatePrompter",
|
||||
tokenizer: "HFMistralTokenizer",
|
||||
train_on_inputs: bool,
|
||||
sequence_len: int,
|
||||
roles_to_train: list[str] | None = None,
|
||||
train_on_eos: str | None = None,
|
||||
train_on_eot: str | None = None,
|
||||
eot_tokens: list[str] | None = None,
|
||||
split_thinking: bool | None = False,
|
||||
):
|
||||
# Call the parent's parent __init__ (PromptTokenizingStrategy) to skip ChatTemplateStrategy's validation
|
||||
# pylint: disable=non-parent-init-called,super-init-not-called
|
||||
PromptTokenizingStrategy.__init__(
|
||||
self, prompter, tokenizer, train_on_inputs, sequence_len
|
||||
)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
|
||||
self.roles_to_train = []
|
||||
if roles_to_train:
|
||||
# map roles if exist in prompter.roles else use the role as is
|
||||
self.roles_to_train = [
|
||||
prompter.roles.get(role, role) for role in roles_to_train
|
||||
]
|
||||
|
||||
self.train_on_eos = train_on_eos
|
||||
# Backward compatibility, load from train_on_eos
|
||||
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
|
||||
|
||||
# Default to eos_token if eot_tokens not provided
|
||||
self.eot_tokens = []
|
||||
if eot_tokens is not None:
|
||||
self.eot_tokens = eot_tokens
|
||||
else:
|
||||
# set eot_tokens to the eos_token
|
||||
self.eot_tokens = [self.tokenizer.eos_token]
|
||||
|
||||
self.split_thinking = split_thinking
|
||||
|
||||
self.images = "images"
|
||||
|
||||
LOG.debug(
|
||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||
)
|
||||
|
||||
# Skip the validation that ChatTemplateStrategy calls
|
||||
# TODO: address this in the future with mistral-specific checks
|
||||
# self._validate_eot_and_eos_tokens()
|
||||
|
||||
@property
|
||||
def supports_multiprocessing(self) -> bool:
|
||||
"""
|
||||
Whether this tokenizing strategy supports multiprocessing.
|
||||
mistral_common tokenizers cannot be pickled for multiprocessing.
|
||||
"""
|
||||
|
||||
return False
|
||||
|
||||
def find_first_eot_token(self, input_ids, start_idx):
|
||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||
# mistral-common tokenizer does not support eot_tokens
|
||||
return self.find_first_eos_token(input_ids, start_idx)
|
||||
|
||||
|
||||
class MistralPrompter(ChatTemplatePrompter):
|
||||
"""
|
||||
Mistral prompter for chat template.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._chat_template_msg_variables = set(["tool_call_id", "name", "tool_calls"])
|
||||
|
||||
|
||||
class StrategyLoader:
|
||||
"""
|
||||
Load chat template strategy based on configuration.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self, cfg):
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
return MistralStrategy
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
return ChatTemplateStrategy
|
||||
|
||||
def _get_prompter_cls(self, cfg):
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
return MistralPrompter
|
||||
|
||||
return ChatTemplatePrompter
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
return {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
@@ -926,14 +775,9 @@ class StrategyLoader:
|
||||
else:
|
||||
dataset_config = ds_cfg
|
||||
|
||||
if cfg.tokenizer_use_mistral_common:
|
||||
# mistral-common does not use this, so we pass an empty string
|
||||
chat_template_string = ""
|
||||
else:
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
@@ -959,11 +803,10 @@ class StrategyLoader:
|
||||
}
|
||||
|
||||
strategy_params = self._get_strategy_params(cfg, dataset_config)
|
||||
strategy_cls = self._get_strategy_cls(cfg)
|
||||
prompter_cls = self._get_prompter_cls(cfg)
|
||||
strategy_cls = self._get_strategy_cls()
|
||||
|
||||
strategy = strategy_cls(
|
||||
prompter_cls(**prompter_params),
|
||||
ChatTemplatePrompter(**prompter_params),
|
||||
tokenizer=tokenizer,
|
||||
**strategy_params,
|
||||
)
|
||||
|
||||
@@ -46,14 +46,6 @@ def default(
|
||||
)
|
||||
|
||||
messages = sample[field_messages]
|
||||
if isinstance(messages, str):
|
||||
messages = [
|
||||
{
|
||||
message_property_mappings["role"]: "user",
|
||||
message_property_mappings["content"]: messages,
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": role_map[m[message_property_mappings["role"]]],
|
||||
@@ -61,35 +53,13 @@ def default(
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
|
||||
chosen_raw = sample[field_chosen]
|
||||
if isinstance(chosen_raw, str):
|
||||
chosen_msg = {
|
||||
message_property_mappings["role"]: "assistant",
|
||||
message_property_mappings["content"]: chosen_raw,
|
||||
}
|
||||
elif isinstance(chosen_raw, dict):
|
||||
chosen_msg = chosen_raw
|
||||
else:
|
||||
chosen_msg = chosen_raw[-1]
|
||||
chosen = {
|
||||
"role": role_map[chosen_msg[message_property_mappings["role"]]],
|
||||
"content": chosen_msg[message_property_mappings["content"]],
|
||||
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
|
||||
"content": sample[field_chosen][message_property_mappings["content"]],
|
||||
}
|
||||
|
||||
rejected_raw = sample[field_rejected]
|
||||
if isinstance(rejected_raw, str):
|
||||
rejected_msg = {
|
||||
message_property_mappings["role"]: "assistant",
|
||||
message_property_mappings["content"]: rejected_raw,
|
||||
}
|
||||
elif isinstance(rejected_raw, dict):
|
||||
rejected_msg = rejected_raw
|
||||
else:
|
||||
rejected_msg = rejected_raw[-1]
|
||||
rejected = {
|
||||
"role": role_map[rejected_msg[message_property_mappings["role"]]],
|
||||
"content": rejected_msg[message_property_mappings["content"]],
|
||||
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
|
||||
"content": sample[field_rejected][message_property_mappings["content"]],
|
||||
}
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from typing import Dict, Optional, Set, TypedDict, Union
|
||||
|
||||
from jinja2 import Environment, meta, nodes
|
||||
from jinja2.ext import Extension
|
||||
|
||||
|
||||
class JinjaTemplateAnalysis(TypedDict):
|
||||
@@ -28,18 +27,6 @@ class JinjaTemplateAnalysis(TypedDict):
|
||||
iteration_target: Optional[Union[str, list[str]]]
|
||||
|
||||
|
||||
class GenerationTagIgnore(Extension):
|
||||
"""
|
||||
Ignores the generation and endgeneration tags in Jinja templates.
|
||||
"""
|
||||
|
||||
tags = {"generation", "endgeneration"}
|
||||
|
||||
def parse(self, parser):
|
||||
parser.stream.skip(1)
|
||||
return nodes.Const("")
|
||||
|
||||
|
||||
class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
Analyzes Jinja templates to extract information about variable usage,
|
||||
@@ -70,9 +57,7 @@ class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.env: Environment = Environment(
|
||||
autoescape=True, extensions=[GenerationTagIgnore]
|
||||
)
|
||||
self.env: Environment = Environment(autoescape=True)
|
||||
self.property_access: Dict[str, Set[str]] = {}
|
||||
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
|
||||
self.index_access: Dict[str, Set[Union[int, float]]] = {}
|
||||
|
||||
@@ -32,3 +32,4 @@ def load(tokenizer, cfg, ds_cfg, processor=None):
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
raise exc
|
||||
return None
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import abc
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from datasets import Dataset
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
from axolotl.prompters import Prompter
|
||||
@@ -29,16 +28,6 @@ class DatasetWrappingStrategy(abc.ABC):
|
||||
Abstract class for wrapping datasets for Chat Messages
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def wrap_dataset(
|
||||
self,
|
||||
dataset,
|
||||
process_count: int | None = None,
|
||||
keep_in_memory: bool | None = False,
|
||||
**kwargs,
|
||||
) -> Dataset:
|
||||
pass
|
||||
|
||||
|
||||
class PromptTokenizingStrategy(abc.ABC):
|
||||
"""
|
||||
@@ -70,14 +59,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def supports_batched(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_multiprocessing(self):
|
||||
"""
|
||||
Whether this tokenizing strategy supports multiprocessing.
|
||||
Should return False if the tokenizer has unpicklable objects.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
@@ -86,6 +67,10 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
return empty
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import typing
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
@@ -28,6 +25,7 @@ from axolotl.common.datasets import TrainDatasetMeta
|
||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||
fix_untrained_tokens,
|
||||
)
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders import (
|
||||
ModelLoader,
|
||||
@@ -47,9 +45,6 @@ try:
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,8 +53,8 @@ def setup_model_and_tokenizer(
|
||||
) -> tuple[
|
||||
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
|
||||
]:
|
||||
"""Load the tokenizer, processor (for multimodal models), and model based on
|
||||
configuration.
|
||||
"""
|
||||
Load the tokenizer, processor (for multimodal models), and model based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
@@ -477,7 +472,7 @@ def handle_untrained_tokens_fix(
|
||||
|
||||
|
||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
|
||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
|
||||
@@ -52,10 +52,3 @@ def patch_optimized_env():
|
||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
|
||||
def get_not_null(value, default=None):
|
||||
"""
|
||||
return the value if it's not None, otherwise return the default value
|
||||
"""
|
||||
return value if value is not None else default
|
||||
|
||||
@@ -53,6 +53,25 @@ IGNORE_INDEX = -100
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class EvalFirstStepCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||
"""
|
||||
Callback to trigger evals on the first step
|
||||
"""
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
|
||||
class SaveBetterTransformerModelCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,7 +1,7 @@
|
||||
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -81,11 +81,9 @@ class DataCollatorForSeq2Seq:
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
for feature in features:
|
||||
remainder_len = max_feature_length - len(feature[feature_name])
|
||||
if feature_name == "position_ids":
|
||||
remainder = list(range(remainder_len))
|
||||
else:
|
||||
remainder = [pad_token_id] * remainder_len
|
||||
remainder = [pad_token_id] * (
|
||||
max_feature_length - len(feature[feature_name])
|
||||
)
|
||||
if isinstance(feature[feature_name], list):
|
||||
feature[feature_name] = (
|
||||
feature[feature_name] + remainder
|
||||
@@ -163,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features: List[List[dict]] = [features]
|
||||
features = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
"""Init for `axolotl.utils.data` module."""
|
||||
"""
|
||||
Data processing modules
|
||||
"""
|
||||
|
||||
from axolotl.utils.data.pretraining import (
|
||||
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||
encode_pretraining,
|
||||
wrap_pretraining_dataset,
|
||||
)
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.data.sft import (
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
|
||||
from axolotl.utils.data.sft import ( # noqa: F401
|
||||
get_dataset_wrapper,
|
||||
prepare_datasets,
|
||||
load_prepare_datasets,
|
||||
load_tokenized_prepared_datasets,
|
||||
prepare_dataset,
|
||||
)
|
||||
from axolotl.utils.data.utils import md5
|
||||
|
||||
__all__ = [
|
||||
"encode_pretraining",
|
||||
"wrap_pretraining_dataset",
|
||||
"prepare_preference_datasets",
|
||||
"get_dataset_wrapper",
|
||||
"prepare_datasets",
|
||||
"md5",
|
||||
]
|
||||
from axolotl.utils.data.utils import md5 # noqa: F401
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Logic for loading / preparing a dataset once over all processes."""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOCK_FILE_NAME = "datasets_prep.lock"
|
||||
READY_FILE_NAME = "datasets_ready.flag"
|
||||
PROCESS_COUNTER_FILE_NAME = "process_counter.txt"
|
||||
|
||||
|
||||
class FileLockLoader:
|
||||
"""
|
||||
Simple class for abstracting single process data loading / processing. The first
|
||||
process that creates a lock file does the work; the remaining procesees simply load
|
||||
the preprocessed dataset once the first process is done.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.cfg = cfg
|
||||
self.dataset_prepared_path = (
|
||||
cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
self.lock_file_path = Path(self.dataset_prepared_path) / LOCK_FILE_NAME
|
||||
self.ready_flag_path = Path(self.dataset_prepared_path) / READY_FILE_NAME
|
||||
self.counter_path = Path(self.dataset_prepared_path) / PROCESS_COUNTER_FILE_NAME
|
||||
|
||||
def load(self, load_fn: Callable[[], Any]) -> Any:
|
||||
with FileLock(str(self.lock_file_path)):
|
||||
self._increment_counter()
|
||||
|
||||
if not self.ready_flag_path.exists():
|
||||
result = load_fn()
|
||||
self.ready_flag_path.touch()
|
||||
return result
|
||||
|
||||
while not self.ready_flag_path.exists():
|
||||
time.sleep(1)
|
||||
return load_fn()
|
||||
|
||||
def _increment_counter(self):
|
||||
"""Safely increment the process counter."""
|
||||
if self.counter_path.exists():
|
||||
count = int(self.counter_path.read_text().strip())
|
||||
else:
|
||||
count = 0
|
||||
self.counter_path.write_text(str(count + 1))
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up ready flag when last process is done."""
|
||||
with FileLock(str(self.lock_file_path)):
|
||||
count = int(self.counter_path.read_text().strip())
|
||||
count -= 1
|
||||
|
||||
if count == 0:
|
||||
# Last process cleans everything up
|
||||
self.ready_flag_path.unlink(missing_ok=True)
|
||||
self.counter_path.unlink(missing_ok=True)
|
||||
else:
|
||||
# Still have active processes
|
||||
self.counter_path.write_text(str(count))
|
||||
@@ -250,7 +250,7 @@ def encode_packed_pretraining(
|
||||
# pylint: disable=duplicate-code
|
||||
# tokenize all the examples
|
||||
# rows get split with stride (overlap)
|
||||
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
|
||||
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
||||
|
||||
train_dataset = process_pretraining_datasets_for_packing(
|
||||
train_dataset,
|
||||
|
||||
@@ -1,117 +1,75 @@
|
||||
"""Data handling specific to RL trainers."""
|
||||
"""data handling specific to DPO"""
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Literal
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
from transformers import PreTrainedTokenizer
|
||||
import yaml
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.kto import load as load_kto
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.lock import FileLockLoader
|
||||
from axolotl.utils.data.shared import (
|
||||
create_train_validation_split,
|
||||
datasets_with_name_generator,
|
||||
generate_dataset_hash_from_config,
|
||||
load_dataset_with_config,
|
||||
load_preprocessed_dataset,
|
||||
merge_datasets,
|
||||
save_preprocessed_dataset,
|
||||
try_load_from_hub,
|
||||
)
|
||||
from axolotl.utils.data.utils import (
|
||||
deduplicate_and_log_datasets,
|
||||
retry_on_request_exceptions,
|
||||
)
|
||||
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
def prepare_preference_datasets(
|
||||
cfg: DictDefault, tokenizer: PreTrainedTokenizer
|
||||
) -> tuple[Dataset, Dataset | None]:
|
||||
"""Load and prepare preference datasets for RL training.
|
||||
def _get_path(ds_hash, cfg):
|
||||
prepared_ds_path = (
|
||||
Path(cfg.dataset_prepared_path) / ds_hash
|
||||
if cfg.dataset_prepared_path
|
||||
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
||||
)
|
||||
|
||||
Loads training and evaluation datasets, handling preprocessing, caching, and
|
||||
deduplication as configured. Uses FileLock for distributed coordination.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing dataset and training settings.
|
||||
tokenizer: Tokenizer to use for processing text.
|
||||
|
||||
Returns:
|
||||
Tuple of (train_dataset, eval_dataset). eval_dataset may be None
|
||||
if no evaluation dataset is configured.
|
||||
"""
|
||||
|
||||
def _load_datasets():
|
||||
# Load training dataset
|
||||
train_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="train")
|
||||
|
||||
# Load or create evaluation dataset
|
||||
eval_dataset: Dataset | None = None
|
||||
if cfg.test_datasets:
|
||||
eval_dataset = _load_or_create_dataset_split(cfg, tokenizer, split="test")
|
||||
elif cfg.val_set_size:
|
||||
# Create validation split from training data
|
||||
train_dataset, eval_dataset = create_train_validation_split(
|
||||
train_dataset, cfg, cfg.val_set_size
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
# Prepare datasets (with file locking logic for multiple ranks)
|
||||
loader = FileLockLoader(cfg)
|
||||
try:
|
||||
train_dataset, eval_dataset = loader.load(_load_datasets)
|
||||
finally:
|
||||
loader.cleanup()
|
||||
|
||||
# Apply deduplication if configured
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=train_dataset, other_dataset=eval_dataset
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
return prepared_ds_path
|
||||
|
||||
|
||||
def _map_dataset(
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | DatasetDict,
|
||||
ds_transform_fn: Callable[..., Any],
|
||||
tokenizer: Any | None = None,
|
||||
**map_kwargs: Any,
|
||||
) -> Dataset:
|
||||
"""Apply transformation function to dataset.
|
||||
def _load_preprocessed_ds(cfg, sub_cfg):
|
||||
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
||||
prepared_ds_path = _get_path(ds_hash, cfg)
|
||||
dataset = None
|
||||
|
||||
Args:
|
||||
cfg: Configuration object.
|
||||
dataset: Dataset to transform.
|
||||
ds_transform_fn: Transformation function to apply.
|
||||
tokenizer: Optional tokenizer for transformation.
|
||||
**map_kwargs: Additional arguments for dataset mapping.
|
||||
# pylint: disable=duplicate-code
|
||||
if (
|
||||
cfg.dataset_prepared_path
|
||||
and any(prepared_ds_path.glob("*"))
|
||||
and not cfg.is_preprocess
|
||||
):
|
||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||
dataset = load_from_disk(str(prepared_ds_path))
|
||||
|
||||
Returns:
|
||||
Transformed dataset.
|
||||
"""
|
||||
return dataset
|
||||
|
||||
|
||||
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
||||
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
||||
prepared_ds_path = _get_path(ds_hash, cfg)
|
||||
|
||||
if cfg.is_preprocess and is_main_process():
|
||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
|
||||
|
||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
sig = inspect.signature(ds_transform_fn)
|
||||
if "tokenizer" in sig.parameters:
|
||||
if not tokenizer:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||
|
||||
if isinstance(dataset, DatasetDict):
|
||||
dataset = dataset["train"]
|
||||
if isinstance(data_set, DatasetDict):
|
||||
data_set = data_set["train"]
|
||||
|
||||
dataset = dataset.map(
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
@@ -119,27 +77,13 @@ def _map_dataset(
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
return dataset
|
||||
return data_set
|
||||
|
||||
|
||||
def _drop_long_sequences(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> bool:
|
||||
"""Filter out samples that exceed maximum sequence length.
|
||||
|
||||
Args:
|
||||
sample: Dataset sample to check.
|
||||
rl: Reinforcement learning type.
|
||||
tokenizer: Tokenizer for length calculation.
|
||||
sequence_len: Maximum allowed sequence length.
|
||||
|
||||
Returns:
|
||||
True if sample should be kept, False if it should be dropped.
|
||||
|
||||
Raises:
|
||||
ValueError: If required keys are missing or RL type is unknown.
|
||||
"""
|
||||
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
|
||||
def drop_long_rl_seq(
|
||||
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
|
||||
):
|
||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
@@ -179,115 +123,132 @@ def _drop_long_sequences(
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
"""Load and process dataset split for RL training.
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
for config_dataset in datasets_w_name_generator(dataset_cfgs):
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token, streaming=False
|
||||
)
|
||||
split_datasets.append(ds)
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing dataset settings.
|
||||
split: Dataset split to load ("train" or "test").
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
Returns:
|
||||
Combined and processed dataset for the specified split.
|
||||
"""
|
||||
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
|
||||
split_datasets: list[Dataset | DatasetDict] = []
|
||||
for i, data_set in enumerate(split_datasets):
|
||||
_type = dataset_cfgs[i]["type"]
|
||||
if _type:
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
if _cfg.rl is RLType.ORPO:
|
||||
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
||||
elif _cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
|
||||
for dataset_config in datasets_with_name_generator(datasets_configs):
|
||||
dataset: Dataset | DatasetDict = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=False
|
||||
)
|
||||
split_datasets.append(dataset)
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
for i, dataset in enumerate(split_datasets):
|
||||
_type = datasets_configs[i]["type"]
|
||||
if _type:
|
||||
if isinstance(_type, DictDefault):
|
||||
_type = "user_defined.default"
|
||||
if cfg.rl is RLType.ORPO:
|
||||
ds_transform_fn = load_orpo(_type, cfg, dataset_idx=i)
|
||||
elif cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, cfg, dataset_idx=i)
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
elif _cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, cfg, dataset_idx=i)
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
split_datasets[i] = data_set
|
||||
|
||||
map_kwargs: dict[str, Any] = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = _map_dataset(
|
||||
cfg, dataset, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
if not cfg.skip_prepare_dataset:
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
)
|
||||
|
||||
combined_datasets = concatenate_datasets(split_datasets)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
|
||||
|
||||
return combined_datasets
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_is_preprocessed = False
|
||||
eval_is_preprocessed = False
|
||||
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
|
||||
train_is_preprocessed = True
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen", and "rejected" already preprocessed
|
||||
split_datasets[i] = dataset
|
||||
train_dataset = load_split(cfg.datasets, cfg)
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
drop_long = partial(
|
||||
_drop_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
eval_dataset = None
|
||||
if cfg.test_datasets:
|
||||
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
|
||||
eval_is_preprocessed = True
|
||||
else:
|
||||
eval_dataset = load_split(cfg.test_datasets, cfg)
|
||||
if not eval_dataset:
|
||||
if cfg.val_set_size:
|
||||
seed = cfg.seed if cfg.seed is not None else 42
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||
to_hash_train = (
|
||||
train_dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
to_hash_test = (
|
||||
train_dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
ds_w_test_split = train_dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
seed=seed,
|
||||
shuffle=False,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
eval_dataset = ds_w_test_split["test"]
|
||||
train_dataset = ds_w_test_split["train"]
|
||||
|
||||
# Merge datasets
|
||||
dataset = merge_datasets(split_datasets, cfg)
|
||||
if not train_is_preprocessed:
|
||||
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
||||
if eval_dataset and not eval_is_preprocessed:
|
||||
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
# Save preprocessed dataset
|
||||
dataset_hash = generate_dataset_hash_from_config(
|
||||
cfg, datasets_configs, tokenizer.name_or_path
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=train_dataset, eval_dataset=eval_dataset
|
||||
)
|
||||
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def _load_or_create_dataset_split(
|
||||
cfg: DictDefault, tokenizer: PreTrainedTokenizer, split: Literal["train", "test"]
|
||||
) -> Dataset:
|
||||
"""Load preprocessed dataset or create new one for given split.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object.
|
||||
tokenizer: Tokenizer to use for processing text.
|
||||
split: Dataset split to load.
|
||||
|
||||
Returns:
|
||||
Tuple of (dataset, is_preprocessed).
|
||||
"""
|
||||
# Select correct dataset configuration based on split
|
||||
datasets_config = cfg.datasets if split == "train" else cfg.test_datasets
|
||||
|
||||
# Generate dataset hash for caching
|
||||
dataset_hash = generate_dataset_hash_from_config(
|
||||
cfg, datasets_config, tokenizer.name_or_path
|
||||
)
|
||||
|
||||
# Try loading from hub if push_dataset_to_hub is configured
|
||||
dataset = None
|
||||
if cfg.push_dataset_to_hub:
|
||||
dataset = try_load_from_hub(cfg, dataset_hash, split)
|
||||
|
||||
# Attempt to load preprocessed dataset
|
||||
if dataset is None:
|
||||
dataset = load_preprocessed_dataset(cfg, dataset_hash)
|
||||
|
||||
# Otherwise, load it
|
||||
if dataset is None:
|
||||
dataset = _load_split(cfg, split=split)
|
||||
|
||||
return dataset
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,21 +1,11 @@
|
||||
"""Dataset loading shared utils."""
|
||||
"""
|
||||
dataset loading shared utils
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
IterableDataset,
|
||||
IterableDatasetDict,
|
||||
concatenate_datasets,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from huggingface_hub.errors import (
|
||||
HFValidationError,
|
||||
@@ -23,141 +13,78 @@ from huggingface_hub.errors import (
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from adlfs import AzureBlobFileSystem
|
||||
from gcsfs import GCSFileSystem
|
||||
from ocifs import OCIFileSystem
|
||||
from s3fs import S3FileSystem
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
EXTENSIONS_TO_DATASET_TYPES = {
|
||||
".parquet": "parquet",
|
||||
".arrow": "arrow",
|
||||
".csv": "csv",
|
||||
".txt": "text",
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_type(dataset_config: DictDefault) -> str:
|
||||
"""Get the dataset type from the path if it's not specified."""
|
||||
if dataset_config.ds_type:
|
||||
return dataset_config.ds_type
|
||||
|
||||
for extension, dataset_type in EXTENSIONS_TO_DATASET_TYPES.items():
|
||||
if extension in dataset_config.path:
|
||||
return dataset_type
|
||||
|
||||
return "json"
|
||||
def get_ds_type(config_dataset: DictDefault):
|
||||
"""
|
||||
Get the dataset type from the path if it's not specified
|
||||
"""
|
||||
ds_type = "json"
|
||||
if config_dataset.ds_type:
|
||||
ds_type = config_dataset.ds_type
|
||||
elif ".parquet" in config_dataset.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in config_dataset.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in config_dataset.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in config_dataset.path:
|
||||
ds_type = "text"
|
||||
return ds_type
|
||||
|
||||
|
||||
def datasets_with_name_generator(
|
||||
dataset_configs: list[DictDefault],
|
||||
) -> Generator[DictDefault, None, None]:
|
||||
"""Yields expanded dataset configurations based on multiple names or preprocessing
|
||||
shards.
|
||||
|
||||
When a dataset config has a list of names, it yields separate configs for each
|
||||
name. When a dataset config specifies preprocessing shards, it yields configs for
|
||||
each shard.
|
||||
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
|
||||
"""
|
||||
Yields dataset configs handling multiple names or preprocess_shards
|
||||
|
||||
Args:
|
||||
dataset_configs: List of dataset configuration objects.
|
||||
|
||||
Yields:
|
||||
Individual dataset configurations, expanded as needed for names or shards.
|
||||
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
|
||||
"""
|
||||
for config in dataset_configs:
|
||||
if config.name and isinstance(config.name, list):
|
||||
for name in config.name:
|
||||
yield DictDefault({**config, "name": name})
|
||||
elif config.preprocess_shards and not config.shards:
|
||||
for shard_idx in range(config.preprocess_shards):
|
||||
for dataset in dataset_configs:
|
||||
if dataset.name and isinstance(dataset.name, list):
|
||||
# load_dataset doesn't properly handle multiple named configurations
|
||||
# at the same time for a given dataset
|
||||
for name in dataset.name:
|
||||
yield DictDefault({**dataset, "name": name})
|
||||
elif dataset.preprocess_shards and not dataset.shards:
|
||||
for shard in range(dataset.preprocess_shards):
|
||||
yield DictDefault(
|
||||
{
|
||||
**config,
|
||||
"shards": config.preprocess_shards,
|
||||
"shards_idx": shard_idx,
|
||||
**dataset,
|
||||
"shards": dataset.preprocess_shards,
|
||||
"shards_idx": shard,
|
||||
}
|
||||
)
|
||||
else:
|
||||
yield config
|
||||
yield dataset
|
||||
|
||||
|
||||
def load_dataset_with_config(
|
||||
dataset_config: DictDefault, use_auth_token: bool, streaming=False
|
||||
) -> Dataset | IterableDataset:
|
||||
"""Load a dataset from a config. Handles datasets that are stored locally, in the
|
||||
HuggingFace Hub, in a remote filesystem (S3, GCS, Azure, OCI), a URL, or
|
||||
`data_files`.
|
||||
def load_dataset_w_config(
|
||||
config_dataset: DictDefault, use_auth_token: bool, streaming=False
|
||||
) -> Union[Dataset, DatasetDict]:
|
||||
"""
|
||||
Load a dataset from a config
|
||||
|
||||
Args:
|
||||
dataset_config: Single dataset config.
|
||||
use_auth_token: Whether to use HF auth token.
|
||||
streaming: Whether to stream the dataset.
|
||||
|
||||
Returns:
|
||||
Loaded dataset.
|
||||
config_dataset: single dataset config
|
||||
use_auth_token: whether to use HF auth token
|
||||
streaming: whether to stream the dataset
|
||||
"""
|
||||
# Set up common kwargs for dataset loading
|
||||
load_dataset_kwargs = {
|
||||
"split": dataset_config.split if dataset_config.split else None,
|
||||
"name": dataset_config.name,
|
||||
"streaming": streaming,
|
||||
"trust_remote_code": dataset_config.trust_remote_code,
|
||||
}
|
||||
|
||||
# First check if it's a local path
|
||||
if Path(dataset_config.path).exists():
|
||||
return _load_from_local_path(dataset_config, load_dataset_kwargs)
|
||||
|
||||
# Check if it's a HuggingFace dataset
|
||||
is_hub_dataset = _check_if_hub_dataset(dataset_config, use_auth_token)
|
||||
|
||||
# Check if it's a cloud storage path and get appropriate filesystem
|
||||
remote_fs, storage_options = _get_remote_filesystem(dataset_config.path)
|
||||
is_cloud_dataset = False
|
||||
if remote_fs:
|
||||
try:
|
||||
is_cloud_dataset = remote_fs.exists(dataset_config.path)
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
# Load from appropriate source
|
||||
if is_hub_dataset:
|
||||
return _load_from_hub(dataset_config, use_auth_token, load_dataset_kwargs)
|
||||
if is_cloud_dataset:
|
||||
return _load_from_cloud(
|
||||
dataset_config, remote_fs, storage_options, load_dataset_kwargs
|
||||
)
|
||||
if dataset_config.path.startswith("https://"):
|
||||
return _load_from_url(dataset_config, load_dataset_kwargs)
|
||||
if dataset_config.data_files:
|
||||
return _load_from_data_files(dataset_config, load_dataset_kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"The dataset could not be loaded. This could be due to a misconfigured dataset path "
|
||||
f"({dataset_config.path}). Try double-check your path / name / data_files. "
|
||||
f"This is not caused by the dataset type."
|
||||
)
|
||||
|
||||
|
||||
def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) -> bool:
|
||||
"""Check if a dataset exists on the HuggingFace Hub."""
|
||||
# pylint: disable=invalid-name
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||
ds_from_hub = False
|
||||
try:
|
||||
# this is just a basic check to see if the path is a
|
||||
# valid HF dataset that's loadable
|
||||
snapshot_download(
|
||||
repo_id=dataset_config.path,
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
token=use_auth_token,
|
||||
revision=dataset_config.revision,
|
||||
revision=config_dataset.revision,
|
||||
ignore_patterns=["*"],
|
||||
)
|
||||
return True
|
||||
ds_from_hub = True
|
||||
except (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
@@ -166,373 +93,198 @@ def _check_if_hub_dataset(dataset_config: DictDefault, use_auth_token: bool) ->
|
||||
HFValidationError,
|
||||
ValueError,
|
||||
):
|
||||
return False
|
||||
pass
|
||||
|
||||
|
||||
def _get_remote_filesystem(
|
||||
path: str,
|
||||
) -> tuple[
|
||||
S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem | None, dict
|
||||
]:
|
||||
"""Get the appropriate filesystem for a remote path."""
|
||||
if path.startswith("s3://"):
|
||||
ds_from_cloud = False
|
||||
storage_options: dict = {}
|
||||
remote_file_system = None
|
||||
if config_dataset.path.startswith("s3://"):
|
||||
try:
|
||||
import s3fs
|
||||
|
||||
storage_options = {"anon": False}
|
||||
return s3fs.S3FileSystem(**storage_options), storage_options
|
||||
import s3fs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError("s3:// paths require s3fs to be installed") from exc
|
||||
|
||||
elif path.startswith(("gs://", "gcs://")):
|
||||
# Reads env, credentials from ~/.aws/credentials, or IAM metadata provider
|
||||
# https://s3fs.readthedocs.io/en/latest/index.html?highlight=storage_options#credentials
|
||||
storage_options = {"anon": False}
|
||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
|
||||
"gcs://"
|
||||
):
|
||||
try:
|
||||
import gcsfs
|
||||
|
||||
storage_options = {"token": None} # type: ignore
|
||||
return gcsfs.GCSFileSystem(**storage_options), storage_options
|
||||
import gcsfs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"gs:// or gcs:// paths require gcsfs to be installed"
|
||||
) from exc
|
||||
|
||||
elif path.startswith(("adl://", "abfs://", "az://")):
|
||||
# gcsfs will use default credentials from the environment else anon
|
||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
||||
storage_options = {"token": None}
|
||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
||||
elif (
|
||||
config_dataset.path.startswith("adl://")
|
||||
or config_dataset.path.startswith("abfs://")
|
||||
or config_dataset.path.startswith("az://")
|
||||
):
|
||||
try:
|
||||
import adlfs
|
||||
|
||||
storage_options = {"anon": False}
|
||||
return adlfs.AzureBlobFileSystem(**storage_options), storage_options
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"adl:// or abfs:// paths require adlfs to be installed"
|
||||
) from exc
|
||||
|
||||
elif path.startswith("oci://"):
|
||||
# # Ensure you have the following environment variables set:
|
||||
# # Gen 1
|
||||
# storage_options = {
|
||||
# "tenant_id": AZURE_STORAGE_TENANT_ID,
|
||||
# "client_id": AZURE_STORAGE_CLIENT_ID,
|
||||
# "client_secret": AZURE_STORAGE_CLIENT_SECRET,
|
||||
# }
|
||||
# # Gen 2
|
||||
# storage_options = {
|
||||
# "account_name": AZURE_STORAGE_ACCOUNT_NAME,
|
||||
# "account_key": AZURE_STORAGE_ACCOUNT_KEY,
|
||||
# }
|
||||
|
||||
# Reads env
|
||||
# https://github.com/fsspec/adlfs?tab=readme-ov-file#setting-credentials
|
||||
storage_options = {"anon": False}
|
||||
remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith("oci://"):
|
||||
try:
|
||||
import ocifs
|
||||
|
||||
storage_options = {}
|
||||
return ocifs.OCIFileSystem(**storage_options), storage_options
|
||||
except ImportError as exc:
|
||||
raise ImportError("oci:// paths require ocifs to be installed") from exc
|
||||
|
||||
return None, {}
|
||||
# https://ocifs.readthedocs.io/en/latest/getting-connected.html#Using-Environment-Variables
|
||||
remote_file_system = ocifs.OCIFileSystem(**storage_options)
|
||||
|
||||
|
||||
def _load_from_local_path(
|
||||
dataset_config: DictDefault, load_dataset_kwargs: dict
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from a local path."""
|
||||
local_path = Path(dataset_config.path)
|
||||
|
||||
if local_path.is_dir():
|
||||
if dataset_config.data_files:
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.data_files,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
try:
|
||||
return load_from_disk(dataset_config.path)
|
||||
except FileNotFoundError:
|
||||
load_dataset_kwargs["streaming"] = False
|
||||
return load_dataset(dataset_config.path, **load_dataset_kwargs)
|
||||
elif local_path.is_file():
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
load_dataset_kwargs["streaming"] = False
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.path,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
|
||||
|
||||
def _load_from_hub(
|
||||
dataset_config: DictDefault, use_auth_token: bool, load_dataset_kwargs: dict
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from the HuggingFace Hub."""
|
||||
return load_dataset(
|
||||
dataset_config.path,
|
||||
data_files=dataset_config.data_files,
|
||||
token=use_auth_token,
|
||||
revision=dataset_config.revision,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_from_cloud(
|
||||
dataset_config: DictDefault,
|
||||
remote_fs: S3FileSystem | GCSFileSystem | AzureBlobFileSystem | OCIFileSystem,
|
||||
storage_options: dict,
|
||||
load_dataset_kwargs: dict,
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from cloud storage."""
|
||||
if remote_fs.isdir(dataset_config.path):
|
||||
return load_from_disk(
|
||||
dataset_config.path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
|
||||
if remote_fs.isfile(dataset_config.path):
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.path,
|
||||
storage_options=storage_options,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Cloud path {dataset_config.path} is neither a directory nor a file"
|
||||
)
|
||||
|
||||
|
||||
def _load_from_url(
|
||||
dataset_config: DictDefault, load_dataset_kwargs: dict
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from a URL."""
|
||||
dataset_type = get_dataset_type(dataset_config)
|
||||
return load_dataset(
|
||||
dataset_type,
|
||||
data_files=dataset_config.path,
|
||||
**load_dataset_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _load_from_data_files(
|
||||
dataset_config: DictDefault, load_dataset_kwargs: dict
|
||||
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
|
||||
"""Load a dataset from data files."""
|
||||
file_path = None
|
||||
|
||||
if isinstance(dataset_config.data_files, str):
|
||||
file_path = hf_hub_download(
|
||||
repo_id=dataset_config.path,
|
||||
repo_type="dataset",
|
||||
filename=dataset_config.data_files,
|
||||
revision=dataset_config.revision,
|
||||
)
|
||||
elif isinstance(dataset_config.data_files, list):
|
||||
file_path = [
|
||||
hf_hub_download(
|
||||
repo_id=dataset_config.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
revision=dataset_config.revision,
|
||||
)
|
||||
for file in dataset_config.data_files
|
||||
]
|
||||
else:
|
||||
raise ValueError("data_files must be either a string or list of strings")
|
||||
|
||||
return load_dataset("json", data_files=file_path, **load_dataset_kwargs)
|
||||
|
||||
|
||||
def generate_split_fingerprints(
|
||||
dataset: Dataset, val_set_size: int | float, seed: int
|
||||
) -> tuple[str, str]:
|
||||
"""Generate consistent fingerprints for train/test splits."""
|
||||
fingerprint = dataset._fingerprint # pylint: disable=protected-access
|
||||
|
||||
train_hash_input = f"{fingerprint}|{val_set_size}|train|{seed}"
|
||||
test_hash_input = f"{fingerprint}|{val_set_size}|test|{seed}"
|
||||
|
||||
train_fingerprint = md5(train_hash_input)
|
||||
test_fingerprint = md5(test_hash_input)
|
||||
|
||||
return train_fingerprint, test_fingerprint
|
||||
|
||||
|
||||
def get_prepared_dataset_path(cfg: DictDefault, dataset_hash: str) -> Path:
|
||||
"""Get standardized path for prepared datasets.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object.
|
||||
dataset_hash: Hash identifying the specific dataset configuration.
|
||||
|
||||
Returns:
|
||||
Path where the prepared dataset should be stored.
|
||||
"""
|
||||
base_path = cfg.dataset_prepared_path or DEFAULT_DATASET_PREPARED_PATH
|
||||
return Path(base_path) / dataset_hash
|
||||
|
||||
|
||||
def create_train_validation_split(
|
||||
dataset: Dataset, cfg: DictDefault, val_set_size: int | float
|
||||
) -> tuple[Dataset, Dataset]:
|
||||
"""Create train/validation split with consistent fingerprinting.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to split.
|
||||
cfg: Configuration object containing seed and other settings.
|
||||
val_set_size: Size of validation set (absolute number or fraction).
|
||||
|
||||
Returns:
|
||||
Tuple of (train_dataset, eval_dataset).
|
||||
"""
|
||||
train_fingerprint, test_fingerprint = generate_split_fingerprints(
|
||||
dataset, val_set_size, cfg.seed
|
||||
)
|
||||
|
||||
# Apply deduplication before splitting if configured
|
||||
if cfg.dataset_exact_deduplication:
|
||||
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
|
||||
split_dataset = dataset.train_test_split(
|
||||
test_size=val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
|
||||
return split_dataset["train"], split_dataset["test"]
|
||||
|
||||
|
||||
def _generate_from_iterable_dataset(
|
||||
dataset: IterableDataset, worker_id: list[int], num_workers: list[int]
|
||||
) -> Generator[Any, None, None]:
|
||||
"""Generator function to correctly split the dataset for each worker"""
|
||||
for i, item in enumerate(dataset):
|
||||
if i % num_workers[0] == worker_id[0]:
|
||||
yield item
|
||||
|
||||
|
||||
def save_preprocessed_dataset(
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset,
|
||||
dataset_hash: str,
|
||||
split: str,
|
||||
) -> None:
|
||||
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
|
||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||
if isinstance(dataset, IterableDataset):
|
||||
num_workers = cfg.dataset_processes
|
||||
|
||||
ds_from_iter = Dataset.from_generator(
|
||||
functools.partial(_generate_from_iterable_dataset, dataset),
|
||||
features=dataset.features,
|
||||
num_proc=num_workers,
|
||||
split=split,
|
||||
gen_kwargs={
|
||||
"worker_id": list(range(num_workers)),
|
||||
"num_workers": [num_workers] * num_workers,
|
||||
},
|
||||
)
|
||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
||||
else:
|
||||
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
if cfg.push_dataset_to_hub:
|
||||
LOG.info(
|
||||
"Pushing merged prepared dataset to Huggingface hub at "
|
||||
f"{cfg.push_dataset_to_hub} (version {dataset_hash})...",
|
||||
main_process_only=False,
|
||||
)
|
||||
dataset.push_to_hub(
|
||||
cfg.push_dataset_to_hub,
|
||||
dataset_hash,
|
||||
private=True,
|
||||
)
|
||||
|
||||
|
||||
def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None:
|
||||
"""Load preprocessed dataset from disk if available.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object.
|
||||
dataset_hash: Hash identifying the dataset configuration.
|
||||
|
||||
Returns:
|
||||
Loaded dataset if found and conditions are met, None otherwise.
|
||||
"""
|
||||
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
|
||||
|
||||
if (
|
||||
cfg.dataset_prepared_path
|
||||
and any(prepared_ds_path.glob("*"))
|
||||
and not cfg.skip_prepare_dataset
|
||||
and not cfg.is_preprocess
|
||||
):
|
||||
LOG.info(
|
||||
f"Loading prepared dataset from disk at {prepared_ds_path}...",
|
||||
main_process_only=False,
|
||||
)
|
||||
return load_from_disk(str(prepared_ds_path))
|
||||
|
||||
LOG.info(
|
||||
f"Unable to find prepared dataset in {prepared_ds_path}",
|
||||
main_process_only=False,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def try_load_from_hub(
|
||||
cfg: DictDefault, dataset_hash: str, split: str
|
||||
) -> Dataset | None:
|
||||
"""Try to load the prepared dataset from HuggingFace Hub."""
|
||||
try:
|
||||
LOG.info(
|
||||
"Attempting to load prepared dataset from HuggingFace Hub at "
|
||||
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
|
||||
)
|
||||
dataset = load_dataset(
|
||||
cfg.push_dataset_to_hub,
|
||||
dataset_hash,
|
||||
token=cfg.hf_use_auth_token,
|
||||
)
|
||||
return dataset[split]
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
LOG.info("Unable to find prepared dataset in HuggingFace Hub")
|
||||
return None
|
||||
if remote_file_system and remote_file_system.exists(config_dataset.path):
|
||||
ds_from_cloud = True
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
def generate_dataset_hash_from_config(
|
||||
cfg: DictDefault, cfg_datasets: list, tokenizer_name: str
|
||||
) -> str:
|
||||
"""Generate a hash to uniquely identify a dataset configuration for SFT.
|
||||
|
||||
Args:
|
||||
cfg: Main configuration object.
|
||||
cfg_datasets: List of dataset configurations.
|
||||
tokenizer_name: Name of the tokenizer being used.
|
||||
|
||||
Returns:
|
||||
MD5 hash string representing the configuration.
|
||||
"""
|
||||
config_str = (
|
||||
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
|
||||
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}|"
|
||||
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
|
||||
f"|{tokenizer_name}"
|
||||
)
|
||||
return str(md5(config_str))
|
||||
|
||||
|
||||
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
||||
"""Merge multiple datasets into one with optional shuffling.
|
||||
|
||||
Args:
|
||||
datasets: List of datasets to merge.
|
||||
cfg: Configuration object containing shuffle settings.
|
||||
|
||||
Returns:
|
||||
Merged dataset.
|
||||
"""
|
||||
if len(datasets) == 1:
|
||||
return datasets[0]
|
||||
|
||||
LOG.info("Merging datasets...")
|
||||
merged_dataset = concatenate_datasets(datasets)
|
||||
|
||||
if cfg.shuffle_merged_datasets:
|
||||
LOG.debug("Shuffling merged datasets...")
|
||||
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
||||
# gather extra args from the config
|
||||
load_ds_kwargs = {}
|
||||
if config_dataset.split:
|
||||
load_ds_kwargs["split"] = config_dataset.split
|
||||
else:
|
||||
LOG.debug("Not shuffling merged datasets.")
|
||||
load_ds_kwargs["split"] = None
|
||||
|
||||
return merged_dataset
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(config_dataset.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
if config_dataset.data_files:
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=streaming,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ds = load_from_disk(
|
||||
config_dataset.path
|
||||
) # pylint: disable=invalid-name
|
||||
except FileNotFoundError:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
||||
)
|
||||
elif ds_from_hub:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=streaming,
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
if remote_file_system.isdir(config_dataset.path):
|
||||
ds = load_from_disk(
|
||||
config_dataset.path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
elif remote_file_system.isfile(config_dataset.path):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=streaming,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=streaming,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif config_dataset.data_files:
|
||||
fp: str | list[str] | None = None
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=config_dataset.data_files,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
elif isinstance(config_dataset.data_files, list):
|
||||
fp = []
|
||||
for file in config_dataset.data_files:
|
||||
fp.append(
|
||||
hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
revision=config_dataset.revision,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("data_files must be either a string or list of strings")
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
name=config_dataset.name,
|
||||
data_files=fp,
|
||||
streaming=streaming,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError(
|
||||
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
|
||||
f"({config_dataset.path}). Try double-check your path / name / data_files. "
|
||||
"This is not caused by the dataset type."
|
||||
)
|
||||
|
||||
return ds
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
"""Data handling helpers"""
|
||||
"""data handling helpers"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
@@ -21,7 +19,9 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class RetryStrategy(Enum):
|
||||
"""Enum for retry strategies."""
|
||||
"""
|
||||
Enum for retry strategies.
|
||||
"""
|
||||
|
||||
CONSTANT = 1
|
||||
LINEAR = 2
|
||||
@@ -30,18 +30,7 @@ class RetryStrategy(Enum):
|
||||
|
||||
def retry_on_request_exceptions(
|
||||
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
|
||||
) -> Callable:
|
||||
"""Decorator that retries function calls on specific request exceptions.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
delay: Base delay between retries in seconds.
|
||||
retry_strategy: Strategy for calculating retry delays.
|
||||
|
||||
Returns:
|
||||
Decorated function with retry logic.
|
||||
"""
|
||||
|
||||
):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||
@@ -51,7 +40,6 @@ def retry_on_request_exceptions(
|
||||
except (
|
||||
requests.exceptions.ReadTimeout,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.HTTPError,
|
||||
huggingface_hub.errors.HfHubHTTPError,
|
||||
) as exc:
|
||||
if attempt < max_retries - 1:
|
||||
@@ -71,7 +59,6 @@ def retry_on_request_exceptions(
|
||||
|
||||
|
||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
"""Generate MD5 hash of a string."""
|
||||
try:
|
||||
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
||||
except TypeError:
|
||||
@@ -79,89 +66,102 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
|
||||
|
||||
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
"""Generate SHA256 hash of a string."""
|
||||
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
|
||||
|
||||
|
||||
def _deduplicate_dataset(
|
||||
dataset: Dataset,
|
||||
seen_hashes: set[str] | None = None,
|
||||
) -> tuple[Dataset, set[str]]:
|
||||
"""Remove duplicate rows from a dataset using SHA256 hashes.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to deduplicate.
|
||||
seen_hashes: Set of previously seen row hashes (for cross-deduplication).
|
||||
|
||||
Returns:
|
||||
Tuple of deduplicated dataset and the set of seen hashes.
|
||||
"""
|
||||
if seen_hashes is None:
|
||||
seen_hashes = set()
|
||||
|
||||
def deduplicate_dataset(
|
||||
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
|
||||
) -> Dataset:
|
||||
unique_indices = []
|
||||
for idx, row in enumerate(dataset):
|
||||
row_hash = sha256(str(row)) # Using SHA256 for collision resistance
|
||||
if row_hash not in seen_hashes:
|
||||
seen_hashes.add(row_hash)
|
||||
unique_indices.append(idx)
|
||||
|
||||
return dataset.select(unique_indices), seen_hashes
|
||||
for idx, row in enumerate(dataset):
|
||||
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
|
||||
if row_hash not in seen_hashes:
|
||||
seen_hashes[row_hash] = [idx]
|
||||
unique_indices.append(idx)
|
||||
else:
|
||||
# Check for collision by looking up the original dataset indices
|
||||
original_indices = seen_hashes[row_hash]
|
||||
is_duplicate = False
|
||||
for original_idx in original_indices:
|
||||
if (
|
||||
not idx == original_idx
|
||||
and original_idx < len(dataset)
|
||||
and str(dataset[original_idx]) == str(row)
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
# Check in the other dataset if provided
|
||||
if other_dataset is not None:
|
||||
if original_idx < len(other_dataset) and str(
|
||||
other_dataset[original_idx]
|
||||
) == str(row):
|
||||
is_duplicate = True
|
||||
break
|
||||
if not is_duplicate:
|
||||
seen_hashes[row_hash].append(idx)
|
||||
unique_indices.append(idx)
|
||||
continue
|
||||
return dataset.select(unique_indices)
|
||||
|
||||
|
||||
def deduplicate_and_log_datasets(
|
||||
dataset: Dataset,
|
||||
other_dataset: Dataset | None = None,
|
||||
dataset_name: str | None = "train",
|
||||
other_name: str | None = "eval",
|
||||
) -> tuple[Dataset, Dataset | None]:
|
||||
"""Deduplicate datasets, with optional cross-dataset deduplication.
|
||||
|
||||
Args:
|
||||
dataset: Primary dataset to deduplicate.
|
||||
other_dataset: Optional second dataset to deduplicate against the first.
|
||||
dataset_name: Name for the primary dataset (for logging).
|
||||
other_name: Name for the second dataset (for logging).
|
||||
*,
|
||||
train_dataset: Dataset = None,
|
||||
eval_dataset: Dataset = None,
|
||||
dataset: Dataset = None,
|
||||
) -> tuple[Dataset, Dataset, Dataset]:
|
||||
"""
|
||||
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
|
||||
|
||||
Returns:
|
||||
Tuple of (deduplicated_dataset, deduplicated_other_dataset).
|
||||
tuple: Deduplicated train, eval, and additional datasets.
|
||||
"""
|
||||
# Deduplicate primary dataset
|
||||
LOG.info(
|
||||
f"Starting deduplication for {dataset_name} dataset. Original size: {len(dataset)}"
|
||||
)
|
||||
dataset, seen_rows = _deduplicate_dataset(dataset)
|
||||
LOG.info(
|
||||
f"Deduplication complete for {dataset_name} dataset. New size: {len(dataset)}"
|
||||
)
|
||||
seen_hashes: dict[str, list[int]] = {}
|
||||
|
||||
# Deduplicate second dataset if provided
|
||||
if other_dataset is not None:
|
||||
# Handle cases where datasets are None
|
||||
if train_dataset is not None:
|
||||
LOG.info(
|
||||
f"Starting deduplication for {other_name} dataset. Original size: {len(other_dataset)}"
|
||||
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
|
||||
)
|
||||
train_dataset = deduplicate_dataset(
|
||||
dataset=train_dataset, seen_hashes=seen_hashes
|
||||
)
|
||||
other_dataset, _ = _deduplicate_dataset(other_dataset, seen_rows)
|
||||
LOG.info(
|
||||
f"Deduplication complete for {other_name} dataset. New size: {len(other_dataset)}"
|
||||
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
|
||||
)
|
||||
else:
|
||||
LOG.info("Train dataset is None. Skipping deduplication.")
|
||||
|
||||
if eval_dataset is not None:
|
||||
LOG.info(
|
||||
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
|
||||
)
|
||||
eval_dataset = deduplicate_dataset(
|
||||
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
|
||||
)
|
||||
LOG.info(
|
||||
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
|
||||
)
|
||||
else:
|
||||
LOG.info("Eval dataset is None. Skipping deduplication.")
|
||||
|
||||
if dataset is not None and (eval_dataset is None and train_dataset is None):
|
||||
LOG.info(
|
||||
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
|
||||
)
|
||||
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
|
||||
LOG.info(
|
||||
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
|
||||
)
|
||||
|
||||
return dataset, other_dataset
|
||||
return train_dataset, eval_dataset, dataset
|
||||
|
||||
|
||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
||||
"""Remove sequences longer than configured maximum from dataset.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to filter.
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
Filtered dataset with long sequences removed.
|
||||
"""
|
||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
if "input_ids" not in dataset.column_names:
|
||||
LOG.warning(
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||
"expected for reward modeling."
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
|
||||
)
|
||||
return dataset
|
||||
|
||||
@@ -171,14 +171,20 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
)
|
||||
|
||||
with contextlib.suppress(AttributeError):
|
||||
try:
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
min_input_len = np.min(ds_lengths)
|
||||
LOG.info(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(ds_lengths)
|
||||
LOG.info(f"max_input_len: {max_input_len}")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||
try:
|
||||
prior_len = len(dataset)
|
||||
except TypeError:
|
||||
# handle iterable datasets case
|
||||
prior_len = None
|
||||
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
|
||||
@@ -1,425 +0,0 @@
|
||||
"""Data handling specific to SFT."""
|
||||
|
||||
import logging
|
||||
from typing import Any, NoReturn, cast
|
||||
|
||||
from datasets import (
|
||||
Dataset,
|
||||
IterableDataset,
|
||||
Sequence,
|
||||
Value,
|
||||
)
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
AlpacaReflectionPTStrategy,
|
||||
DatasetWrappingStrategy,
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
PromptTokenizingStrategy,
|
||||
SummarizeTLDRPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import (
|
||||
AlpacaPrompter,
|
||||
GPTeacherPrompter,
|
||||
JeopardyPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
Prompter,
|
||||
ReflectAlpacaPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_unknown_dataset_strategy(dataset_config: DictDefault) -> NoReturn:
|
||||
"""Raise error for unknown dataset strategy."""
|
||||
ds_type = dataset_config.type
|
||||
suffix = ""
|
||||
if ":load_" in ds_type:
|
||||
suffix = f"Did you mean {ds_type.replace(':load_', '.load_')}?"
|
||||
|
||||
error_message = f"unhandled prompt tokenization strategy: {ds_type}. {suffix}"
|
||||
LOG.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def get_dataset_wrapper(
|
||||
dataset_config: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset_base_type: str | None,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_prompt_style: str | None = None,
|
||||
processor: ProcessorMixin | None = None, # pylint: disable=unused-argument
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Create an appropriate dataset wrapper and prompter based on dataset
|
||||
configuration.
|
||||
|
||||
Args:
|
||||
dataset_config: Configuration for the dataset.
|
||||
tokenizer: Tokenizer to use for processing text.
|
||||
cfg: Global configuration object.
|
||||
dataset_base_type: The base type of the dataset.
|
||||
dataset: The actual dataset object.
|
||||
dataset_prompt_style: Optional prompt style specification.
|
||||
processor: Optional processor for multimodal datasets.
|
||||
|
||||
Returns:
|
||||
tuple of (dataset_wrapper, dataset_prompter).
|
||||
"""
|
||||
# Common parameters for dataset wrapping
|
||||
dataset_kwargs: dict[str, Any] = {
|
||||
"process_count": cfg.dataset_processes,
|
||||
"keep_in_memory": cfg.dataset_keep_in_memory is True,
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
f"Loading dataset: {dataset_config['path']} with base_type: "
|
||||
f"{dataset_base_type} and prompt_style: {dataset_prompt_style}"
|
||||
)
|
||||
|
||||
# Dataset is already tokenized
|
||||
if _is_dataset_already_tokenized(dataset):
|
||||
return dataset, UnsupportedPrompter()
|
||||
|
||||
# Custom dataset type definition
|
||||
if isinstance(dataset_config.type, DictDefault):
|
||||
return _handle_custom_dataset_type(
|
||||
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
|
||||
)
|
||||
|
||||
# Skip preparation if configured
|
||||
if cfg.skip_prepare_dataset:
|
||||
return dataset, None
|
||||
|
||||
# Bradley-Terry dataset
|
||||
if dataset_config.type.startswith("bradley_terry"):
|
||||
return _handle_bradley_terry_dataset(
|
||||
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
|
||||
)
|
||||
|
||||
# Stepwise supervised dataset
|
||||
if dataset_config.type.startswith("stepwise_supervised"):
|
||||
return _handle_stepwise_supervised_dataset(
|
||||
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
|
||||
)
|
||||
|
||||
# Try to load prompt tokenizer / dataset wrapper strategy from registry
|
||||
dataset_strategy = load(
|
||||
dataset_config.type, tokenizer, cfg, dataset_config, processor=processor
|
||||
)
|
||||
if dataset_strategy:
|
||||
return _handle_loaded_strategy(dataset_strategy, dataset, dataset_kwargs)
|
||||
|
||||
# Known dataset types with specific handling
|
||||
if dataset_base_type in DATASET_HANDLERS:
|
||||
handler = DATASET_HANDLERS[dataset_base_type]
|
||||
return handler(dataset_prompt_style, tokenizer, cfg, dataset, dataset_kwargs)
|
||||
|
||||
# Unhandled dataset type
|
||||
handle_unknown_dataset_strategy(dataset_config)
|
||||
|
||||
|
||||
def _is_dataset_already_tokenized(dataset: Dataset | IterableDataset) -> bool:
|
||||
"""Check if the dataset is already tokenized."""
|
||||
return (
|
||||
isinstance(dataset, Dataset)
|
||||
and "input_ids" in dataset.features
|
||||
and "attention_mask" in dataset.features
|
||||
and "labels" in dataset.features
|
||||
)
|
||||
|
||||
|
||||
def _handle_custom_dataset_type(
|
||||
dataset_config: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a custom dataset type defined in the configuration."""
|
||||
dataset_strategy = cast(
|
||||
PromptTokenizingStrategy,
|
||||
load("user_defined", tokenizer, cfg, dataset_config.type.to_dict()),
|
||||
)
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_bradley_terry_dataset(
|
||||
dataset_config: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Handle a Bradley-Terry dataset."""
|
||||
bt_type = dataset_config.type.split(".", 1)[1]
|
||||
dataset_strategy = bradley_terry_load(bt_type, tokenizer, cfg, dataset_config)
|
||||
|
||||
if not dataset_strategy:
|
||||
handle_unknown_dataset_strategy(dataset_config)
|
||||
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_stepwise_supervised_dataset(
|
||||
dataset_config: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a stepwise supervised dataset."""
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_strategy = load(dataset_config.type, tokenizer, cfg, dataset_config)
|
||||
|
||||
# We need to explicitly cast boolean labels to int
|
||||
# for compatibility with how trl's PRMTrainer works
|
||||
if isinstance(dataset, Dataset):
|
||||
dataset = dataset.cast_column("labels", Sequence(Value("int64")))
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_loaded_strategy(
|
||||
dataset_strategy: PromptTokenizingStrategy | DatasetWrappingStrategy,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Handle a dataset with a strategy loaded from the registry."""
|
||||
if isinstance(dataset_strategy, DatasetWrappingStrategy):
|
||||
return dataset_strategy.wrap_dataset(dataset, **dataset_kwargs), None
|
||||
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_alpaca_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle an Alpaca dataset."""
|
||||
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
|
||||
dataset_strategy = AlpacaPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_explainchoice_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle an ExplainChoice dataset."""
|
||||
dataset_prompter = MultipleChoiceExplainPrompter(dataset_prompt_style)
|
||||
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_concisechoice_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a ConciseChoice dataset."""
|
||||
dataset_prompter = MultipleChoiceConcisePrompter(dataset_prompt_style)
|
||||
dataset_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_summarizetldr_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a SummarizeTLDR dataset."""
|
||||
dataset_prompter = SummarizeTLDRPrompter(dataset_prompt_style)
|
||||
dataset_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_jeopardy_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a Jeopardy dataset."""
|
||||
dataset_prompter = JeopardyPrompter(dataset_prompt_style)
|
||||
dataset_strategy = JeopardyPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_oasst_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle an OpenAssistant dataset."""
|
||||
dataset_prompter = AlpacaPrompter(dataset_prompt_style)
|
||||
dataset_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_gpteacher_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a GPTeacher dataset."""
|
||||
dataset_prompter = GPTeacherPrompter(dataset_prompt_style)
|
||||
dataset_strategy = GPTeacherPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
def _handle_reflection_dataset(
|
||||
dataset_prompt_style: str | None,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
cfg: DictDefault,
|
||||
dataset: Dataset | IterableDataset,
|
||||
dataset_kwargs: dict[str, Any],
|
||||
) -> tuple[Dataset | IterableDataset, Prompter]:
|
||||
"""Handle a Reflection dataset."""
|
||||
dataset_prompter = ReflectAlpacaPrompter(dataset_prompt_style)
|
||||
dataset_strategy = AlpacaReflectionPTStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
dataset_strategy,
|
||||
dataset,
|
||||
**dataset_kwargs,
|
||||
)
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
DATASET_HANDLERS = {
|
||||
"alpaca": _handle_alpaca_dataset,
|
||||
"explainchoice": _handle_explainchoice_dataset,
|
||||
"concisechoice": _handle_concisechoice_dataset,
|
||||
"summarizetldr": _handle_summarizetldr_dataset,
|
||||
"jeopardy": _handle_jeopardy_dataset,
|
||||
"oasst": _handle_oasst_dataset,
|
||||
"gpteacher": _handle_gpteacher_dataset,
|
||||
"reflection": _handle_reflection_dataset,
|
||||
}
|
||||
@@ -1,567 +0,0 @@
|
||||
"""Wrapper for MistralTokenizer from mistral-common"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from torch import Tensor
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.utils.collators.core import IGNORE_INDEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
|
||||
|
||||
def _get_file_path(path_or_repo_id: str, filename: str) -> str:
|
||||
"""Get the file path from local or HF Hub"""
|
||||
if os.path.exists(path_or_repo_id):
|
||||
maybe_file_path = os.path.join(path_or_repo_id, filename)
|
||||
if os.path.exists(maybe_file_path):
|
||||
return maybe_file_path
|
||||
|
||||
raise FileNotFoundError(f"File not found at {path_or_repo_id}")
|
||||
|
||||
return hf_hub_download(repo_id=path_or_repo_id, filename=filename)
|
||||
|
||||
|
||||
class HFMistralTokenizer:
|
||||
"""
|
||||
Wraps mistral_common.tokens.tokenizers.mistral.MistralTokenizer
|
||||
and exposes HuggingFace API for special tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, mistral: MistralTokenizer, name_or_path: str, tokenizer_path: str
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
mistral: The mistral-common tokenizer to wrap.
|
||||
name_or_path: The name or path to the tokenizer files or the repo id.
|
||||
"""
|
||||
self._mistral = mistral
|
||||
self._padding_side = "right"
|
||||
self._name_or_path = name_or_path
|
||||
self._tokenizer_path = tokenizer_path
|
||||
|
||||
# Manual set to training mode
|
||||
from mistral_common.protocol.instruct.validator import (
|
||||
MistralRequestValidator,
|
||||
ValidationMode,
|
||||
)
|
||||
|
||||
# Check if MistralRequestValidator has a _mode attribute.
|
||||
# This is a private API and may change in the future.
|
||||
# pylint: disable=protected-access
|
||||
if not (
|
||||
hasattr(self._mistral, "_chat_completion_request_validator")
|
||||
and isinstance(
|
||||
self._mistral._chat_completion_request_validator,
|
||||
MistralRequestValidator,
|
||||
)
|
||||
and hasattr(self._mistral._chat_completion_request_validator, "_mode")
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Unable to switch mistral tokenizer to finetuning mode – "
|
||||
"private API `_chat_completion_request_validator._mode` missing."
|
||||
)
|
||||
|
||||
self._mistral._chat_completion_request_validator._mode = (
|
||||
ValidationMode.finetuning
|
||||
)
|
||||
|
||||
def _load_system_prompt(self, path_or_repo_id: str) -> str:
|
||||
"""Load system prompt from local or HF Hub.
|
||||
|
||||
Note: Unused for now as we don't want to explicitly set the system prompt if a user does
|
||||
not provide one.
|
||||
|
||||
Args:
|
||||
path_or_repo_id: The path to the tokenizer files or the repo id.
|
||||
|
||||
Returns:
|
||||
The system prompt.
|
||||
"""
|
||||
file_path = _get_file_path(path_or_repo_id, "SYSTEM_PROMPT.txt")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"System prompt file not found at {file_path}")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
return file.read()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.bos_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.pad_id
|
||||
|
||||
@property
|
||||
def unk_token_id(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.unk_id
|
||||
|
||||
@property
|
||||
def bos_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.bos_token_id)
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.eos_token_id)
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.pad_token_id)
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.id_to_piece(self.unk_token_id)
|
||||
|
||||
@property
|
||||
def padding_side(self) -> str:
|
||||
return self._padding_side
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return self._name_or_path
|
||||
|
||||
@property
|
||||
def chat_template(self) -> str | None:
|
||||
"""Chat template is not supported. Dummy method to satisfy HuggingFace API."""
|
||||
return None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.n_words
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
name_or_path: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> "HFMistralTokenizer":
|
||||
"""
|
||||
Load a mistral tekken tokenizer from a local file or HF Hub and wrap it.
|
||||
|
||||
Args:
|
||||
path_or_repo_id: The path to the tokenizer files or the repo id.
|
||||
revision: The revision of the tokenizer to download.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
A HFMistralTokenizer instance.
|
||||
"""
|
||||
if revision:
|
||||
raise NotImplementedError(
|
||||
"Revision not supported yet for mistral-common tokenizer"
|
||||
)
|
||||
|
||||
# only support Tekken tokenizer for now
|
||||
# downloads from HF Hub if not local
|
||||
tokenizer_path = _get_file_path(name_or_path, "tekken.json")
|
||||
|
||||
base = MistralTokenizer.from_file(tokenizer_path)
|
||||
|
||||
return cls(
|
||||
base,
|
||||
name_or_path=name_or_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
)
|
||||
|
||||
def save_pretrained(self, save_directory: str) -> None:
|
||||
"""
|
||||
Save the Tekken/SentencePiece model file so that from_pretrained can pick it up again.
|
||||
|
||||
Only Tekken models are supported.
|
||||
|
||||
Args:
|
||||
save_directory: The directory to save the tokenizer files.
|
||||
"""
|
||||
inner = self._mistral.instruct_tokenizer.tokenizer
|
||||
if isinstance(inner, Tekkenizer):
|
||||
# Create the directory and save the model
|
||||
try:
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# Verify directory was created
|
||||
if not os.path.exists(save_directory):
|
||||
raise RuntimeError(f"Failed to create directory: {save_directory}")
|
||||
|
||||
# Verify source file exists
|
||||
if not os.path.exists(self._tokenizer_path):
|
||||
raise FileNotFoundError(
|
||||
f"Source tokenizer file not found: {self._tokenizer_path}"
|
||||
)
|
||||
|
||||
destination_path = os.path.join(save_directory, "tekken.json")
|
||||
copyfile(self._tokenizer_path, destination_path)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to save tokenizer to {save_directory}: {e}. "
|
||||
f"Source path: {self._tokenizer_path}, "
|
||||
f"Directory exists: {os.path.exists(save_directory)}"
|
||||
) from e
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown tokenizer type: {type(inner)}")
|
||||
|
||||
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
|
||||
"""
|
||||
Encode a text string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
text: The text string to encode.
|
||||
add_special_tokens: Whether to add special tokens to the encoded tokens.
|
||||
|
||||
Returns:
|
||||
A list of token IDs.
|
||||
"""
|
||||
return self._mistral.instruct_tokenizer.tokenizer.encode(
|
||||
text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens,
|
||||
)
|
||||
|
||||
def decode(
|
||||
self, token_ids: int | list[int], skip_special_tokens: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Decode a list of token IDs into a text string.
|
||||
|
||||
Args:
|
||||
token_ids: The int or list of token IDs to decode.
|
||||
skip_special_tokens: Whether to skip special tokens in the decoded text.
|
||||
|
||||
Returns:
|
||||
The decoded text string.
|
||||
"""
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
|
||||
if skip_special_tokens:
|
||||
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
|
||||
|
||||
# to_string returns a string with special tokens
|
||||
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
|
||||
|
||||
def _create_mistral_chat_completion_request(
|
||||
self, conversation: list[dict], tools: list[dict] | None = None
|
||||
) -> "ChatCompletionRequest":
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AssistantMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||
|
||||
messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = (
|
||||
[]
|
||||
)
|
||||
for turn in conversation:
|
||||
role = turn.get("role")
|
||||
|
||||
if role == "user":
|
||||
messages.append(UserMessage(content=turn["content"]))
|
||||
elif role == "assistant":
|
||||
messages.append(
|
||||
AssistantMessage(
|
||||
content=turn.get("content"),
|
||||
tool_calls=turn.get("tool_calls"),
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
messages.append(
|
||||
ToolMessage(
|
||||
content=turn.get("content"),
|
||||
tool_call_id=turn.get("tool_call_id"),
|
||||
name=turn.get("name"),
|
||||
)
|
||||
)
|
||||
elif role == "system":
|
||||
messages.append(SystemMessage(content=turn["content"]))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown role for use with mistral-common tokenizer: {turn['role']}"
|
||||
)
|
||||
|
||||
tool_calls: list[Tool] = []
|
||||
if tools:
|
||||
# convert to Tool
|
||||
for tool in tools:
|
||||
if tool["type"] != "function":
|
||||
continue
|
||||
|
||||
function = tool["function"]
|
||||
|
||||
tool_calls.append(
|
||||
Tool(
|
||||
function=Function(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
# set parameters to empty dict if not provided
|
||||
parameters=function.get("parameters", {}),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
chat_completion: ChatCompletionRequest = ChatCompletionRequest(
|
||||
messages=messages,
|
||||
tools=tool_calls,
|
||||
)
|
||||
|
||||
return chat_completion
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tokenize: bool = True,
|
||||
tools: list[dict] | None = None,
|
||||
chat_template: str | None = None, # pylint: disable=unused-argument
|
||||
add_generation_prompt: bool = False, # pylint: disable=unused-argument
|
||||
) -> list[int] | str:
|
||||
if chat_template:
|
||||
raise NotImplementedError("chat_template not supported yet")
|
||||
|
||||
if add_generation_prompt:
|
||||
raise NotImplementedError("add_generation_prompt not supported yet")
|
||||
|
||||
chat_completion: ChatCompletionRequest = (
|
||||
self._create_mistral_chat_completion_request(messages, tools)
|
||||
)
|
||||
|
||||
tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens
|
||||
|
||||
if tokenize:
|
||||
return tokens
|
||||
|
||||
return self.decode(tokens)
|
||||
|
||||
def pad(
|
||||
self,
|
||||
features: list[dict[str, list[int] | np.ndarray]],
|
||||
*,
|
||||
padding: bool | str | PaddingStrategy = True,
|
||||
max_length: int | None = None,
|
||||
pad_to_multiple_of: int | None = None,
|
||||
return_tensors: str | None = None, # "np", "pt", or "tf"
|
||||
) -> dict[str, np.ndarray | Tensor]:
|
||||
"""
|
||||
HF-style pad method that properly handles all sequence-related features:
|
||||
- pad 'input_ids' & 'labels' to the longest (or to max_length)
|
||||
"""
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
# Check for unsupported fields
|
||||
if any("token_type_ids" in f for f in features):
|
||||
raise ValueError("token_type_ids is not supported by this tokenizer")
|
||||
|
||||
# Determine desired sequence length
|
||||
lengths = [len(f["input_ids"]) for f in features]
|
||||
if padding in (True, "longest", PaddingStrategy.LONGEST):
|
||||
target_length = max(lengths)
|
||||
elif padding in ("max_length", PaddingStrategy.MAX_LENGTH):
|
||||
if max_length is None:
|
||||
raise ValueError("max_length must be set for 'max_length' padding")
|
||||
target_length = max_length
|
||||
elif padding in (False, "do_not_pad", PaddingStrategy.DO_NOT_PAD):
|
||||
target_length = None
|
||||
else:
|
||||
raise ValueError(f"Unknown padding strategy: {padding}")
|
||||
|
||||
# Apply pad_to_multiple_of
|
||||
if target_length is not None and pad_to_multiple_of is not None:
|
||||
target_length = (
|
||||
math.ceil(target_length / pad_to_multiple_of) * pad_to_multiple_of
|
||||
)
|
||||
|
||||
# If no padding requested, just stack tensors
|
||||
do_pad = target_length is not None
|
||||
|
||||
# Pad sequences using torch.nn.utils.rnn.pad_sequence
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["input_ids"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
)
|
||||
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["labels"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=IGNORE_INDEX,
|
||||
)
|
||||
|
||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
[torch.tensor(x["attention_mask"], dtype=torch.long) for x in features],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
|
||||
# Handle position_ids - pad with sequential values for right padding, 0s for left padding
|
||||
if "position_ids" in features[0]:
|
||||
if self.padding_side == "left":
|
||||
# Likely not needed, but keeping for now
|
||||
# For left padding, we'll pad with 0s using pad_sequence, then handle manually
|
||||
position_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
[
|
||||
torch.tensor(x["position_ids"], dtype=torch.long)
|
||||
for x in features
|
||||
],
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
else:
|
||||
# For right padding, continue the sequence
|
||||
max_pos_len = max(len(f["position_ids"]) for f in features)
|
||||
position_ids_list = []
|
||||
for f in features:
|
||||
pos_seq = torch.tensor(f["position_ids"], dtype=torch.long)
|
||||
if len(pos_seq) < max_pos_len:
|
||||
# Continue the sequence
|
||||
last_pos = pos_seq[-1].item() if len(pos_seq) > 0 else -1
|
||||
pad_len = max_pos_len - len(pos_seq)
|
||||
pad_positions = torch.arange(
|
||||
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
||||
)
|
||||
pos_seq = torch.cat([pos_seq, pad_positions])
|
||||
position_ids_list.append(pos_seq)
|
||||
position_ids = torch.stack(position_ids_list)
|
||||
else:
|
||||
# Create position_ids if not present
|
||||
seq_len = input_ids.size(1)
|
||||
position_ids = (
|
||||
torch.arange(seq_len, dtype=torch.long)
|
||||
.unsqueeze(0)
|
||||
.expand(input_ids.size(0), -1)
|
||||
)
|
||||
|
||||
# Ensure all tensors have the same sequence length
|
||||
max_seq_len = max(
|
||||
input_ids.size(1),
|
||||
labels.size(1),
|
||||
attention_mask.size(1),
|
||||
position_ids.size(1),
|
||||
)
|
||||
|
||||
# TODO: check if trimming is needed? and correct.
|
||||
|
||||
if do_pad and target_length is not None:
|
||||
max_seq_len = target_length
|
||||
|
||||
# Pad all tensors to the same length
|
||||
if input_ids.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - input_ids.size(1)
|
||||
if self.padding_side == "right":
|
||||
input_ids = F.pad(
|
||||
input_ids,
|
||||
(0, pad_len),
|
||||
value=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
)
|
||||
else:
|
||||
input_ids = F.pad(
|
||||
input_ids,
|
||||
(pad_len, 0),
|
||||
value=self.pad_token_id if self.pad_token_id is not None else 0,
|
||||
)
|
||||
elif input_ids.size(1) > max_seq_len:
|
||||
input_ids = input_ids[:, :max_seq_len]
|
||||
|
||||
if labels.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - labels.size(1)
|
||||
if self.padding_side == "right":
|
||||
labels = F.pad(labels, (0, pad_len), value=IGNORE_INDEX)
|
||||
else:
|
||||
labels = F.pad(labels, (pad_len, 0), value=IGNORE_INDEX)
|
||||
elif labels.size(1) > max_seq_len:
|
||||
labels = labels[:, :max_seq_len]
|
||||
|
||||
if attention_mask.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - attention_mask.size(1)
|
||||
if self.padding_side == "right":
|
||||
attention_mask = F.pad(attention_mask, (0, pad_len), value=0)
|
||||
else:
|
||||
attention_mask = F.pad(attention_mask, (pad_len, 0), value=0)
|
||||
elif attention_mask.size(1) > max_seq_len:
|
||||
attention_mask = attention_mask[:, :max_seq_len]
|
||||
|
||||
if position_ids.size(1) < max_seq_len:
|
||||
pad_len = max_seq_len - position_ids.size(1)
|
||||
if self.padding_side == "right":
|
||||
batch_size = position_ids.size(0)
|
||||
new_position_ids = []
|
||||
for i in range(batch_size):
|
||||
seq = position_ids[i]
|
||||
if len(seq) > 0:
|
||||
# get last position and pad with sequential values
|
||||
last_pos = seq[-1].item()
|
||||
pad_positions = torch.arange(
|
||||
last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long
|
||||
)
|
||||
new_seq = torch.cat([seq, pad_positions])
|
||||
else:
|
||||
new_seq = torch.arange(pad_len, dtype=torch.long)
|
||||
new_position_ids.append(new_seq)
|
||||
position_ids = torch.stack(new_position_ids)
|
||||
else:
|
||||
position_ids = F.pad(position_ids, (pad_len, 0), value=0)
|
||||
elif position_ids.size(1) > max_seq_len:
|
||||
position_ids = position_ids[:, :max_seq_len]
|
||||
|
||||
final_batch = {
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
# Handle non-sequence fields (raise error)
|
||||
sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"}
|
||||
for f in features:
|
||||
for key in f.keys():
|
||||
if key not in sequence_fields:
|
||||
raise NotImplementedError(
|
||||
f"Non-sequence field {key} not handled yet"
|
||||
)
|
||||
|
||||
# Convert to requested tensor type
|
||||
if return_tensors is None or return_tensors == "np":
|
||||
result = {}
|
||||
for k, v in final_batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
result[k] = v.numpy().astype(np.long)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
if return_tensors == "pt":
|
||||
return final_batch
|
||||
|
||||
raise ValueError(f"Unsupported return_tensors='{return_tensors}'")
|
||||
|
||||
def convert_ids_to_tokens(self, ids: list[int]) -> list[str]:
|
||||
"""
|
||||
Convert a list of token IDs to a list of tokens.
|
||||
|
||||
Args:
|
||||
ids: The list of token IDs to convert.
|
||||
|
||||
Returns:
|
||||
The list of tokens.
|
||||
"""
|
||||
return [
|
||||
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
|
||||
]
|
||||
@@ -3,7 +3,6 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length
|
||||
into fixed-capacity batches to optimize memory usage and training throughput.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import math
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
@@ -146,7 +145,7 @@ def pack_parallel(
|
||||
"""
|
||||
num_items = len(sequence_lengths)
|
||||
if num_processes is None:
|
||||
num_processes = max(1, min(num_items // group_size, cpu_count(), 16))
|
||||
num_processes = max(1, min(num_items // group_size, cpu_count()))
|
||||
|
||||
# Create tasks for parallel processing
|
||||
tasks = []
|
||||
@@ -259,8 +258,8 @@ class MultipackBatchSampler(BatchSampler):
|
||||
batch_max_len: int, # Maximum sequence length (bin capacity)
|
||||
lengths: np.ndarray, # Sequence lengths
|
||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||
num_count_samples: int = 8, # Number of times to estimate batch count
|
||||
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
|
||||
num_count_samples: int = 16, # Number of times to estimate batch count
|
||||
sequential: bool = False, # Whether to use sequential packing
|
||||
group_size: int = 100_000, # Size of groups for parallel packing
|
||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||
@@ -350,7 +349,6 @@ class MultipackBatchSampler(BatchSampler):
|
||||
# Calculate efficiency statistics
|
||||
total_used = lengths.sum()
|
||||
total_slots = len(all_bins) * self.batch_max_len
|
||||
del all_bins
|
||||
|
||||
# Group bins into batches (each batch contains batch_size bins)
|
||||
batches = [
|
||||
@@ -370,7 +368,6 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.total_token_slots += total_slots
|
||||
|
||||
self._batches = batches
|
||||
gc.collect()
|
||||
return batches
|
||||
|
||||
def __iter__(self) -> Iterator[list[list[int]]]:
|
||||
@@ -446,18 +443,10 @@ class MultipackBatchSampler(BatchSampler):
|
||||
|
||||
if self._len_across_ranks is None:
|
||||
# Sample multiple times to get stable estimate
|
||||
_sampled_lens = []
|
||||
for _ in range(self.num_count_samples):
|
||||
self._batches = None # Reset cached batches
|
||||
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
|
||||
len_batches = min(_sampled_lens)
|
||||
|
||||
len_batches = min( # pylint: disable=consider-using-generator
|
||||
[len(self._batches) for _ in range(self.num_count_samples)]
|
||||
)
|
||||
# Gather minimum across all ranks
|
||||
if self._len_across_ranks is None:
|
||||
self._len_across_ranks = self.gather_len_batches(len_batches)
|
||||
else:
|
||||
self._len_across_ranks = min(
|
||||
self._len_across_ranks, self.gather_len_batches(len_batches)
|
||||
)
|
||||
self._len_across_ranks = self.gather_len_batches(len_batches)
|
||||
|
||||
return self._len_across_ranks
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,6 @@
|
||||
"""Pydantic models for datasets-related configuration"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from axolotl.utils.schemas.enums import ChatTemplate
|
||||
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||
@@ -11,178 +9,56 @@ from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||
class UserDefinedPrompterType(BaseModel):
|
||||
"""Structure for user defined prompt types"""
|
||||
|
||||
system_prompt: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Custom user instruction prompt"},
|
||||
)
|
||||
system_format: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Use {system} as key to be replaced"},
|
||||
)
|
||||
system_prompt: str | None = None
|
||||
system_format: str | None = None
|
||||
field_system: str | None = None
|
||||
field_instruction: str | None = None
|
||||
field_input: str | None = None
|
||||
field_output: str | None = None
|
||||
|
||||
format: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Customizable to be single line or multi-line. Use {instruction}/{input} as key to be replaced. 'format' can include {input}"
|
||||
},
|
||||
)
|
||||
no_input_format: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "'no_input_format' cannot include {input}"},
|
||||
)
|
||||
field: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "For `completion` datsets only, uses the provided field instead of `text` column"
|
||||
},
|
||||
)
|
||||
format: str | None = None
|
||||
no_input_format: str | None = None
|
||||
field: str | None = None
|
||||
|
||||
|
||||
class SFTDataset(BaseModel):
|
||||
"""SFT configuration subset"""
|
||||
|
||||
path: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "HuggingFace dataset repo | s3:// | gs:// | path to local file or directory"
|
||||
},
|
||||
)
|
||||
split: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "name of dataset split to load from"},
|
||||
)
|
||||
type: str | UserDefinedPrompterType | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]"
|
||||
},
|
||||
)
|
||||
path: str | None = None
|
||||
split: str | None = None
|
||||
type: str | UserDefinedPrompterType | None = None
|
||||
input_transform: str | None = None
|
||||
shards: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "split dataset into N pieces (use with shards_idx)"
|
||||
},
|
||||
)
|
||||
shards_idx: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "the index of sharded dataset to use"},
|
||||
)
|
||||
preprocess_shards: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)"
|
||||
},
|
||||
)
|
||||
shards: int | None = None
|
||||
shards_idx: int | None = None
|
||||
preprocess_shards: int | None = None
|
||||
conversation: str | None = None
|
||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||
chat_template: ChatTemplate | str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field."
|
||||
},
|
||||
)
|
||||
chat_template_jinja: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Custom jinja chat template. Used only if `chat_template: jinja` or empty."
|
||||
},
|
||||
)
|
||||
data_files: str | list[str] | None = Field(
|
||||
default=None, json_schema_extra={"description": "path to source data files"}
|
||||
)
|
||||
chat_template: ChatTemplate | str | None = None
|
||||
chat_template_jinja: str | None = None
|
||||
data_files: str | list[str] | None = None
|
||||
input_format: str | None = None
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "name of dataset configuration to load"},
|
||||
)
|
||||
ds_type: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "defines the datatype when path is a file"},
|
||||
)
|
||||
name: str | None = None
|
||||
ds_type: str | None = None
|
||||
field: str | None = None
|
||||
field_human: str | None = None
|
||||
field_model: str | None = None
|
||||
field_messages: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Key containing the messages (default: "messages")'
|
||||
},
|
||||
)
|
||||
field_tools: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
|
||||
},
|
||||
)
|
||||
field_messages: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_role: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_content: str | None = None
|
||||
message_property_mappings: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Mapping of properties from the input dataset to the chat template. (default: message_property_mappings={'role':'role', 'content':'content'}) If a property exists in the template but not in this mapping, the system will attempt to load it directly from the message using the property name as the key. Example: In the mapping below, 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and used as 'content' in the chat template."
|
||||
},
|
||||
)
|
||||
message_field_training: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`."
|
||||
},
|
||||
)
|
||||
message_field_training_detail: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train)."
|
||||
},
|
||||
)
|
||||
split_thinking: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "(for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags"
|
||||
},
|
||||
)
|
||||
message_property_mappings: dict[str, str] | None = None
|
||||
message_field_training: str | None = None
|
||||
message_field_training_detail: str | None = None
|
||||
split_thinking: bool | None = None
|
||||
logprobs_field: str | None = None
|
||||
temperature: float | None = None
|
||||
roles_to_train: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Roles to train on. The tokens from these roles will be considered for the loss."
|
||||
},
|
||||
)
|
||||
train_on_eos: Literal["all", "turn", "last"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation"
|
||||
},
|
||||
)
|
||||
roles: dict[str, list[str]] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Roles mapping in the messages. The format is {target_role: [source_roles]}. All source roles will be mapped to the target role. The default is: user: ["human", "user"], assistant: ["gpt", "assistant"], system: ["system"], tool: ["tool"]'
|
||||
},
|
||||
)
|
||||
drop_system_message: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to drop the system turn from the dataset. Only works with chat_template. This does not drop the default system message from chat_template if it exists. If you wish to, we recommend using a custom jinja template with the default system message removed or adding a system turn with empty content."
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
||||
)
|
||||
revision: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets."
|
||||
},
|
||||
)
|
||||
roles_to_train: list[str] | None = None
|
||||
train_on_eos: str | None = None
|
||||
roles: dict[str, list[str]] | None = None
|
||||
drop_system_message: bool | None = None
|
||||
trust_remote_code: bool | None = False
|
||||
revision: str | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -60,30 +60,10 @@ class RemappedParameters(BaseModel):
|
||||
"""Parameters that have been remapped to other names"""
|
||||
|
||||
overrides_of_model_config: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="model_config",
|
||||
json_schema_extra={
|
||||
"description": "optional overrides to the base model configuration"
|
||||
},
|
||||
default=None, alias="model_config"
|
||||
)
|
||||
overrides_of_model_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="model_kwargs",
|
||||
json_schema_extra={
|
||||
"description": "optional overrides the base model loading from_pretrained"
|
||||
},
|
||||
)
|
||||
type_of_model: str | None = Field(
|
||||
default=None,
|
||||
alias="model_type",
|
||||
json_schema_extra={
|
||||
"description": "If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too"
|
||||
},
|
||||
)
|
||||
revision_of_model: str | None = Field(
|
||||
default=None,
|
||||
alias="model_revision",
|
||||
json_schema_extra={
|
||||
"description": "You can specify to choose a specific model revision from huggingface hub"
|
||||
},
|
||||
default=None, alias="model_kwargs"
|
||||
)
|
||||
type_of_model: str | None = Field(default=None, alias="model_type")
|
||||
revision_of_model: str | None = Field(default=None, alias="model_revision")
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Enums for Axolotl input config"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
@@ -10,81 +8,81 @@ import torch
|
||||
class TorchIntDType(Enum):
|
||||
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
|
||||
|
||||
uint1 = getattr(torch, "uint1", None)
|
||||
uint2 = getattr(torch, "uint2", None)
|
||||
uint3 = getattr(torch, "uint3", None)
|
||||
uint4 = getattr(torch, "uint4", None)
|
||||
uint5 = getattr(torch, "uint5", None)
|
||||
uint6 = getattr(torch, "uint6", None)
|
||||
uint7 = getattr(torch, "uint7", None)
|
||||
int4 = getattr(torch, "int4", None)
|
||||
int8 = getattr(torch, "int8", None)
|
||||
uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name
|
||||
uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name
|
||||
uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name
|
||||
uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name
|
||||
uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name
|
||||
uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name
|
||||
uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name
|
||||
int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name
|
||||
int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
DPO = "dpo"
|
||||
GRPO = "grpo"
|
||||
IPO = "ipo"
|
||||
ORPO = "orpo"
|
||||
KTO = "kto"
|
||||
SIMPO = "simpo"
|
||||
DPO = "dpo" # pylint: disable=invalid-name
|
||||
GRPO = "grpo" # pylint: disable=invalid-name
|
||||
IPO = "ipo" # pylint: disable=invalid-name
|
||||
ORPO = "orpo" # pylint: disable=invalid-name
|
||||
KTO = "kto" # pylint: disable=invalid-name
|
||||
SIMPO = "simpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
alpaca = "alpaca"
|
||||
chatml = "chatml"
|
||||
mistral_v1 = "mistral_v1"
|
||||
mistral_v2v3 = "mistral_v2v3"
|
||||
mistral_v3_tekken = "mistral_v3_tekken"
|
||||
mistral_v7_tekken = "mistral_v7_tekken"
|
||||
gemma = "gemma"
|
||||
cohere = "cohere"
|
||||
llama3 = "llama3"
|
||||
llama3_2_vision = "llama3_2_vision"
|
||||
llama4 = "llama4"
|
||||
phi_3 = "phi_3"
|
||||
phi_35 = "phi_35"
|
||||
deepseek_v2 = "deepseek_v2"
|
||||
deepseek_v3 = "deepseek_v3"
|
||||
jamba = "jamba"
|
||||
jinja = "jinja"
|
||||
qwen_25 = "qwen_25"
|
||||
qwen3 = "qwen3"
|
||||
tokenizer_default = "tokenizer_default"
|
||||
exaone = "exaone"
|
||||
metharme = "metharme"
|
||||
pixtral = "pixtral"
|
||||
llava = "llava"
|
||||
qwen2_vl = "qwen2_vl"
|
||||
gemma3 = "gemma3"
|
||||
command_a = "command_a"
|
||||
command_a_tool_use = "command_a_tool_use"
|
||||
command_a_rag = "command_a_rag"
|
||||
aya = "aya"
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
llama4 = "llama4" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
||||
jamba = "jamba" # pylint: disable=invalid-name
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
qwen3 = "qwen3" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
pixtral = "pixtral" # pylint: disable=invalid-name
|
||||
llava = "llava" # pylint: disable=invalid-name
|
||||
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
||||
gemma3 = "gemma3" # pylint: disable=invalid-name
|
||||
command_a = "command_a" # pylint: disable=invalid-name
|
||||
command_a_tool_use = "command_a_tool_use" # pylint: disable=invalid-name
|
||||
command_a_rag = "command_a_rag" # pylint: disable=invalid-name
|
||||
aya = "aya" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CustomSupportedOptimizers(str, Enum):
|
||||
"""Custom supported optimizers"""
|
||||
|
||||
optimi_adamw = "optimi_adamw"
|
||||
ao_adamw_4bit = "ao_adamw_4bit"
|
||||
ao_adamw_8bit = "ao_adamw_8bit"
|
||||
ao_adamw_fp8 = "ao_adamw_fp8"
|
||||
adopt_adamw = "adopt_adamw"
|
||||
came_pytorch = "came_pytorch"
|
||||
muon = "muon"
|
||||
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
|
||||
muon = "muon" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class RingAttnFunc(str, Enum):
|
||||
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||
|
||||
VARLEN_LLAMA3 = "varlen_llama3"
|
||||
BATCH_RING = "batch_ring"
|
||||
# VARLEN_RING = "varlen_ring"
|
||||
# VARLEN_ZIGZAG = "varlen_zigzag"
|
||||
VARLEN_LLAMA3 = "varlen_llama3"
|
||||
BATCH_RING = "batch_ring"
|
||||
# BATCH_ZIGZAG = "batch_zigzag"
|
||||
# BATCH_STRIPE = "batch_stripe"
|
||||
|
||||
@@ -13,21 +13,10 @@ class MLFlowConfig(BaseModel):
|
||||
"""MLFlow configuration subset"""
|
||||
|
||||
use_mlflow: bool | None = None
|
||||
mlflow_tracking_uri: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "URI to mlflow"}
|
||||
)
|
||||
mlflow_experiment_name: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "Your experiment name"}
|
||||
)
|
||||
mlflow_run_name: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "Your run name"}
|
||||
)
|
||||
hf_mlflow_log_artifacts: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "set to true to copy each saved checkpoint on each save to mlflow artifact registry"
|
||||
},
|
||||
)
|
||||
mlflow_tracking_uri: str | None = None
|
||||
mlflow_experiment_name: str | None = None
|
||||
mlflow_run_name: str | None = None
|
||||
hf_mlflow_log_artifacts: bool | None = None
|
||||
|
||||
|
||||
class LISAConfig(BaseModel):
|
||||
@@ -51,33 +40,13 @@ class WandbConfig(BaseModel):
|
||||
"""Wandb configuration subset"""
|
||||
|
||||
use_wandb: bool | None = None
|
||||
wandb_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Set the name of your wandb run"},
|
||||
)
|
||||
wandb_run_id: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "Set the ID of your wandb run"}
|
||||
)
|
||||
wandb_mode: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": '"offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb'
|
||||
},
|
||||
)
|
||||
wandb_project: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "Your wandb project name"}
|
||||
)
|
||||
wandb_entity: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "A wandb Team name if using a Team"},
|
||||
)
|
||||
wandb_name: str | None = None
|
||||
wandb_run_id: str | None = None
|
||||
wandb_mode: str | None = None
|
||||
wandb_project: str | None = None
|
||||
wandb_entity: str | None = None
|
||||
wandb_watch: str | None = None
|
||||
wandb_log_model: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": '"checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training'
|
||||
},
|
||||
)
|
||||
wandb_log_model: str | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -95,52 +64,14 @@ class WandbConfig(BaseModel):
|
||||
class CometConfig(BaseModel):
|
||||
"""Comet configuration subset"""
|
||||
|
||||
use_comet: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Enable or disable Comet integration."},
|
||||
)
|
||||
comet_api_key: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "API key for Comet. Recommended to set via `comet login`."
|
||||
},
|
||||
)
|
||||
comet_workspace: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Workspace name in Comet. Defaults to the user's default workspace."
|
||||
},
|
||||
)
|
||||
comet_project_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Project name in Comet. Defaults to Uncategorized."
|
||||
},
|
||||
)
|
||||
comet_experiment_key: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key."
|
||||
},
|
||||
)
|
||||
comet_mode: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.'
|
||||
},
|
||||
)
|
||||
comet_online: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Set to True to log data to Comet server, or False for offline storage. Default is True."
|
||||
},
|
||||
)
|
||||
comet_experiment_config: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Dictionary for additional configuration settings, see the doc for more details."
|
||||
},
|
||||
)
|
||||
use_comet: bool | None = None
|
||||
comet_api_key: str | None = None
|
||||
comet_workspace: str | None = None
|
||||
comet_project_name: str | None = None
|
||||
comet_experiment_key: str | None = None
|
||||
comet_mode: str | None = None
|
||||
comet_online: bool | None = None
|
||||
comet_experiment_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
|
||||
@@ -12,55 +12,19 @@ class ModelInputConfig(BaseModel):
|
||||
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
base_model: str = Field(
|
||||
json_schema_extra={
|
||||
"description": "This is the huggingface model that contains *.pt, *.safetensors, or *.bin files. This can also be a relative path to a model on disk"
|
||||
}
|
||||
)
|
||||
base_model_config: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
|
||||
},
|
||||
)
|
||||
base_model: str
|
||||
base_model_config: str | None = None
|
||||
cls_model_config: str | None = None
|
||||
tokenizer_config: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Optional tokenizer configuration path in case you want to use a different tokenizer than the one defined in the base model"
|
||||
},
|
||||
)
|
||||
tokenizer_use_fast: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "use_fast option for tokenizer loading from_pretrained, default to True"
|
||||
},
|
||||
)
|
||||
tokenizer_legacy: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use the legacy tokenizer setting, defaults to True"
|
||||
},
|
||||
)
|
||||
tokenizer_use_mistral_common: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer."
|
||||
},
|
||||
)
|
||||
tokenizer_config: str | None = None
|
||||
tokenizer_use_fast: bool | None = None
|
||||
tokenizer_legacy: bool | None = None
|
||||
tokenizer_type: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Corresponding tokenizer for the model AutoTokenizer is a good choice"
|
||||
},
|
||||
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||
)
|
||||
processor_type: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||
)
|
||||
trust_remote_code: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
||||
)
|
||||
trust_remote_code: bool | None = None
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
@@ -75,23 +39,10 @@ class ModelInputConfig(BaseModel):
|
||||
class ModelOutputConfig(BaseModel):
|
||||
"""model save configuration subset"""
|
||||
|
||||
output_dir: str = Field(
|
||||
default="./model-out",
|
||||
json_schema_extra={"description": "Where to save the full-finetuned model to"},
|
||||
)
|
||||
hub_model_id: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "push checkpoints to hub"}
|
||||
)
|
||||
hub_strategy: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "how to push checkpoints to hub"},
|
||||
)
|
||||
save_safetensors: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Save model as safetensors (require safetensors package). Default True"
|
||||
},
|
||||
)
|
||||
output_dir: str = Field(default="./model-out")
|
||||
hub_model_id: str | None = None
|
||||
hub_strategy: str | None = None
|
||||
save_safetensors: bool | None = True
|
||||
|
||||
|
||||
class SpecialTokensConfig(BaseModel):
|
||||
|
||||
@@ -9,7 +9,7 @@ class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
loftq_bits: int = Field(
|
||||
default=4, json_schema_extra={"description": "typically 4 bits"}
|
||||
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||
)
|
||||
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||
|
||||
@@ -17,78 +17,31 @@ class LoftQConfig(BaseModel):
|
||||
class PeftConfig(BaseModel):
|
||||
"""peftq configuration subset"""
|
||||
|
||||
loftq_config: LoftQConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Configuration options for loftq initialization for LoRA"
|
||||
},
|
||||
)
|
||||
loftq_config: LoftQConfig | None = None
|
||||
|
||||
|
||||
class LoraConfig(BaseModel):
|
||||
"""Peft / LoRA configuration subset"""
|
||||
|
||||
load_in_8bit: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer"
|
||||
},
|
||||
)
|
||||
load_in_4bit: bool | None = Field(
|
||||
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
|
||||
)
|
||||
load_in_8bit: bool | None = Field(default=False)
|
||||
load_in_4bit: bool | None = Field(default=False)
|
||||
|
||||
adapter: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model"
|
||||
},
|
||||
)
|
||||
lora_model_dir: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If you already have a lora model trained that you want to load, put that here. This means after training, if you want to test the model, you should set this to the value of `output_dir`. Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`."
|
||||
},
|
||||
)
|
||||
adapter: str | None = None
|
||||
lora_model_dir: str | None = None
|
||||
lora_r: int | None = None
|
||||
lora_alpha: int | None = None
|
||||
lora_fan_in_fan_out: bool | None = None
|
||||
lora_target_modules: str | list[str] | None = None
|
||||
lora_target_linear: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "If true, will target all linear modules"},
|
||||
)
|
||||
lora_modules_to_save: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities."
|
||||
},
|
||||
)
|
||||
lora_target_linear: bool | None = None
|
||||
lora_modules_to_save: list[str] | None = None
|
||||
lora_dropout: float | None = 0.0
|
||||
peft_layers_to_transform: list[int] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The layer indices to transform, otherwise, apply to all layers"
|
||||
},
|
||||
)
|
||||
peft_layers_to_transform: list[int] | None = None
|
||||
peft_layers_pattern: list[str] | None = None
|
||||
peft: PeftConfig | None = None
|
||||
peft_use_dora: bool | None = Field(
|
||||
default=None, json_schema_extra={"description": "Whether to use DoRA."}
|
||||
)
|
||||
peft_use_rslora: bool | None = Field(
|
||||
default=None, json_schema_extra={"description": "Whether to use RSLoRA."}
|
||||
)
|
||||
peft_layer_replication: list[tuple[int, int]] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "List of layer indices to replicate."},
|
||||
)
|
||||
peft_init_lora_weights: bool | str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "How to initialize LoRA weights. Default to True which is MS original implementation."
|
||||
},
|
||||
)
|
||||
peft_use_dora: bool | None = None
|
||||
peft_use_rslora: bool | None = None
|
||||
peft_layer_replication: list[tuple[int, int]] | None = None
|
||||
peft_init_lora_weights: bool | str | None = None
|
||||
|
||||
qlora_sharded_model_loading: bool | None = Field(
|
||||
default=False,
|
||||
@@ -96,24 +49,9 @@ class LoraConfig(BaseModel):
|
||||
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||
},
|
||||
)
|
||||
lora_on_cpu: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge"
|
||||
},
|
||||
)
|
||||
gptq: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether you are training a 4-bit GPTQ quantized model"
|
||||
},
|
||||
)
|
||||
bnb_config_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "optional overrides to the bnb 4bit quantization configuration"
|
||||
},
|
||||
)
|
||||
lora_on_cpu: bool | None = None
|
||||
gptq: bool | None = None
|
||||
bnb_config_kwargs: dict[str, Any] | None = None
|
||||
|
||||
loraplus_lr_ratio: float | None = Field(
|
||||
default=None,
|
||||
@@ -124,7 +62,7 @@ class LoraConfig(BaseModel):
|
||||
loraplus_lr_embedding: float | None = Field(
|
||||
default=1e-6,
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate for lora embedding layers. Default value is 1e-6."
|
||||
"description": "loraplus learning rate for lora embedding layers."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -187,29 +125,8 @@ class LoraConfig(BaseModel):
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
relora_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of steps per ReLoRA restart"},
|
||||
)
|
||||
relora_warmup_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of per-restart warmup steps"},
|
||||
)
|
||||
relora_anneal_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of anneal steps for each relora cycle"
|
||||
},
|
||||
)
|
||||
relora_prune_ratio: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "threshold for optimizer magnitude when pruning"
|
||||
},
|
||||
)
|
||||
relora_cpu_offload: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "True to perform lora weight merges on cpu during restarts, for modest gpu memory savings"
|
||||
},
|
||||
)
|
||||
relora_steps: int | None = None
|
||||
relora_warmup_steps: int | None = None
|
||||
relora_anneal_steps: int | None = None
|
||||
relora_prune_ratio: float | None = None
|
||||
relora_cpu_offload: bool | None = None
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user