Compare commits
21 Commits
pre-commit
...
feat/soap-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a7f048c6b | ||
|
|
76d26366ad | ||
|
|
64fe284765 | ||
|
|
cf0c79d52e | ||
|
|
4ba80a0e5a | ||
|
|
c49682132b | ||
|
|
e46239f8d3 | ||
|
|
05f03b541a | ||
|
|
a4e430e7c4 | ||
|
|
6cdcb8ddd5 | ||
|
|
a7811ad4a0 | ||
|
|
e2da821e67 | ||
|
|
2c34a4634e | ||
|
|
a9b0733f2c | ||
|
|
9f00465a5c | ||
|
|
86bac48d14 | ||
|
|
e44953d50c | ||
|
|
23f0c51d88 | ||
|
|
113e9cd193 | ||
|
|
61825a464a | ||
|
|
c907ac173e |
7
.github/workflows/docs.yml
vendored
7
.github/workflows/docs.yml
vendored
@@ -20,9 +20,12 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
- name: install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter
|
python3 -m pip install jupyter quartodoc
|
||||||
|
python3 -m pip install -e . --no-deps
|
||||||
|
- name: Build autodoc
|
||||||
|
run: quartodoc build
|
||||||
- name: Publish to GitHub Pages (and render)
|
- name: Publish to GitHub Pages (and render)
|
||||||
uses: quarto-dev/quarto-actions/publish@v2
|
uses: quarto-dev/quarto-actions/publish@v2
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/tests-nightly.yml
vendored
2
.github/workflows/tests-nightly.yml
vendored
@@ -136,4 +136,4 @@ jobs:
|
|||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.e2e_tests
|
||||||
|
|||||||
17
.github/workflows/tests.yml
vendored
17
.github/workflows/tests.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -98,8 +98,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -136,7 +137,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -170,10 +171,14 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Show HF cache
|
||||||
|
run: huggingface-cli scan-cache
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -227,7 +232,7 @@ jobs:
|
|||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
@@ -274,4 +279,4 @@ jobs:
|
|||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.e2e_tests
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -181,6 +181,10 @@ prepared-datasets/
|
|||||||
submit.sh
|
submit.sh
|
||||||
*.out*
|
*.out*
|
||||||
|
|
||||||
|
# Quartodoc generated files
|
||||||
|
objects.json
|
||||||
|
site_libs/
|
||||||
|
|
||||||
typings/
|
typings/
|
||||||
out/
|
out/
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
[settings]
|
[settings]
|
||||||
profile=black
|
profile=black
|
||||||
known_third_party=wandb,comet_ml
|
known_third_party=wandb,comet_ml
|
||||||
|
known_local_folder=src,tests
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
|||||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||||
|
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
||||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||||
|
|
||||||
## 🤝 Getting Help
|
## 🤝 Getting Help
|
||||||
|
|||||||
194
_quarto.yml
194
_quarto.yml
@@ -1,6 +1,179 @@
|
|||||||
project:
|
project:
|
||||||
type: website
|
type: website
|
||||||
|
|
||||||
|
quartodoc:
|
||||||
|
dir: docs/api
|
||||||
|
package: axolotl
|
||||||
|
title: API Reference
|
||||||
|
parser: google
|
||||||
|
|
||||||
|
sections:
|
||||||
|
- title: Core
|
||||||
|
desc: Core functionality for training
|
||||||
|
contents:
|
||||||
|
- train
|
||||||
|
- evaluate
|
||||||
|
- datasets
|
||||||
|
- convert
|
||||||
|
- prompt_tokenizers
|
||||||
|
- logging_config
|
||||||
|
- core.trainer_builder
|
||||||
|
- core.training_args
|
||||||
|
- core.chat.messages
|
||||||
|
- core.chat.format.chatml
|
||||||
|
- core.chat.format.llama3x
|
||||||
|
- core.chat.format.shared
|
||||||
|
- core.datasets.chat
|
||||||
|
- core.datasets.transforms.chat_builder
|
||||||
|
- title: CLI
|
||||||
|
desc: Command-line interface
|
||||||
|
contents:
|
||||||
|
- cli.main
|
||||||
|
- cli.train
|
||||||
|
- cli.evaluate
|
||||||
|
- cli.args
|
||||||
|
- cli.checks
|
||||||
|
- cli.config
|
||||||
|
- cli.inference
|
||||||
|
- cli.merge_lora
|
||||||
|
- cli.merge_sharded_fsdp_weights
|
||||||
|
- cli.preprocess
|
||||||
|
- cli.sweeps
|
||||||
|
- cli.utils
|
||||||
|
- cli.cloud.base
|
||||||
|
- cli.cloud.modal_
|
||||||
|
- title: Trainers
|
||||||
|
desc: Training implementations
|
||||||
|
contents:
|
||||||
|
- core.trainers.base
|
||||||
|
- core.trainers.trl
|
||||||
|
- core.trainers.dpo.trainer
|
||||||
|
- core.trainers.grpo.trainer
|
||||||
|
- title: Prompt Strategies
|
||||||
|
desc: Prompt formatting strategies
|
||||||
|
contents:
|
||||||
|
- prompt_strategies.base
|
||||||
|
- prompt_strategies.chat_template
|
||||||
|
- prompt_strategies.alpaca_chat
|
||||||
|
- prompt_strategies.alpaca_instruct
|
||||||
|
- prompt_strategies.alpaca_w_system
|
||||||
|
- prompt_strategies.user_defined
|
||||||
|
- prompt_strategies.llama2_chat
|
||||||
|
- prompt_strategies.completion
|
||||||
|
- prompt_strategies.input_output
|
||||||
|
- prompt_strategies.stepwise_supervised
|
||||||
|
- prompt_strategies.metharme
|
||||||
|
- prompt_strategies.orcamini
|
||||||
|
- prompt_strategies.pygmalion
|
||||||
|
- prompt_strategies.messages.chat
|
||||||
|
- prompt_strategies.dpo.chat_template
|
||||||
|
- prompt_strategies.dpo.llama3
|
||||||
|
- prompt_strategies.dpo.chatml
|
||||||
|
- prompt_strategies.dpo.zephyr
|
||||||
|
- prompt_strategies.dpo.user_defined
|
||||||
|
- prompt_strategies.dpo.passthrough
|
||||||
|
- prompt_strategies.kto.llama3
|
||||||
|
- prompt_strategies.kto.chatml
|
||||||
|
- prompt_strategies.kto.user_defined
|
||||||
|
- prompt_strategies.orpo.chat_template
|
||||||
|
- prompt_strategies.bradley_terry.llama3
|
||||||
|
- title: Kernels
|
||||||
|
desc: Low-level performance optimizations
|
||||||
|
contents:
|
||||||
|
- kernels.lora
|
||||||
|
- kernels.geglu
|
||||||
|
- kernels.swiglu
|
||||||
|
- kernels.quantize
|
||||||
|
- kernels.utils
|
||||||
|
- title: MonkeyPatches
|
||||||
|
desc: Runtime patches for model optimizations
|
||||||
|
contents:
|
||||||
|
- monkeypatch.llama_attn_hijack_flash
|
||||||
|
- monkeypatch.llama_attn_hijack_xformers
|
||||||
|
- monkeypatch.mistral_attn_hijack_flash
|
||||||
|
- monkeypatch.multipack
|
||||||
|
- monkeypatch.relora
|
||||||
|
- monkeypatch.llama_expand_mask
|
||||||
|
- monkeypatch.lora_kernels
|
||||||
|
- monkeypatch.utils
|
||||||
|
- monkeypatch.btlm_attn_hijack_flash
|
||||||
|
- monkeypatch.llama_patch_multipack
|
||||||
|
- monkeypatch.stablelm_attn_hijack_flash
|
||||||
|
- monkeypatch.trainer_fsdp_optim
|
||||||
|
- monkeypatch.transformers_fa_utils
|
||||||
|
- monkeypatch.unsloth_
|
||||||
|
- monkeypatch.attention.mllama
|
||||||
|
- monkeypatch.data.batch_dataset_fetcher
|
||||||
|
- monkeypatch.mixtral
|
||||||
|
- title: Utils
|
||||||
|
desc: Utility functions
|
||||||
|
contents:
|
||||||
|
- utils.models
|
||||||
|
- utils.tokenization
|
||||||
|
- utils.chat_templates
|
||||||
|
- utils.lora
|
||||||
|
- utils.lora_embeddings
|
||||||
|
- utils.model_shard_quant
|
||||||
|
- utils.bench
|
||||||
|
- utils.freeze
|
||||||
|
- utils.trainer
|
||||||
|
- utils.schedulers
|
||||||
|
- utils.distributed
|
||||||
|
- utils.dict
|
||||||
|
- utils.optimizers.adopt
|
||||||
|
- utils.data.pretraining
|
||||||
|
- utils.data.sft
|
||||||
|
- utils.gradient_checkpointing.unsloth
|
||||||
|
- title: Schemas
|
||||||
|
desc: Pydantic data models for Axolotl config
|
||||||
|
contents:
|
||||||
|
- utils.schemas.config
|
||||||
|
- utils.schemas.model
|
||||||
|
- utils.schemas.training
|
||||||
|
- utils.schemas.datasets
|
||||||
|
- utils.schemas.peft
|
||||||
|
- utils.schemas.trl
|
||||||
|
- utils.schemas.multimodal
|
||||||
|
- utils.schemas.integrations
|
||||||
|
- utils.schemas.enums
|
||||||
|
- utils.schemas.utils
|
||||||
|
- title: Integrations
|
||||||
|
desc: Third-party integrations and extensions
|
||||||
|
contents:
|
||||||
|
- integrations.base
|
||||||
|
- integrations.cut_cross_entropy.args
|
||||||
|
- integrations.grokfast.optimizer
|
||||||
|
- integrations.kd.trainer
|
||||||
|
- integrations.liger.args
|
||||||
|
- integrations.lm_eval.args
|
||||||
|
- integrations.spectrum.args
|
||||||
|
- title: Common
|
||||||
|
desc: Common utilities and shared functionality
|
||||||
|
contents:
|
||||||
|
- common.architectures
|
||||||
|
- common.const
|
||||||
|
- common.datasets
|
||||||
|
- title: Models
|
||||||
|
desc: Custom model implementations
|
||||||
|
contents:
|
||||||
|
- models.mamba.modeling_mamba
|
||||||
|
- title: Data Processing
|
||||||
|
desc: Data processing utilities
|
||||||
|
contents:
|
||||||
|
- utils.collators.core
|
||||||
|
- utils.collators.batching
|
||||||
|
- utils.collators.mamba
|
||||||
|
- utils.collators.mm_chat
|
||||||
|
- utils.samplers.multipack
|
||||||
|
- title: Callbacks
|
||||||
|
desc: Training callbacks
|
||||||
|
contents:
|
||||||
|
- utils.callbacks.perplexity
|
||||||
|
- utils.callbacks.profiler
|
||||||
|
- utils.callbacks.lisa
|
||||||
|
- utils.callbacks.mlflow_
|
||||||
|
- utils.callbacks.comet_
|
||||||
|
|
||||||
website:
|
website:
|
||||||
title: "Axolotl"
|
title: "Axolotl"
|
||||||
description: "We make fine-tuning accessible, scalable, and fun"
|
description: "We make fine-tuning accessible, scalable, and fun"
|
||||||
@@ -35,6 +208,8 @@ website:
|
|||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/config.qmd
|
- docs/config.qmd
|
||||||
|
- text: "API Reference"
|
||||||
|
href: docs/api
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
@@ -80,3 +255,22 @@ format:
|
|||||||
theme: darkly
|
theme: darkly
|
||||||
css: styles.css
|
css: styles.css
|
||||||
toc: true
|
toc: true
|
||||||
|
# Enable better handling of line breaks in markdown
|
||||||
|
preserve-tabs: true
|
||||||
|
html-math-method: mathjax
|
||||||
|
# Improved markdown processing options
|
||||||
|
md-extensions:
|
||||||
|
- markdown_it
|
||||||
|
- def_list
|
||||||
|
- attr_list
|
||||||
|
- fenced_divs
|
||||||
|
- tables
|
||||||
|
- html_admonition
|
||||||
|
- lineblocks
|
||||||
|
- fancy_lists
|
||||||
|
# Control whitespace handling
|
||||||
|
whitespace: preserve
|
||||||
|
# Process newlines in paragraphs
|
||||||
|
wrap: preserve
|
||||||
|
# Better line break handling
|
||||||
|
preserve-linebreaks: true
|
||||||
|
|||||||
@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
||||||
|
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1,2 +1,4 @@
|
|||||||
/.quarto/
|
/.quarto/
|
||||||
_site/
|
_site/
|
||||||
|
/api/*.qmd
|
||||||
|
/api/*.html
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "CLI Reference"
|
title: "Command Line Interface (CLI)"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ tokenizer_legacy:
|
|||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
shrink_embeddings:
|
shrink_embeddings:
|
||||||
|
# Whether to load the model with randomly initialized weights. Useful for
|
||||||
|
# pre-training a model from scratch or debugging purposes.
|
||||||
|
random_init_weights:
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -463,6 +466,7 @@ auto_find_batch_size: # Optional[bool]
|
|||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
|
||||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||||
|
|
||||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||||
@@ -503,36 +507,58 @@ lr_div_factor: # Learning rate div factor
|
|||||||
|
|
||||||
# Specify optimizer
|
# Specify optimizer
|
||||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
||||||
#
|
#
|
||||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||||
# in the examples/ for your model and fine-tuning use case.
|
# in the examples/ for your model and fine-tuning use case.
|
||||||
#
|
#
|
||||||
# Valid values for 'optimizer' include:
|
# Valid values for 'optimizer' include:
|
||||||
# - adamw_hf
|
|
||||||
# - adamw_torch
|
# - adamw_torch
|
||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
|
# - adamw_torch_npu_fused
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
|
# - adamw_torch_4bit
|
||||||
|
# - ademamix
|
||||||
# - sgd
|
# - sgd
|
||||||
# - adagrad
|
# - adagrad
|
||||||
# - adamw_bnb_8bit
|
# - adamw_bnb_8bit
|
||||||
|
# - adamw_8bit # alias for adamw_bnb_8bit
|
||||||
|
# - ademamix_8bit
|
||||||
# - lion_8bit
|
# - lion_8bit
|
||||||
# - lion_32bit
|
# - lion_32bit
|
||||||
# - paged_adamw_32bit
|
# - paged_adamw_32bit
|
||||||
# - paged_adamw_8bit
|
# - paged_adamw_8bit
|
||||||
|
# - paged_ademamix_32bit
|
||||||
|
# - paged_ademamix_8bit
|
||||||
# - paged_lion_32bit
|
# - paged_lion_32bit
|
||||||
# - paged_lion_8bit
|
# - paged_lion_8bit
|
||||||
|
# - rmsprop
|
||||||
|
# - rmsprop_bnb
|
||||||
|
# - rmsprop_bnb_8bit
|
||||||
|
# - rmsprop_bnb_32bit
|
||||||
# - galore_adamw
|
# - galore_adamw
|
||||||
# - galore_adamw_8bit
|
# - galore_adamw_8bit
|
||||||
# - galore_adafactor
|
# - galore_adafactor
|
||||||
# - galore_adamw_layerwise
|
# - galore_adamw_layerwise
|
||||||
# - galore_adamw_8bit_layerwise
|
# - galore_adamw_8bit_layerwise
|
||||||
# - galore_adafactor_layerwise
|
# - galore_adafactor_layerwise
|
||||||
|
# - lomo
|
||||||
|
# - adalomo
|
||||||
|
# - grokadamw
|
||||||
|
# - schedule_free_adamw
|
||||||
|
# - schedule_free_sgd
|
||||||
|
# - apollo_adamw
|
||||||
|
# - apollo_adamw_layerwise
|
||||||
|
#
|
||||||
|
# Additional custom optimizers include:
|
||||||
|
# - optimi_adamw
|
||||||
|
# - ao_adamw_8bit
|
||||||
|
# - ao_adamw_fp8
|
||||||
optimizer:
|
optimizer:
|
||||||
# Dictionary of arguments to pass to the optimizer
|
# Dictionary of arguments to pass to the optimizer
|
||||||
optim_args:
|
optim_args:
|
||||||
@@ -584,6 +610,14 @@ resume_from_checkpoint:
|
|||||||
# Be careful with this being turned on between different models.
|
# Be careful with this being turned on between different models.
|
||||||
auto_resume_from_checkpoints: false
|
auto_resume_from_checkpoints: false
|
||||||
|
|
||||||
|
## Multimodal section
|
||||||
|
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
||||||
|
# Will read from model/processor config if not set.
|
||||||
|
image_size:
|
||||||
|
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
||||||
|
image_resize_algorithm: 'bilinear'
|
||||||
|
## End of multimodal section
|
||||||
|
|
||||||
# Don't mess with this, it's here for accelerate and torchrun
|
# Don't mess with this, it's here for accelerate and torchrun
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|
||||||
@@ -617,6 +651,14 @@ ddp_timeout:
|
|||||||
ddp_bucket_cap_mb:
|
ddp_bucket_cap_mb:
|
||||||
ddp_broadcast_buffers:
|
ddp_broadcast_buffers:
|
||||||
|
|
||||||
|
# Sequence parallelism
|
||||||
|
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||||
|
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||||
|
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||||
|
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||||
|
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||||
|
sequence_parallel_degree:
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ description: How datasets are processed
|
|||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||||
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
the [dataset format](dataset-formats) and prompt strategies to:
|
||||||
|
|
||||||
- parse the dataset based on the *dataset format*
|
- parse the dataset based on the *dataset format*
|
||||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||||
|
|||||||
@@ -103,8 +103,7 @@ This uses the same tags as the [`main` image](#sec-main-tags).
|
|||||||
|
|
||||||
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
||||||
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
||||||
- `PUBLIC_KEY`: Add a public key for the SSH service.
|
- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service.
|
||||||
- `SSH_KEY`: Add a private key for the SSH service.
|
|
||||||
|
|
||||||
#### Volume mounts
|
#### Volume mounts
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||||
|
|
||||||
|
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
|
||||||
|
|
||||||
|
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
@@ -1,28 +1,171 @@
|
|||||||
# MultiModal / Vision Language Models (BETA)
|
---
|
||||||
|
title: MultiModal / Vision Language Models (BETA)
|
||||||
|
format:
|
||||||
|
html:
|
||||||
|
toc: true
|
||||||
|
toc-depth: 3
|
||||||
|
---
|
||||||
|
|
||||||
### Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- Mllama, i.e. llama with vision models
|
- [Mllama](#sec-mllama)
|
||||||
|
- [Pixtral](#sec-pixtral)
|
||||||
|
- [Llava-1.5](#sec-llava-15)
|
||||||
|
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||||
|
- [Gemma-3](#sec-gemma-3)
|
||||||
|
- [Qwen2-VL](#sec-qwen2-vl)
|
||||||
|
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||||
|
|
||||||
### Usage
|
## Usage
|
||||||
|
|
||||||
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
|
Multimodal support is limited and doesn't have full feature parity.
|
||||||
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
|
|
||||||
|
Here are the hyperparams you'll need to use to finetune a multimodal model.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
skip_prepare_dataset: true
|
|
||||||
|
|
||||||
chat_template: llama3_2_vision
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
||||||
|
sample_packing: false # not yet supported with multimodal
|
||||||
|
|
||||||
|
chat_template: # see in next section
|
||||||
|
|
||||||
|
# example dataset
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
# only finetune the Language model, leave the vision model and vision tower frozen
|
# (optional) if doing lora, only finetune the Language model,
|
||||||
|
# leave the vision model and vision tower frozen
|
||||||
|
# load_in_8bit: true
|
||||||
|
adapter: lora
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# (optional) if you want to resize images to a set size
|
||||||
|
image_size: 512
|
||||||
|
image_resize_algorithm: bilinear
|
||||||
|
```
|
||||||
|
|
||||||
|
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||||
|
:::
|
||||||
|
|
||||||
|
### Mllama {#sec-mllama}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
|
||||||
|
chat_template: llama3_2_vision
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pixtral {#sec-pixtral}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: mistralai/Pixtral-12B-2409
|
||||||
|
|
||||||
|
chat_template: pixtral
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llava-1.5 {#sec-llava-15}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: llava-hf/llava-1.5-7b-hf
|
||||||
|
|
||||||
|
chat_template: llava
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||||
|
|
||||||
|
chat_template: mistral_v7_tekken
|
||||||
|
```
|
||||||
|
|
||||||
|
### Gemma-3 {#sec-gemma-3}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
The Gemma3-1B model is a text-only model, so please train as regular text model.
|
||||||
|
:::
|
||||||
|
|
||||||
|
For multi-modal 4B/12B/27B models, use the following config:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: google/gemma-3-4b-it
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
```
|
||||||
|
|
||||||
|
### Qwen2-VL {#sec-qwen2-vl}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
|
||||||
|
chat_template: qwen2_vl
|
||||||
|
```
|
||||||
|
|
||||||
|
### Qwen2.5-VL {#sec-qwen25-vl}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
|
||||||
|
chat_template: qwen2_vl # same as qwen2-vl
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataset Format
|
||||||
|
|
||||||
|
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||||
|
|
||||||
|
- A message is a list of `role` and `content`.
|
||||||
|
- `role` can be `system`, `user`, `assistant`, etc.
|
||||||
|
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
For backwards compatibility:
|
||||||
|
|
||||||
|
- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key.
|
||||||
|
- If `content` is a string, it will be converted to a list with `type` as `text`.
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
|
||||||
|
|
||||||
|
- `"path": "/path/to/image.jpg"`
|
||||||
|
- `"url": "https://example.com/image.jpg"`
|
||||||
|
- `"base64": "..."`
|
||||||
|
- `"image": PIL.Image`
|
||||||
|
:::
|
||||||
|
|
||||||
|
Here is an example of a multi-modal dataset:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are a helpful assistant."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
||||||
|
{"type": "text", "text": "Describe this image in detail."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The image is a bee."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
```
|
```
|
||||||
|
|||||||
90
docs/sequence_parallelism.qmd
Normal file
90
docs/sequence_parallelism.qmd
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
---
|
||||||
|
title: Sequence Parallelism
|
||||||
|
description: Train with long sequences split across multiple GPUs.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Sequence Parallelism
|
||||||
|
|
||||||
|
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||||
|
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||||
|
GPU processes a different portion of the sequence, and the results are aggregated
|
||||||
|
through a ring communication pattern.
|
||||||
|
|
||||||
|
## When to Use Sequence Parallelism
|
||||||
|
|
||||||
|
Use sequence parallelism when:
|
||||||
|
|
||||||
|
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||||
|
- You have multiple GPUs available
|
||||||
|
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To enable sequence parallelism, add the following to your configuration file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Set to a divisor (> 1) of the number of GPUs available
|
||||||
|
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||||
|
```
|
||||||
|
|
||||||
|
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||||
|
|
||||||
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||||
|
- With 4 GPUs, valid values would be 2 or 4
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
When sequence parallelism is enabled:
|
||||||
|
|
||||||
|
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||||
|
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||||
|
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
||||||
|
4. The trainer uses special ring communication patterns for attention operations
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
To use sequence parallelism, you need:
|
||||||
|
|
||||||
|
- Multiple GPUs (at least 2)
|
||||||
|
- The `ring-flash-attn` package. Install with:
|
||||||
|
- `pip install axolotl[ring-flash-attn]` (preferred)
|
||||||
|
- `pip install ring-flash-attn>=0.1.4`
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
||||||
|
- May have a small performance overhead due to communication between GPUs
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Example config with sequence parallelism
|
||||||
|
base_model: meta-llama/Llama-3-8B-Instruct
|
||||||
|
sequence_len: 8192
|
||||||
|
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
||||||
|
flash_attention: true # Required with sequence parallelism
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||||
|
into 2 subsequences of length 4096 across 2 GPUs.
|
||||||
|
|
||||||
|
## Sample Packing with Sequence Parallelism
|
||||||
|
|
||||||
|
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||||
|
|
||||||
|
1. Samples are first packed together
|
||||||
|
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||||
|
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||||
|
|
||||||
|
## Effect on Batch Size
|
||||||
|
|
||||||
|
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||||
|
|
||||||
|
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||||
|
- The number of batches processed per step decreases
|
||||||
|
|
||||||
|
For example:
|
||||||
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||||
|
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||||
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||||
71
examples/cohere/command-r-7b-qlora.yml
Normal file
71
examples/cohere/command-r-7b-qlora.yml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: CohereForAI/c4ai-command-r7b-12-2024
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: cohere
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
74
examples/gemma3/gemma-3-1b-qlora.yml
Normal file
74
examples/gemma3/gemma-3-1b-qlora.yml
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
base_model: google/gemma-3-1b-it
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
63
examples/gemma3/gemma-3-4b-lora.yml
Normal file
63
examples/gemma3/gemma-3-4b-lora.yml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: google/gemma-3-4b-it
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -19,7 +19,6 @@ val_set_size: 0.0
|
|||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
dataset_exact_deduplication: true
|
dataset_exact_deduplication: true
|
||||||
test_value: true
|
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
63
examples/llava/lora-7b.yaml
Normal file
63
examples/llava/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: llava-hf/llava-1.5-7b-hf
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: llava
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: mistral_v7_tekken
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
65
examples/pixtral/lora-12b.yml
Normal file
65
examples/pixtral/lora-12b.yml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: mistral-community/pixtral-12b
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: pixtral
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <pad>
|
||||||
63
examples/qwen2-vl/lora-7b.yaml
Normal file
63
examples/qwen2-vl/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: qwen2_vl
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -2,3 +2,5 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
|
quartodoc
|
||||||
|
jupyter
|
||||||
|
|||||||
@@ -4,19 +4,18 @@
|
|||||||
bitsandbytes==0.45.3
|
bitsandbytes==0.45.3
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
flash-attn==2.7.4.post1
|
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
liger-kernel==0.5.3
|
liger-kernel==0.5.5
|
||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.15.0
|
peft==0.15.0
|
||||||
transformers==4.49.0
|
transformers==4.50.0
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.5.2
|
accelerate==1.5.2
|
||||||
datasets==3.4.1
|
datasets==3.5.0
|
||||||
deepspeed==0.16.4
|
deepspeed==0.16.4
|
||||||
trl==0.15.1
|
trl==0.15.1
|
||||||
|
|
||||||
@@ -36,6 +35,7 @@ einops
|
|||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4,<=2.0.1
|
numpy>=1.24.4,<=2.0.1
|
||||||
|
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
|
|||||||
@@ -1,315 +0,0 @@
|
|||||||
accelerate==0.34.1
|
|
||||||
addict==2.4.0
|
|
||||||
aiofiles==23.2.1
|
|
||||||
aiohttp==3.9.0
|
|
||||||
aiosignal==1.3.1
|
|
||||||
aiostream==0.5.2
|
|
||||||
alembic==1.13.1
|
|
||||||
annotated-types==0.6.0
|
|
||||||
annoy==1.17.3
|
|
||||||
ansible==6.7.0
|
|
||||||
ansible-core==2.13.13
|
|
||||||
ansible-vault==2.1.0
|
|
||||||
anyio==3.7.1
|
|
||||||
appdirs==1.4.4
|
|
||||||
art==6.0
|
|
||||||
asgiref==3.7.2
|
|
||||||
async-timeout==4.0.2
|
|
||||||
attrdict==2.0.1
|
|
||||||
attrs==22.2.0
|
|
||||||
awscli==1.32.75
|
|
||||||
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
|
||||||
backoff==2.2.1
|
|
||||||
base58==2.1.1
|
|
||||||
beartype==0.17.2
|
|
||||||
bitnet==0.2.1
|
|
||||||
bitsandbytes==0.42.0
|
|
||||||
bittensor==6.7.0
|
|
||||||
black==23.7.0
|
|
||||||
blinker==1.7.0
|
|
||||||
boto3==1.34.75
|
|
||||||
botocore==1.34.75
|
|
||||||
cachetools==5.3.3
|
|
||||||
cachy==0.1.1
|
|
||||||
certifi==2023.7.22
|
|
||||||
cffi==1.16.0
|
|
||||||
cfgv==3.3.1
|
|
||||||
chai-guanaco==1.2.4
|
|
||||||
charset-normalizer==3.2.0
|
|
||||||
cleo==0.6.8
|
|
||||||
click==8.1.7
|
|
||||||
cloudpickle==2.0.0
|
|
||||||
cohere==4.11.2
|
|
||||||
colorama==0.4.4
|
|
||||||
coloredlogs==15.0.1
|
|
||||||
CoLT5-attention==0.10.20
|
|
||||||
contextlib2==21.6.0
|
|
||||||
contourpy==1.2.0
|
|
||||||
cryptography==41.0.3
|
|
||||||
cycler==0.12.1
|
|
||||||
cytoolz==0.12.3
|
|
||||||
databricks-cli==0.18.0
|
|
||||||
dataclasses-json==0.5.7
|
|
||||||
datasets==2.11.0
|
|
||||||
ddt==1.6.0
|
|
||||||
decorator==5.1.1
|
|
||||||
deepspeed==0.15.0
|
|
||||||
# Editable Git install with no remote (dialogpt==0.1)
|
|
||||||
-e /Users/wing/Projects/ml/dialogpt/src
|
|
||||||
dill==0.3.6
|
|
||||||
distlib==0.3.6
|
|
||||||
docker==7.0.0
|
|
||||||
docker-pycreds==0.4.0
|
|
||||||
docstring-parser==0.15
|
|
||||||
docutils==0.16
|
|
||||||
ecdsa==0.18.0
|
|
||||||
einops==0.7.0
|
|
||||||
einops-exts==0.0.4
|
|
||||||
einx==0.1.3
|
|
||||||
entrypoints==0.4
|
|
||||||
eth-hash==0.6.0
|
|
||||||
eth-keys==0.5.0
|
|
||||||
eth-typing==4.0.0
|
|
||||||
eth-utils==2.3.1
|
|
||||||
evaluate==0.4.0
|
|
||||||
exceptiongroup==1.1.1
|
|
||||||
fastapi==0.109.2
|
|
||||||
fastcore==1.5.29
|
|
||||||
ffmpy==0.4.0
|
|
||||||
filelock==3.12.2
|
|
||||||
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
|
||||||
fire==0.5.0
|
|
||||||
first==2.0.2
|
|
||||||
flake8==7.0.0
|
|
||||||
Flask==3.0.1
|
|
||||||
fonttools==4.47.2
|
|
||||||
frozendict==2.4.1
|
|
||||||
frozenlist==1.3.3
|
|
||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
|
||||||
fsspec==2023.6.0
|
|
||||||
fuzzywuzzy==0.18.0
|
|
||||||
gitdb==4.0.10
|
|
||||||
GitPython==3.1.31
|
|
||||||
google-pasta==0.2.0
|
|
||||||
gradio==4.42.0
|
|
||||||
gradio_client==1.3.0
|
|
||||||
greenlet==2.0.2
|
|
||||||
grpclib==0.4.7
|
|
||||||
gunicorn==21.2.0
|
|
||||||
h11==0.14.0
|
|
||||||
h2==4.1.0
|
|
||||||
hpack==4.0.0
|
|
||||||
httpcore==0.17.3
|
|
||||||
httpx==0.24.1
|
|
||||||
huggingface-hub==0.23.4
|
|
||||||
humanfriendly==10.0
|
|
||||||
hyperframe==6.0.1
|
|
||||||
identify==2.5.24
|
|
||||||
idna==3.4
|
|
||||||
immutables==0.20
|
|
||||||
importlib-metadata==6.7.0
|
|
||||||
importlib-resources==6.1.1
|
|
||||||
inflection==0.5.1
|
|
||||||
iniconfig==2.0.0
|
|
||||||
itsdangerous==2.1.2
|
|
||||||
Jinja2==3.1.2
|
|
||||||
jmespath==1.0.1
|
|
||||||
joblib==1.3.2
|
|
||||||
jsonlines==3.1.0
|
|
||||||
jsonschema==2.6.0
|
|
||||||
kiwisolver==1.4.5
|
|
||||||
langchain==0.0.144
|
|
||||||
Levenshtein==0.24.0
|
|
||||||
libcst==1.1.0
|
|
||||||
liger-kernel==0.0.0
|
|
||||||
lion-pytorch==0.1.2
|
|
||||||
llama-cpp-python==0.1.36
|
|
||||||
llvmlite==0.40.1
|
|
||||||
local-attention==1.9.0
|
|
||||||
loguru==0.7.0
|
|
||||||
Mako==1.3.2
|
|
||||||
Markdown==3.5.2
|
|
||||||
markdown-it-py==3.0.0
|
|
||||||
markdown2==2.4.10
|
|
||||||
MarkupSafe==2.1.2
|
|
||||||
marshmallow==3.19.0
|
|
||||||
marshmallow-enum==1.5.1
|
|
||||||
matplotlib==3.8.2
|
|
||||||
mccabe==0.7.0
|
|
||||||
mdurl==0.1.2
|
|
||||||
MEGABYTE-pytorch==0.0.7
|
|
||||||
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
|
||||||
mlflow==2.10.0
|
|
||||||
modal==0.62.77
|
|
||||||
more-itertools==10.2.0
|
|
||||||
mpmath==1.2.1
|
|
||||||
msgpack==1.0.7
|
|
||||||
msgpack-numpy-opentensor==0.5.0
|
|
||||||
multidict==6.0.4
|
|
||||||
multiprocess==0.70.14
|
|
||||||
munch==2.5.0
|
|
||||||
mypy==1.3.0
|
|
||||||
mypy-extensions==1.0.0
|
|
||||||
nest-asyncio==1.6.0
|
|
||||||
netaddr==0.10.1
|
|
||||||
networkx==3.0rc1
|
|
||||||
nh3==0.2.14
|
|
||||||
nodeenv==1.8.0
|
|
||||||
nomic==2.0.2
|
|
||||||
numba==0.57.1
|
|
||||||
numexpr==2.8.4
|
|
||||||
numpy==1.24.4
|
|
||||||
oauthlib==3.2.2
|
|
||||||
openai==0.27.4
|
|
||||||
openapi==1.1.0
|
|
||||||
openapi-schema-pydantic==1.2.4
|
|
||||||
optimum==1.8.6
|
|
||||||
orjson==3.10.7
|
|
||||||
packaging==23.1
|
|
||||||
pandas==2.0.0
|
|
||||||
parameterized==0.9.0
|
|
||||||
password-strength==0.0.3.post2
|
|
||||||
pastel==0.1.1
|
|
||||||
pathos==0.3.0
|
|
||||||
pathspec==0.11.1
|
|
||||||
pathtools==0.1.2
|
|
||||||
peft==0.11.1
|
|
||||||
pendulum==3.0.0
|
|
||||||
Pillow==9.5.0
|
|
||||||
pip-tools==1.11.0
|
|
||||||
platformdirs==3.2.0
|
|
||||||
pluggy==1.4.0
|
|
||||||
poetry==0.7.1
|
|
||||||
pox==0.3.2
|
|
||||||
ppft==1.7.6.6
|
|
||||||
pre-commit==3.3.2
|
|
||||||
prettytable==3.10.0
|
|
||||||
prompt-toolkit==3.0.39
|
|
||||||
protobuf==3.20.2
|
|
||||||
protobuf3-to-dict==0.1.5
|
|
||||||
psutil==5.9.5
|
|
||||||
psycopg==3.1.18
|
|
||||||
PuLP==2.8.0
|
|
||||||
py==1.11.0
|
|
||||||
py-bip39-bindings==0.1.11
|
|
||||||
py-cpuinfo==9.0.0
|
|
||||||
py-ed25519-zebra-bindings==1.0.1
|
|
||||||
py-sr25519-bindings==0.2.0
|
|
||||||
pyarrow==11.0.0
|
|
||||||
pyasn1==0.6.0
|
|
||||||
pycodestyle==2.11.1
|
|
||||||
pycparser==2.21
|
|
||||||
pycryptodome==3.20.0
|
|
||||||
pydantic==2.5.3
|
|
||||||
pydantic_core==2.14.6
|
|
||||||
pydub==0.25.1
|
|
||||||
pyfiglet==0.8.post1
|
|
||||||
pyflakes==3.2.0
|
|
||||||
Pygments==2.15.1
|
|
||||||
PyJWT==2.8.0
|
|
||||||
pylev==1.4.0
|
|
||||||
PyNaCl==1.5.0
|
|
||||||
pynvml==11.5.0
|
|
||||||
pyparsing==2.4.7
|
|
||||||
pyrsistent==0.14.11
|
|
||||||
pytest==8.0.2
|
|
||||||
pytest-asyncio==0.23.4
|
|
||||||
python-dateutil==2.8.2
|
|
||||||
python-dotenv==1.0.1
|
|
||||||
python-Levenshtein==0.24.0
|
|
||||||
python-multipart==0.0.9
|
|
||||||
pytz==2023.3
|
|
||||||
PyYAML==6.0.1
|
|
||||||
querystring-parser==1.2.4
|
|
||||||
rapidfuzz==3.6.1
|
|
||||||
regex==2023.6.3
|
|
||||||
requests==2.31.0
|
|
||||||
requests-toolbelt==0.8.0
|
|
||||||
resolvelib==0.8.1
|
|
||||||
responses==0.18.0
|
|
||||||
retry==0.9.2
|
|
||||||
rich==13.7.0
|
|
||||||
rsa==4.7.2
|
|
||||||
ruff==0.6.3
|
|
||||||
s3transfer==0.10.1
|
|
||||||
safetensors==0.4.5
|
|
||||||
sagemaker==2.148.0
|
|
||||||
scalecodec==1.2.7
|
|
||||||
schedulefree==1.2.1
|
|
||||||
schema==0.7.5
|
|
||||||
scikit-learn==1.4.0
|
|
||||||
scipy==1.9.3
|
|
||||||
seaborn==0.13.2
|
|
||||||
semantic-version==2.10.0
|
|
||||||
sentencepiece==0.2.0
|
|
||||||
sentry-sdk==1.19.1
|
|
||||||
setproctitle==1.3.2
|
|
||||||
shellingham==1.5.4
|
|
||||||
shortuuid==1.0.11
|
|
||||||
shtab==1.6.5
|
|
||||||
sigtools==4.0.1
|
|
||||||
six==1.16.0
|
|
||||||
skypilot==0.4.1
|
|
||||||
smdebug-rulesconfig==1.0.1
|
|
||||||
smmap==5.0.0
|
|
||||||
sniffio==1.3.0
|
|
||||||
SQLAlchemy==1.4.47
|
|
||||||
sqlparse==0.4.4
|
|
||||||
starlette==0.36.3
|
|
||||||
substrate-interface==1.5.2
|
|
||||||
svgwrite==1.4.3
|
|
||||||
sympy==1.11.1
|
|
||||||
synchronicity==0.6.7
|
|
||||||
tabulate==0.9.0
|
|
||||||
tblib==1.7.0
|
|
||||||
tenacity==8.2.2
|
|
||||||
tensor-parallel==2.0.0
|
|
||||||
termcolor==2.2.0
|
|
||||||
text2art==0.2.0
|
|
||||||
threadpoolctl==3.2.0
|
|
||||||
tiktoken==0.6.0
|
|
||||||
time-machine==2.14.1
|
|
||||||
timm==0.9.16
|
|
||||||
tokenizers==0.19.1
|
|
||||||
tokenmonster==1.1.12
|
|
||||||
toml==0.9.6
|
|
||||||
tomli==2.0.1
|
|
||||||
tomlkit==0.12.0
|
|
||||||
toolz==0.12.1
|
|
||||||
torch==2.2.0
|
|
||||||
torchdata==0.6.1
|
|
||||||
torchdiffeq==0.2.3
|
|
||||||
TorchFix==0.4.0
|
|
||||||
torchtext==0.15.2
|
|
||||||
torchvision==0.17.0
|
|
||||||
tqdm==4.66.2
|
|
||||||
transformers==4.44.2
|
|
||||||
trl==0.9.6
|
|
||||||
typer==0.12.5
|
|
||||||
types-certifi==2021.10.8.3
|
|
||||||
types-requests==2.31.0.20240125
|
|
||||||
types-setuptools==69.0.0.20240125
|
|
||||||
types-toml==0.10.8.7
|
|
||||||
typing==3.7.4.3
|
|
||||||
typing-inspect==0.8.0
|
|
||||||
typing_extensions==4.9.0
|
|
||||||
tyro==0.5.18
|
|
||||||
tzdata==2023.3
|
|
||||||
unique-names-generator==1.0.2
|
|
||||||
urllib3==2.2.2
|
|
||||||
uvicorn==0.22.0
|
|
||||||
vector_quantize_pytorch==1.14.1
|
|
||||||
virtualenv==20.23.0
|
|
||||||
voyager==2.0.2
|
|
||||||
wandb==0.16.2
|
|
||||||
watchfiles==0.21.0
|
|
||||||
wavedrom==2.0.3.post3
|
|
||||||
wcwidth==0.2.6
|
|
||||||
websocket-client==1.7.0
|
|
||||||
websockets==12.0
|
|
||||||
Werkzeug==3.0.1
|
|
||||||
wonderwords==2.2.0
|
|
||||||
xxhash==3.2.0
|
|
||||||
yarl==1.8.2
|
|
||||||
zetascale==2.2.7
|
|
||||||
zipp==3.15.0
|
|
||||||
22
setup.py
22
setup.py
@@ -16,13 +16,7 @@ def parse_requirements():
|
|||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
is_extras = (
|
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
||||||
"flash-attn" in line
|
|
||||||
or "flash-attention" in line
|
|
||||||
or "deepspeed" in line
|
|
||||||
or "mamba-ssm" in line
|
|
||||||
or "lion-pytorch" in line
|
|
||||||
)
|
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
_, url = line.split()
|
_, url = line.split()
|
||||||
@@ -39,7 +33,6 @@ def parse_requirements():
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"triton",
|
"triton",
|
||||||
"mamba-ssm",
|
"mamba-ssm",
|
||||||
"flash-attn",
|
|
||||||
"xformers",
|
"xformers",
|
||||||
"autoawq",
|
"autoawq",
|
||||||
"liger-kernel",
|
"liger-kernel",
|
||||||
@@ -124,9 +117,8 @@ setup(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||||
"flash-attn==2.7.4.post1",
|
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
||||||
],
|
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.4",
|
"deepspeed==0.16.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
@@ -141,15 +133,15 @@ setup(
|
|||||||
"mlflow": [
|
"mlflow": [
|
||||||
"mlflow",
|
"mlflow",
|
||||||
],
|
],
|
||||||
"lion-pytorch": [
|
|
||||||
"lion-pytorch==0.1.2",
|
|
||||||
],
|
|
||||||
"galore": [
|
"galore": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
],
|
],
|
||||||
|
"apollo": [
|
||||||
|
"apollo-torch",
|
||||||
|
],
|
||||||
"optimizers": [
|
"optimizers": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
"lion-pytorch==0.1.2",
|
"apollo-torch",
|
||||||
"lomo-optim==0.1.1",
|
"lomo-optim==0.1.1",
|
||||||
"torch-optimi==0.2.1",
|
"torch-optimi==0.2.1",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def do_inference(
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Inference-specific CLI arguments.
|
cli_args: Inference-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
@@ -151,7 +151,7 @@ def do_inference_gradio(
|
|||||||
"""
|
"""
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from axolotl.cli.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
LOG.info("Running merge of LoRA with base model...")
|
LOG.info("Running merge of LoRA with base model...")
|
||||||
@@ -44,6 +44,9 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
)
|
)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
|
if processor:
|
||||||
|
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -17,13 +17,14 @@ from axolotl.cli.config import load_cfg
|
|||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||||
"""
|
"""
|
||||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||||
@@ -33,6 +34,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Training-specific CLI arguments.
|
cli_args: Training-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
@@ -44,16 +48,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
del model, tokenizer, trainer
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
del model
|
|
||||||
del tokenizer
|
|
||||||
del trainer
|
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,16 @@ from typing import Any, Callable, Type, Union, get_args, get_origin
|
|||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import (
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
ProcessorMixin,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -295,9 +300,13 @@ def load_model_and_tokenizer(
|
|||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
) -> tuple[
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||||
|
ProcessorMixin | None,
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||||
config.
|
config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -305,7 +314,7 @@ def load_model_and_tokenizer(
|
|||||||
inference: Boolean denoting inference mode.
|
inference: Boolean denoting inference mode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`transformers` model and tokenizer.
|
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||||
"""
|
"""
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
@@ -313,4 +322,9 @@ def load_model_and_tokenizer(
|
|||||||
LOG.info("loading model...")
|
LOG.info("loading model...")
|
||||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||||
|
|
||||||
return model, tokenizer
|
processor = None
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
LOG.info("loading processor...")
|
||||||
|
processor = load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
return model, tokenizer, processor
|
||||||
|
|||||||
@@ -13,9 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
"""
|
"""Builder for the training args and trainer"""
|
||||||
Builder for the training args and trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
import importlib
|
||||||
@@ -38,7 +36,7 @@ from transformers import (
|
|||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
@@ -62,6 +60,7 @@ from axolotl.core.training_args import (
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
|
from axolotl.processing_strategies import get_processing_strategy
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
@@ -85,8 +84,8 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
@@ -664,6 +663,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
optimizer_cls = MuonOptimizerFactory
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "soap":
|
||||||
|
from axolotl.utils.optimizers.soap import SOAP
|
||||||
|
|
||||||
|
optimizer_cls = SOAP
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "optimi_adamw":
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
@@ -749,6 +753,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.accelerator_config
|
self.cfg.accelerator_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cfg.image_size:
|
||||||
|
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||||
|
if self.cfg.image_resize_algorithm:
|
||||||
|
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||||
|
self.cfg.image_resize_algorithm
|
||||||
|
)
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
if self.cfg.kd_alpha is not None:
|
if self.cfg.kd_alpha is not None:
|
||||||
@@ -764,6 +774,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kd_top_k_before_softmax
|
self.cfg.kd_top_k_before_softmax
|
||||||
)
|
)
|
||||||
|
|
||||||
|
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||||
|
self.cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
elif self.cfg.process_reward_model:
|
elif self.cfg.process_reward_model:
|
||||||
@@ -847,9 +861,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if self.cfg.pretraining_sample_concatenation is False:
|
if (
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
self.cfg.pretraining_sample_concatenation is False
|
||||||
if self.cfg.micro_batch_size > 1:
|
or self.cfg.micro_batch_size > 1
|
||||||
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -877,9 +892,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
elif (
|
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
and self.cfg.flash_attention is not True
|
and self.cfg.flash_attention is not True
|
||||||
):
|
):
|
||||||
@@ -889,8 +902,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
if self.cfg.processor_type and self.processor:
|
if self.cfg.processor_type and self.processor:
|
||||||
collator = MultiModalChatDataCollator
|
collator = MultiModalChatDataCollator
|
||||||
kwargs["processor"] = self.processor
|
kwargs["processing_strategy"] = get_processing_strategy(
|
||||||
kwargs["chat_template"] = training_args.chat_template
|
self.processor,
|
||||||
|
training_args.chat_template,
|
||||||
|
self.cfg.chat_template,
|
||||||
|
image_size=training_args.image_size,
|
||||||
|
image_resize_algorithm=training_args.image_resize_algorithm,
|
||||||
|
)
|
||||||
elif self.cfg.batch_flattening:
|
elif self.cfg.batch_flattening:
|
||||||
collator = DataCollatorWithFlattening
|
collator = DataCollatorWithFlattening
|
||||||
collator_args.pop(0)
|
collator_args.pop(0)
|
||||||
@@ -910,6 +928,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
|
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||||
|
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""Init for axolotl.core.trainers"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .base import AxolotlTrainer
|
||||||
|
from .dpo.trainer import AxolotlDPOTrainer
|
||||||
|
from .grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
from .mamba import AxolotlMambaTrainer
|
||||||
|
from .relora import ReLoRATrainer
|
||||||
|
from .trl import (
|
||||||
|
AxolotlCPOTrainer,
|
||||||
|
AxolotlKTOTrainer,
|
||||||
|
AxolotlORPOTrainer,
|
||||||
|
AxolotlPRMTrainer,
|
||||||
|
AxolotlRewardTrainer,
|
||||||
|
TRLPPOTrainer,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,365 +1,47 @@
|
|||||||
"""
|
"""Module for customized trainers"""
|
||||||
module for customized trainers
|
|
||||||
"""
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Dict, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.utils.data import (
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
BatchSampler,
|
||||||
|
DataLoader,
|
||||||
|
RandomSampler,
|
||||||
|
Sampler,
|
||||||
|
SequentialSampler,
|
||||||
|
)
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
|
||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.core.trainers.mixins import (
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
OptimizerMixin,
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
SchedulerMixin,
|
||||||
from axolotl.utils.schedulers import (
|
SequenceParallelMixin,
|
||||||
RexLR,
|
|
||||||
get_cosine_schedule_with_min_lr,
|
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
|
||||||
get_cosine_schedule_with_warmup_decay_constant,
|
|
||||||
)
|
)
|
||||||
|
from axolotl.core.trainers.utils import (
|
||||||
|
sanitize_kwargs_for_ds_tagging,
|
||||||
|
sanitize_kwargs_for_tagging,
|
||||||
|
)
|
||||||
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
LOG = logging.getLogger(__name__)
|
||||||
import smdistributed.modelparallel.torch as smp
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
||||||
if isinstance(tag_names, str):
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
tag_names = [tag_names]
|
|
||||||
|
|
||||||
if kwargs is not None:
|
|
||||||
if "tags" not in kwargs:
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
|
||||||
kwargs["tags"].extend(tag_names)
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
|
||||||
tag_names.append(kwargs["tags"])
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
|
||||||
if isinstance(dataset_tags, str):
|
|
||||||
dataset_tags = [dataset_tags]
|
|
||||||
|
|
||||||
if (dataset_tags is not None) and (kwargs is not None):
|
|
||||||
if "dataset_tags" not in kwargs:
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
|
||||||
kwargs["dataset_tags"].extend(dataset_tags)
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
|
||||||
dataset_tags.append(kwargs["dataset_tags"])
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Mixin class for scheduler setup in CausalTrainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
|
||||||
passed as an argument.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_training_steps (int): The number of training steps to do.
|
|
||||||
optimizer (torch.optim.Optimizer): The training optimizer
|
|
||||||
"""
|
|
||||||
use_cosine_quadratic = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.lr_quadratic_warmup is True
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cosine_min_lr = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.cosine_min_lr_ratio is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
|
||||||
# fmt: on
|
|
||||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
|
||||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
|
||||||
extra_lr_kwargs = {}
|
|
||||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["pct_start"] = pct_start
|
|
||||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
|
||||||
|
|
||||||
self.lr_scheduler = OneCycleLR(
|
|
||||||
optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
**extra_lr_kwargs,
|
|
||||||
**self.args.lr_scheduler_kwargs,
|
|
||||||
)
|
|
||||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
|
|
||||||
self.lr_scheduler = RexLR(
|
|
||||||
optimizer=optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
)
|
|
||||||
elif use_cosine_quadratic:
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
|
||||||
else:
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Mixin class for shared handling of building custom optimizers
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(
|
|
||||||
self, opt_model, optimizer_kwargs
|
|
||||||
) -> list[dict]:
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
|
||||||
params: dict = {
|
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
|
||||||
"no_weight_decay": {},
|
|
||||||
}
|
|
||||||
lr_groups_lookup = {}
|
|
||||||
lr_groups_learning_rates = {}
|
|
||||||
if self.args.lr_groups:
|
|
||||||
for lr_group in self.args.lr_groups:
|
|
||||||
group_name = lr_group["name"]
|
|
||||||
group_modules = lr_group["modules"]
|
|
||||||
for module in group_modules:
|
|
||||||
lr_groups_lookup[module] = group_name
|
|
||||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
|
||||||
params[f"to_weight_decay_{group_name}"] = {}
|
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
if name.endswith("modules_to_save.default.weight") or any(
|
|
||||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
|
||||||
):
|
|
||||||
params["embeddings"][name] = param
|
|
||||||
elif name in decay_parameters:
|
|
||||||
lr_group_modules = [
|
|
||||||
group_modules
|
|
||||||
for group_modules in lr_groups_lookup
|
|
||||||
if group_modules in name
|
|
||||||
]
|
|
||||||
if lr_groups_lookup and any(lr_group_modules):
|
|
||||||
lr_group_module = lr_group_modules[0]
|
|
||||||
group_name = lr_groups_lookup[lr_group_module]
|
|
||||||
params[f"to_weight_decay_{group_name}"][name] = param
|
|
||||||
else:
|
|
||||||
params["to_weight_decay"][name] = param
|
|
||||||
else:
|
|
||||||
params["no_weight_decay"][name] = param
|
|
||||||
optimizer_grouped_parameters = []
|
|
||||||
if params["to_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["to_weight_decay"].values()),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["embeddings"]:
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
if self.args.embedding_lr_scale:
|
|
||||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
|
||||||
elif self.args.embedding_lr:
|
|
||||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["embeddings"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["no_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["no_weight_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
|
||||||
if params[f"to_weight_decay_{group_name}"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(
|
|
||||||
params[f"to_weight_decay_{group_name}"].values()
|
|
||||||
),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": group_lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if (
|
|
||||||
self.args.loraplus_lr_ratio is None
|
|
||||||
and self.args.embedding_lr_scale is None
|
|
||||||
and self.args.embedding_lr is None
|
|
||||||
and self.args.lr_groups is None
|
|
||||||
and self.optimizer_cls_and_kwargs is None
|
|
||||||
):
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
|
|
||||||
if (
|
|
||||||
not self.optimizer
|
|
||||||
and self.optimizer_cls_and_kwargs is not None
|
|
||||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
|
||||||
):
|
|
||||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
self.optimizer = optimizer_factory_cls()(
|
|
||||||
opt_model, self.args, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.optimizer:
|
|
||||||
if self.optimizer_cls_and_kwargs is not None:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
else:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args, opt_model
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
|
||||||
opt_model, optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
loraplus_lr_embedding = getattr(
|
|
||||||
self.args, "loraplus_lr_embedding", 1e-6
|
|
||||||
)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for GaLore optimizer.
|
|
||||||
if "params" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
|
||||||
|
|
||||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for LOMO optimizer.
|
|
||||||
if "model" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
|
||||||
|
|
||||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
|
||||||
# to avoid arguments conflicts.
|
|
||||||
if "optimizer_dict" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
|
||||||
"optimizer_dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer = optimizer_cls(
|
|
||||||
optimizer_grouped_parameters, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
for module in opt_model.modules():
|
|
||||||
if isinstance(module, nn.Embedding):
|
|
||||||
skipped += sum(
|
|
||||||
{
|
|
||||||
p.data_ptr(): p.numel() for p in module.parameters()
|
|
||||||
}.values()
|
|
||||||
)
|
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
|
||||||
manager.register_module_override(
|
|
||||||
module, "weight", {"optim_bits": 32}
|
|
||||||
)
|
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|
||||||
"""
|
|
||||||
Extend the base Trainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
@@ -376,12 +58,18 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
# Initialize sequence parallelism if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._setup_sequence_parallel()
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
@@ -394,142 +82,247 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _create_multipack_sampler(
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
self, base_sampler: Sampler, dataset: Dataset
|
||||||
if self.args.multipack_real_batches:
|
) -> MultipackBatchSampler:
|
||||||
batch_size = self.args.per_device_train_batch_size
|
"""
|
||||||
batch_max_len = self.args.max_seq_length
|
Helper method to create a `MultipackBatchSampler` for multipacking sequences
|
||||||
else:
|
for training.
|
||||||
batch_size = 1
|
|
||||||
train_batch_size = (
|
|
||||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
|
||||||
)
|
|
||||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
|
||||||
|
|
||||||
if self.args.curriculum_sampling:
|
Args:
|
||||||
sampler = SequentialSampler(self.train_dataset)
|
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
|
||||||
else:
|
dataset: Dataset to sample from.
|
||||||
sampler = RandomSampler(self.train_dataset)
|
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
Returns:
|
||||||
sampler,
|
Multipack (sample packing) batch sampler.
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
"""
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
if self.args.multipack_real_batches:
|
||||||
batch_max_len=batch_max_len,
|
batch_size = self.args.per_device_train_batch_size
|
||||||
batch_size=batch_size,
|
batch_max_len = self.args.max_seq_length
|
||||||
group_size=self.args.sample_packing_group_size,
|
else:
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
batch_size = 1
|
||||||
drop_last=True,
|
train_batch_size = (
|
||||||
|
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||||
)
|
)
|
||||||
if self.args.curriculum_sampling:
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
return SequentialSampler(self.train_dataset)
|
|
||||||
return super()._get_train_sampler()
|
|
||||||
|
|
||||||
def _get_eval_sampler(
|
return MultipackBatchSampler(
|
||||||
self, eval_dataset: Dataset
|
base_sampler,
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
lengths=get_dataset_lengths(dataset),
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
if self.args.multipack_real_batches:
|
batch_max_len=batch_max_len,
|
||||||
batch_size = self.args.per_device_eval_batch_size
|
batch_size=batch_size,
|
||||||
batch_max_len = self.args.max_seq_length
|
drop_last=True,
|
||||||
else:
|
)
|
||||||
batch_size = 1
|
|
||||||
batch_max_len = (
|
def _get_train_sampler(self) -> Sampler | None:
|
||||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
"""
|
||||||
)
|
Helper method to get the sampler for training. Handles cases for sequence
|
||||||
return MultipackBatchSampler(
|
parallelism, sample packing, and curriculum sampling (sequential).
|
||||||
SequentialSampler(eval_dataset),
|
|
||||||
lengths=get_dataset_lengths(self.eval_dataset),
|
Returns:
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
batch_max_len=batch_max_len,
|
depends on the passed training args.
|
||||||
batch_size=batch_size,
|
"""
|
||||||
group_size=self.args.sample_packing_group_size,
|
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
# Determine the base sampler first
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
||||||
|
elif self.args.curriculum_sampling:
|
||||||
|
base_sampler = SequentialSampler(self.train_dataset)
|
||||||
|
elif use_sample_packing:
|
||||||
|
base_sampler = RandomSampler(self.train_dataset)
|
||||||
|
else:
|
||||||
|
# Default to parent class implementation for standard random sampling
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
|
# Apply multipack wrapper if needed
|
||||||
|
if use_sample_packing:
|
||||||
|
return self._create_multipack_sampler(
|
||||||
|
base_sampler=base_sampler,
|
||||||
|
dataset=self.train_dataset,
|
||||||
)
|
)
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
return base_sampler
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
|
||||||
train_dataset = self.train_dataset
|
|
||||||
if "length" in train_dataset.features.keys():
|
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
|
||||||
data_collator = self.data_collator
|
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self._train_batch_size,
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
}
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params["prefetch_factor"] = (
|
|
||||||
self.args.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
sampler = self._get_train_sampler()
|
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||||
|
and sample packing cases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
|
depends on the passed training args.
|
||||||
|
"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||||
|
use_multipack = (
|
||||||
|
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine the base sampler
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
||||||
|
elif use_multipack:
|
||||||
|
base_sampler = SequentialSampler(eval_dataset)
|
||||||
|
else:
|
||||||
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
|
# Apply multipack wrapper if needed
|
||||||
|
if use_multipack:
|
||||||
|
return self._create_multipack_sampler(
|
||||||
|
base_sampler=base_sampler,
|
||||||
|
dataset=eval_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_sampler
|
||||||
|
|
||||||
|
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||||
|
"""Create common dataloader parameters for train or eval."""
|
||||||
|
batch_size = custom_batch_size or (
|
||||||
|
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"collate_fn": self.data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add persistent workers only for training
|
||||||
|
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||||
|
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||||
|
|
||||||
|
# Add prefetch factor if specified
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _prepare_dataloader(
|
||||||
|
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||||
|
):
|
||||||
|
"""Prepare a dataloader with the given dataset and sampler."""
|
||||||
|
# Get base parameters
|
||||||
|
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||||
|
|
||||||
|
# Add sampler configuration
|
||||||
|
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
|
# batch_size and batch_sampler are mutually exclusive
|
||||||
dataloader_params["batch_sampler"] = sampler
|
dataloader_params["batch_sampler"] = sampler
|
||||||
del dataloader_params["batch_size"]
|
del dataloader_params["batch_size"]
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
|
||||||
|
|
||||||
|
if not is_eval:
|
||||||
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
|
# Create the dataloader
|
||||||
|
dataloader = DataLoader(dataset, **dataloader_params)
|
||||||
|
|
||||||
|
if self.args.sample_packing and (
|
||||||
|
(not is_eval and not self.args.pretraining)
|
||||||
|
or (is_eval and self.args.eval_sample_packing is not False)
|
||||||
|
):
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
DataLoader(train_dataset, **dataloader_params)
|
|
||||||
)
|
|
||||||
return super().get_train_dataloader()
|
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
# Otherwise prepare with accelerator
|
||||||
|
return self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
"""Get dataloader for training"""
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
data_collator = self.data_collator # type: ignore
|
||||||
|
|
||||||
|
# Handle dataset preprocessing
|
||||||
|
if isinstance(train_dataset, datasets.Dataset):
|
||||||
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
|
if not self.args.sample_packing or self.args.pretraining:
|
||||||
|
train_dataset = self._remove_unused_columns(
|
||||||
|
train_dataset, description="training"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
|
data_collator,
|
||||||
|
description="training",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sampler and create dataloader
|
||||||
|
sampler = self._get_train_sampler()
|
||||||
|
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||||
|
|
||||||
|
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||||
|
"""Get dataloader for evaluation"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
)
|
)
|
||||||
if eval_dataset:
|
if "length" in eval_dataset.column_names:
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.train_data_collator
|
self.train_data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
# Handle sample packing or sequence parallelism
|
||||||
eval_dataset = (
|
if (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
self.args.sample_packing
|
||||||
|
and self.args.eval_sample_packing is not False
|
||||||
|
or self.args.sequence_parallel_degree > 1
|
||||||
|
):
|
||||||
|
# Get appropriate data collator
|
||||||
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.eval_data_collator
|
||||||
|
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
||||||
|
else self.data_collator
|
||||||
|
)
|
||||||
|
if "length" in eval_dataset.column_names:
|
||||||
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
|
|
||||||
|
# Handle dataset preprocessing for SP
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
if isinstance(eval_dataset, datasets.Dataset):
|
||||||
|
eval_dataset = self._remove_unused_columns(
|
||||||
|
eval_dataset, description="evaluation"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.data_collator, description="evaluation"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||||
|
batch_size = (
|
||||||
|
self.args.eval_batch_size
|
||||||
|
if self.args.sample_packing
|
||||||
|
else self.args.per_device_eval_batch_size
|
||||||
|
)
|
||||||
|
sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
dataloader = self._prepare_dataloader(
|
||||||
|
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
return dataloader
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
|
||||||
data_collator = self.data_collator
|
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self.args.eval_batch_size,
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
}
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params["prefetch_factor"] = (
|
|
||||||
self.args.dataloader_prefetch_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(eval_sampler, BatchSampler):
|
|
||||||
dataloader_params["batch_sampler"] = eval_sampler
|
|
||||||
del dataloader_params["batch_size"]
|
|
||||||
else:
|
|
||||||
dataloader_params["sampler"] = eval_sampler
|
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
DataLoader(eval_dataset, **dataloader_params)
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def _get_bench_sampler(
|
def _get_bench_sampler(
|
||||||
self, bench_dataset: Dataset
|
self, bench_dataset: Dataset
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> torch.utils.data.Sampler | None:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return SequentialSampler(bench_dataset)
|
return SequentialSampler(bench_dataset)
|
||||||
return None
|
return None
|
||||||
@@ -554,6 +347,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
):
|
):
|
||||||
@@ -570,6 +364,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return_outputs=return_outputs,
|
return_outputs=return_outputs,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().compute_loss(
|
return super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -744,10 +539,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@@ -764,15 +559,13 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs (`Dict[str, float]`):
|
logs: The values to log.
|
||||||
The values to log.
|
start_time: The start of training.
|
||||||
start_time (`Optional[float]`):
|
|
||||||
The start of training.
|
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -784,7 +577,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
) -> None:
|
) -> None:
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
@@ -797,110 +590,26 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
def training_step(
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Mamba specific trainer to handle loss calculation
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "mamba"]
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
self,
|
||||||
model,
|
model: nn.Module,
|
||||||
inputs,
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
num_items_in_batch: int | None = None,
|
||||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
) -> torch.Tensor:
|
||||||
):
|
"""
|
||||||
input_ids = inputs.pop("input_ids")
|
Perform a training step on a batch of inputs. Overrides the
|
||||||
lm_logits = model(input_ids).logits
|
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||||
|
enabled.
|
||||||
|
|
||||||
labels = input_ids.to(lm_logits.device)
|
Args:
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
model: Model to perform training step for.
|
||||||
labels = labels[:, 1:].contiguous()
|
inputs: Dictionary mapping.
|
||||||
|
"""
|
||||||
|
# Set up sequence parallelism for this step if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
# Proceed with normal training step
|
||||||
lm_loss = loss_fct(
|
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
return lm_loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class ReLoRATrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Trainer subclass that uses the OneCycleLR scheduler
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "relora"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.lr_scheduler = None
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self,
|
|
||||||
num_training_steps: int,
|
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
||||||
):
|
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
|
||||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
|
||||||
|
|
||||||
if self.args.relora_steps:
|
|
||||||
warmup_steps = (
|
|
||||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
|
||||||
)
|
|
||||||
anneal_steps = (
|
|
||||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
|
||||||
)
|
|
||||||
self.lr_scheduler = ReLoRAScheduler(
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler,
|
|
||||||
self.args.relora_steps,
|
|
||||||
anneal_steps,
|
|
||||||
warmup_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "prm"]
|
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.mixins import SchedulerMixin
|
||||||
SchedulerMixin,
|
from axolotl.core.trainers.utils import (
|
||||||
_sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
_sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
@@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
32
src/axolotl/core/trainers/mamba.py
Normal file
32
src/axolotl/core/trainers/mamba.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Module for mamba trainer"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
"""Mamba specific trainer to handle loss calculation"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "mamba"]
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
input_ids = inputs.pop("input_ids")
|
||||||
|
lm_logits = model(input_ids).logits
|
||||||
|
|
||||||
|
labels = input_ids.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lm_loss
|
||||||
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
8
src/axolotl/core/trainers/mixins/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Init for axolotl.core.trainers.mixins"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from .optimizer import OptimizerMixin
|
||||||
|
from .scheduler import SchedulerMixin
|
||||||
|
from .sequence_parallel import SequenceParallelMixin
|
||||||
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""Module for Axolotl trainer optimizer mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerMixin(Trainer):
|
||||||
|
"""Mixin class for shared handling of building custom optimizers"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_optimizer_grouped_parameters(
|
||||||
|
self, opt_model, optimizer_kwargs
|
||||||
|
) -> list[dict]:
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
params: dict = {
|
||||||
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
|
"no_weight_decay": {},
|
||||||
|
}
|
||||||
|
lr_groups_lookup = {}
|
||||||
|
lr_groups_learning_rates = {}
|
||||||
|
if self.args.lr_groups:
|
||||||
|
for lr_group in self.args.lr_groups:
|
||||||
|
group_name = lr_group["name"]
|
||||||
|
group_modules = lr_group["modules"]
|
||||||
|
for module in group_modules:
|
||||||
|
lr_groups_lookup[module] = group_name
|
||||||
|
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||||
|
params[f"to_weight_decay_{group_name}"] = {}
|
||||||
|
|
||||||
|
for name, param in opt_model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
lr_group_modules = [
|
||||||
|
group_modules
|
||||||
|
for group_modules in lr_groups_lookup
|
||||||
|
if group_modules in name
|
||||||
|
]
|
||||||
|
if lr_groups_lookup and any(lr_group_modules):
|
||||||
|
lr_group_module = lr_group_modules[0]
|
||||||
|
group_name = lr_groups_lookup[lr_group_module]
|
||||||
|
params[f"to_weight_decay_{group_name}"][name] = param
|
||||||
|
else:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||||
|
if params[f"to_weight_decay_{group_name}"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(
|
||||||
|
params[f"to_weight_decay_{group_name}"].values()
|
||||||
|
),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": group_lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
|
and self.args.lr_groups is None
|
||||||
|
and self.optimizer_cls_and_kwargs is None
|
||||||
|
):
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self.optimizer
|
||||||
|
and self.optimizer_cls_and_kwargs is not None
|
||||||
|
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||||
|
):
|
||||||
|
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
self.optimizer = optimizer_factory_cls()(
|
||||||
|
opt_model, self.args, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.optimizer:
|
||||||
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
else:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args, opt_model
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
|
opt_model, optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(
|
||||||
|
self.args, "loraplus_lr_embedding", 1e-6
|
||||||
|
)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
|
**optimizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for GaLore optimizer.
|
||||||
|
if "params" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
|
|
||||||
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for LOMO optimizer.
|
||||||
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||||
|
"optimizer_dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer = optimizer_cls(
|
||||||
|
optimizer_grouped_parameters, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
skipped = 0
|
||||||
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum(
|
||||||
|
{
|
||||||
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Module for Axolotl trainer scheduler mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.utils.schedulers import (
|
||||||
|
RexLR,
|
||||||
|
get_cosine_schedule_with_min_lr,
|
||||||
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
passed as an argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_training_steps (int): The number of training steps to do.
|
||||||
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
|
"""
|
||||||
|
use_cosine_quadratic = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cosine_min_lr = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.cosine_min_lr_ratio is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||||
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
extra_lr_kwargs = {}
|
||||||
|
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["pct_start"] = pct_start
|
||||||
|
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||||
|
|
||||||
|
self.lr_scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
**extra_lr_kwargs,
|
||||||
|
**self.args.lr_scheduler_kwargs,
|
||||||
|
)
|
||||||
|
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
|
||||||
|
self.lr_scheduler = RexLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
)
|
||||||
|
elif use_cosine_quadratic:
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
|
else:
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
131
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from datasets import Dataset
|
||||||
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
except ImportError:
|
||||||
|
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||||
|
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for sequence parallelism support in trainers.
|
||||||
|
|
||||||
|
This mixin provides functionality for handling sequence parallelism,
|
||||||
|
including creating appropriate samplers, managing data partitioning,
|
||||||
|
and updating ring flash attention parameters during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def _setup_sequence_parallel(self):
|
||||||
|
"""Set up sequence parallelism environment."""
|
||||||
|
self.ring_attn_group = get_ring_attn_group()
|
||||||
|
|
||||||
|
def _create_sequence_parallel_sampler(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
shuffle: bool = True,
|
||||||
|
is_eval: bool = False,
|
||||||
|
) -> DistributedSampler:
|
||||||
|
"""
|
||||||
|
Helper method to create sampler for sequence parallelism (SP).
|
||||||
|
|
||||||
|
We create a distributed sampler with rank equal to the SP group ID, which
|
||||||
|
means that all ranks in the SP group receive the same sample / set of samples
|
||||||
|
per training step. We also set the number of replicas equal to the number of
|
||||||
|
SP groups, which is a bit of a hack / unintended use, but works!
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Dataset to sample from.
|
||||||
|
shuffle: Whether to shuffle the dataset.
|
||||||
|
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Distributed sampler.
|
||||||
|
"""
|
||||||
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||||
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
return DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=num_sp_groups,
|
||||||
|
rank=sp_group_id,
|
||||||
|
seed=self.args.seed if shuffle else None,
|
||||||
|
shuffle=shuffle,
|
||||||
|
drop_last=not is_eval,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Get a training sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The training dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self._create_sequence_parallel_sampler(
|
||||||
|
dataset,
|
||||||
|
shuffle=not self.args.curriculum_sampling,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Get an evaluation sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_dataset: The evaluation dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self._create_sequence_parallel_sampler(
|
||||||
|
eval_dataset, shuffle=False, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
||||||
|
"""
|
||||||
|
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||||
|
the substituted ring_flash_attn. This is accomplished by using the passed
|
||||||
|
`input_ids`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Current batch of inputs.
|
||||||
|
"""
|
||||||
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
|
# parallel data collator
|
||||||
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
packed_seq_lens = [seq_len] * batch_size
|
||||||
|
|
||||||
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
|
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
cu_seqlens = torch.cumsum(
|
||||||
|
torch.tensor(
|
||||||
|
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(
|
||||||
|
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||||
43
src/axolotl/core/trainers/relora.py
Normal file
43
src/axolotl/core/trainers/relora.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Module for ReLoRA trainer"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRATrainer(AxolotlTrainer):
|
||||||
|
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "relora"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self,
|
||||||
|
num_training_steps: int,
|
||||||
|
optimizer: torch.optim.Optimizer | None = None,
|
||||||
|
):
|
||||||
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
|
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
if self.args.relora_steps:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||||
|
)
|
||||||
|
anneal_steps = (
|
||||||
|
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||||
|
)
|
||||||
|
self.lr_scheduler = ReLoRAScheduler(
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
self.args.relora_steps,
|
||||||
|
anneal_steps,
|
||||||
|
warmup_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
@@ -1,16 +1,25 @@
|
|||||||
"""
|
"""Module for TRL PPO trainer"""
|
||||||
module for TRL PPO training
|
|
||||||
"""
|
from typing import Literal, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trl import PPOTrainer
|
from trl import (
|
||||||
|
CPOTrainer,
|
||||||
|
KTOTrainer,
|
||||||
|
ORPOTrainer,
|
||||||
|
PPOTrainer,
|
||||||
|
PRMTrainer,
|
||||||
|
RewardTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
class TRLPPOTrainer(PPOTrainer):
|
||||||
"""
|
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||||
wrapper for ppo trainer to handle customizations
|
|
||||||
"""
|
tag_names = ["axolotl", "ppo"]
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
@@ -31,9 +40,7 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||||
enumerate(self.dataloader)
|
|
||||||
):
|
|
||||||
query_tensors = batch["input_ids"]
|
query_tensors = batch["input_ids"]
|
||||||
|
|
||||||
# generate model response
|
# generate model response
|
||||||
@@ -65,3 +72,189 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
rewards,
|
rewards,
|
||||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
batch: dict[str, Union[list, torch.LongTensor]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
):
|
||||||
|
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
||||||
|
|
||||||
|
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
|
||||||
|
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
forward_output = self.concatenated_forward(model, batch)
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = forward_output[:5]
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
aux_loss = forward_output[5]
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
|
||||||
|
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
||||||
|
)
|
||||||
|
# full ORPO loss
|
||||||
|
loss = policy_nll_loss - losses.mean()
|
||||||
|
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
|
||||||
|
chosen_rewards
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
|
||||||
|
rejected_rewards
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
|
||||||
|
reward_accuracies
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
||||||
|
chosen_rewards - rejected_rewards
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}logps/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logps/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
|
||||||
|
policy_rejected_logits.detach().mean()
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
|
||||||
|
policy_chosen_logits.detach().mean()
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}nll_loss"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}log_odds_ratio"] = (
|
||||||
|
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}log_odds_chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
|
||||||
|
)
|
||||||
|
for k, v in metrics.items():
|
||||||
|
metrics[k] = v.item()
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
loss += self.aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
batch: dict[str, Union[list, torch.LongTensor]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
):
|
||||||
|
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
forward_output = self.concatenated_forward(model, batch)
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = forward_output[:5]
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
aux_loss = forward_output[5]
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics[f"{prefix}rewards/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}rewards/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}rewards/accuracies"] = (
|
||||||
|
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}rewards/margins"] = (
|
||||||
|
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logps/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_rejected_logps)
|
||||||
|
.detach()
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logps/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_chosen_logps)
|
||||||
|
.detach()
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logits/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logits/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}nll_loss"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
loss += self.aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base RewardTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base trl.PRMTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "prm"]
|
||||||
|
|||||||
33
src/axolotl/core/trainers/utils.py
Normal file
33
src/axolotl/core/trainers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Utils for Axolotl trainers"""
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||||
|
if isinstance(tag_names, str):
|
||||||
|
tag_names = [tag_names]
|
||||||
|
|
||||||
|
if kwargs is not None:
|
||||||
|
if "tags" not in kwargs:
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||||
|
kwargs["tags"].extend(tag_names)
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||||
|
tag_names.append(kwargs["tags"])
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||||
|
if isinstance(dataset_tags, str):
|
||||||
|
dataset_tags = [dataset_tags]
|
||||||
|
|
||||||
|
if (dataset_tags is not None) and (kwargs is not None):
|
||||||
|
if "dataset_tags" not in kwargs:
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||||
|
kwargs["dataset_tags"].extend(dataset_tags)
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||||
|
dataset_tags.append(kwargs["dataset_tags"])
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
|
return kwargs
|
||||||
@@ -5,6 +5,7 @@ extra axolotl specific training args
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL.Image import Resampling
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
@@ -207,14 +208,33 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sequence_parallel_degree: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# multi-modal section
|
||||||
|
|
||||||
|
image_size: int | tuple[int, int] | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The size of the image to resize to"},
|
||||||
|
)
|
||||||
|
|
||||||
|
image_resize_algorithm: Resampling | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The algorithm to use for image resizing"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# end of multi-modal section
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
||||||
so it can't be used as a mixin.
|
default value so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
@@ -25,18 +27,18 @@ LOG = get_logger("axolotl.evaluate")
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_dataset(
|
def evaluate_dataset(
|
||||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
) -> Optional[Dict[str, float]]:
|
) -> Optional[Dict[str, float]]:
|
||||||
"""Helper function to evaluate a single dataset safely.
|
"""Helper function to evaluate a single dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainer: The trainer instance
|
trainer: The trainer instance.
|
||||||
dataset: Dataset to evaluate
|
dataset: Dataset to evaluate.
|
||||||
dataset_type: Type of dataset ('train' or 'eval')
|
dataset_type: Type of dataset ('train' or 'eval').
|
||||||
flash_optimum: Whether to use flash optimum
|
flash_optimum: Whether to use flash optimum.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of metrics or None if dataset is None
|
Dictionary of metrics or None if dataset is None.
|
||||||
"""
|
"""
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
@@ -63,17 +65,14 @@ def evaluate_dataset(
|
|||||||
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets
|
Evaluate a model on training and validation datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Dictionary mapping metric names to their values.
|
||||||
- The model (either PeftModel or PreTrainedModel)
|
|
||||||
- The tokenizer
|
|
||||||
- Dictionary of evaluation metrics
|
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
|||||||
@@ -11,19 +11,17 @@
|
|||||||
# the License.
|
# the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
module to handle merging the plugins' input arguments with the base configurations.
|
Module to handle merging the plugins' input arguments with the base configurations.
|
||||||
|
|
||||||
this was moved here to prevent circular imports
|
This was moved here to prevent circular imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_input_args():
|
def merge_input_args():
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Cut Cross Entropy
|
# Cut Cross Entropy
|
||||||
|
|
||||||
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
Cut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
||||||
|
|
||||||
See https://github.com/apple/ml-cross-entropy
|
See https://github.com/apple/ml-cross-entropy
|
||||||
|
|
||||||
@@ -29,6 +29,20 @@ plugins:
|
|||||||
cut_cross_entropy: true
|
cut_cross_entropy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
- llama
|
||||||
|
- phi3
|
||||||
|
- gemma
|
||||||
|
- gemma2
|
||||||
|
- gemma3
|
||||||
|
- gemma3_text
|
||||||
|
- mistral
|
||||||
|
- mistral3
|
||||||
|
- qwen2
|
||||||
|
- cohere
|
||||||
|
- cohere2
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
```bib
|
```bib
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ import torch
|
|||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
|
from axolotl.utils.distributed import zero_only
|
||||||
|
|
||||||
from ...utils.distributed import zero_only
|
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||||
@@ -72,7 +72,9 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
if cfg.cut_cross_entropy:
|
if cfg.cut_cross_entropy:
|
||||||
self._check_requirements()
|
self._check_requirements()
|
||||||
|
|
||||||
from cut_cross_entropy.transformers import cce_patch
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
||||||
|
cce_patch,
|
||||||
|
)
|
||||||
|
|
||||||
with zero_only():
|
with zero_only():
|
||||||
LOG.info(
|
LOG.info(
|
||||||
|
|||||||
201
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
201
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""Cohere and Cohere2 CCE patch."""
|
||||||
|
|
||||||
|
# This patch is based off transformers 4.50.0.
|
||||||
|
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
||||||
|
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
||||||
|
# operation is done internally and should be mathematically equivalent.
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
COHERE_INPUTS_DOCSTRING,
|
||||||
|
KwargsForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||||
|
|
||||||
|
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||||
|
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||||
|
|
||||||
|
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>> # Generate
|
||||||
|
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
# scale weight by logit_scale in-place of logits
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight * self.logit_scale,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
logits = logits * self.logit_scale # main diff from Llama
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_cohere(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.cohere import modeling_cohere
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_cohere.CohereForCausalLM
|
||||||
|
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_cohere2(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.cohere2 import modeling_cohere2
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
||||||
|
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
175
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
175
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""Gemma CCE patch"""
|
||||||
|
|
||||||
|
# This patch is based off transformers 4.50.0.
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.gemma.modeling_gemma import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
GEMMA_INPUTS_DOCSTRING,
|
||||||
|
KwargsForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||||
|
|
||||||
|
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||||
|
|
||||||
|
>>> prompt = "What is your favorite condiment?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"What is your favorite condiment?"
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.gemma import modeling_gemma
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_gemma.GemmaForCausalLM
|
||||||
|
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
459
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
459
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
||||||
|
|
||||||
|
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
||||||
|
# and updated for transformers 4.50.0.
|
||||||
|
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
||||||
|
# with both gemma3 (text and multimodal) models.
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from transformers.cache_utils import Cache, HybridCache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
GEMMA3_INPUTS_DOCSTRING,
|
||||||
|
Gemma3CausalLMOutputWithPast,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_torchdynamo_compiling,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[HybridCache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
defer_logits_calculation: bool = False,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
defer_logits_calculation (`bool`, *optional*):
|
||||||
|
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||||
|
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||||
|
|
||||||
|
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||||
|
|
||||||
|
>>> prompt = "What is your favorite condiment?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"What is your favorite condiment?"
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||||
|
# defer logits calculation to the ConditionalGeneration forward
|
||||||
|
logits = hidden_states[:, slice_indices, :]
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if self.config.final_logit_softcapping is not None:
|
||||||
|
logits = logits / self.config.final_logit_softcapping
|
||||||
|
logits = torch.tanh(logits)
|
||||||
|
logits = logits * self.config.final_logit_softcapping
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
pixel_values: torch.FloatTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||||
|
|
||||||
|
>>> prompt = "answer en Where is the cow standing?"
|
||||||
|
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"answer en Where is the cow standing?\nbeach"
|
||||||
|
```"""
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
is_training = token_type_ids is not None and labels is not None
|
||||||
|
|
||||||
|
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||||
|
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||||
|
special_image_mask = input_ids == self.config.image_token_index
|
||||||
|
llm_input_ids = input_ids.clone()
|
||||||
|
llm_input_ids[special_image_mask] = 0
|
||||||
|
else:
|
||||||
|
llm_input_ids = input_ids # type: ignore
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = (
|
||||||
|
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
||||||
|
)
|
||||||
|
cache_position = torch.arange( # type: ignore
|
||||||
|
past_seen_tokens,
|
||||||
|
past_seen_tokens + inputs_embeds.shape[1],
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge text and images
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_features = self.get_image_features(pixel_values)
|
||||||
|
|
||||||
|
if input_ids is None:
|
||||||
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
|
torch.tensor(
|
||||||
|
self.config.image_token_index,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||||
|
inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not is_torchdynamo_compiling()
|
||||||
|
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||||
|
):
|
||||||
|
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of images does not match number of special image tokens in the input text. "
|
||||||
|
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||||
|
"tokens from image embeddings."
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||||
|
|
||||||
|
# mask out pad-token-ids in labels for BC
|
||||||
|
if labels is not None and self.pad_token_id in labels:
|
||||||
|
logger.warning_once(
|
||||||
|
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||||
|
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||||
|
)
|
||||||
|
labels = torch.where( # type: ignore
|
||||||
|
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||||
|
)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
past_key_values,
|
||||||
|
cache_position,
|
||||||
|
inputs_embeds,
|
||||||
|
is_training,
|
||||||
|
)
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
defer_logits_calculation=True, # enable deferred logits calculation
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.language_model.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = hidden_states
|
||||||
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
shift_logits = logits[..., :-1, :]
|
||||||
|
shift_labels = labels[..., 1:]
|
||||||
|
if attention_mask is not None:
|
||||||
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||||
|
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||||
|
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
||||||
|
logits.device
|
||||||
|
)
|
||||||
|
shift_logits = shift_logits[
|
||||||
|
shift_attention_mask.to(logits.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
shift_labels = shift_labels[
|
||||||
|
shift_attention_mask.to(shift_labels.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
else:
|
||||||
|
shift_logits = shift_logits.contiguous()
|
||||||
|
shift_labels = shift_labels.contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
||||||
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||||
|
loss = loss_fct(flat_logits, flat_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Gemma3CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_features if pixel_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma2(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.gemma2 import modeling_gemma2
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
||||||
|
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma3_text(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.gemma3 import modeling_gemma3
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
||||||
|
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma3(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.gemma3 import modeling_gemma3
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
||||||
|
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
# patch the causal model to enable deferred logits calculation
|
||||||
|
maybe_model.language_model.forward = MethodType(
|
||||||
|
cce_forward, maybe_model.language_model
|
||||||
|
)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
|
# patch the causal model to enable deferred logits calculation
|
||||||
|
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,392 @@
|
|||||||
|
"""Mistral and Mistral3 CCE patch."""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.mistral3.modeling_mistral3 import (
|
||||||
|
Mistral3CausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
MISTRAL_INPUTS_DOCSTRING,
|
||||||
|
KwargsForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
is_torchdynamo_compiling,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] | None = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
defer_logits_calculation: bool = False,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
defer_logits_calculation (`bool`, *optional*):
|
||||||
|
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||||
|
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||||
|
|
||||||
|
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||||
|
# defer logits calculation to the ConditionalGeneration forward
|
||||||
|
logits = hidden_states[:, slice_indices, :]
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
pixel_values: torch.FloatTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
image_sizes: torch.Tensor | None = None,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
|
||||||
|
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer
|
||||||
|
if vision_feature_layer is not None
|
||||||
|
else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if pixel_values is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_features = self.get_image_features(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
vision_feature_layer=vision_feature_layer,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||||
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||||
|
inputs_embeds.device
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
not is_torchdynamo_compiling()
|
||||||
|
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||||
|
):
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
|
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
defer_logits_calculation=True, # enable deferred logits calculation
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.language_model.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = hidden_states
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
if attention_mask is not None:
|
||||||
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||||
|
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||||
|
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||||
|
logits.device
|
||||||
|
)
|
||||||
|
shift_logits = logits[..., :-1, :][
|
||||||
|
shift_attention_mask.to(logits.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
shift_labels = labels[..., 1:][
|
||||||
|
shift_attention_mask.to(labels.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
else:
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
|
shift_labels.view(-1).to(shift_logits.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Mistral3CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_features if pixel_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mistral(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.mistral import modeling_mistral
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_mistral.MistralForCausalLM
|
||||||
|
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mistral3(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.mistral import modeling_mistral
|
||||||
|
from transformers.models.mistral3 import modeling_mistral3
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
||||||
|
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
# patch the causal model to enable deferred logits calculation
|
||||||
|
maybe_model.language_model.forward = MethodType(
|
||||||
|
cce_forward, maybe_model.language_model
|
||||||
|
)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
|
# patch the causal model to enable deferred logits calculation
|
||||||
|
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
379
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
379
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
"""Mllama CCE patch."""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.mllama.modeling_mllama import (
|
||||||
|
MLLAMA_INPUTS_DOCSTRING,
|
||||||
|
_prepare_cross_attention_mask,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
|
||||||
|
)
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.LongTensor] = None,
|
||||||
|
cross_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
defer_logits_calculation: bool = False,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
defer_logits_calculation (`bool`, *optional*):
|
||||||
|
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||||
|
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
||||||
|
|
||||||
|
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
||||||
|
|
||||||
|
>>> prompt = "If I had to write a haiku, it would be:"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
||||||
|
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
>>> print(result)
|
||||||
|
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
||||||
|
I love the idea of snowflakes gently falling, each one
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cross_attention_mask=cross_attention_mask,
|
||||||
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||||
|
# defer logits calculation to the ConditionalGeneration forward
|
||||||
|
logits = hidden_states[:, slice_indices, :]
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
|
||||||
|
)
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
||||||
|
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
||||||
|
|
||||||
|
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
||||||
|
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
||||||
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> output = model.generate(**inputs, max_new_tokens=15)
|
||||||
|
|
||||||
|
>>> prompt_len = inputs.input_ids.shape[-1]
|
||||||
|
>>> generated_ids = output[:, prompt_len:]
|
||||||
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
>>> print(generated_text)
|
||||||
|
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if pixel_values is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if pixel_values is not None and cross_attention_states is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
||||||
|
)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
if aspect_ratio_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||||
|
)
|
||||||
|
# get vision tokens from vision model
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
aspect_ratio_ids=aspect_ratio_ids,
|
||||||
|
aspect_ratio_mask=aspect_ratio_mask,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
cross_attention_states = vision_outputs[0]
|
||||||
|
cross_attention_states = self.multi_modal_projector(
|
||||||
|
cross_attention_states
|
||||||
|
).reshape(
|
||||||
|
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
if cross_attention_mask is not None:
|
||||||
|
cross_attention_mask, full_text_row_masked_out_mask = (
|
||||||
|
_prepare_cross_attention_mask(
|
||||||
|
cross_attention_mask,
|
||||||
|
num_vision_tokens=self.vision_model.num_patches,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
full_text_row_masked_out_mask = None
|
||||||
|
|
||||||
|
if cross_attention_mask is not None and cache_position is not None:
|
||||||
|
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
||||||
|
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
||||||
|
:, :, cache_position
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
cross_attention_mask=cross_attention_mask,
|
||||||
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
defer_logits_calculation=True, # enable deferred logits calculation
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.language_model.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
||||||
|
logits = hidden_states
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (loss,) + outputs if loss is not None else outputs
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=outputs.logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mllama(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.mllama import modeling_mllama
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
||||||
|
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
# patch the language model
|
||||||
|
maybe_model.language_model.forward = MethodType(
|
||||||
|
cce_forward, maybe_model.language_model
|
||||||
|
)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
|
|
||||||
|
# patch the causal language model
|
||||||
|
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
"""Cut Cross Entropy patcher"""
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
||||||
|
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
||||||
|
from cut_cross_entropy.transformers.llama import patch_llama
|
||||||
|
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
||||||
|
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
|
||||||
|
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
||||||
|
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
||||||
|
patch_cohere,
|
||||||
|
patch_cohere2,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
||||||
|
patch_gemma2,
|
||||||
|
patch_gemma3,
|
||||||
|
patch_gemma3_text,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
||||||
|
patch_mistral,
|
||||||
|
patch_mistral3,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
||||||
|
|
||||||
|
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
||||||
|
"llama": patch_llama,
|
||||||
|
"mllama": patch_mllama,
|
||||||
|
"phi3": patch_phi3,
|
||||||
|
"gemma": patch_gemma,
|
||||||
|
"gemma2": patch_gemma2,
|
||||||
|
"gemma3": patch_gemma3,
|
||||||
|
"gemma3_text": patch_gemma3_text,
|
||||||
|
"mistral": patch_mistral,
|
||||||
|
"mistral3": patch_mistral3,
|
||||||
|
"qwen2": patch_qwen2,
|
||||||
|
"cohere": patch_cohere,
|
||||||
|
"cohere2": patch_cohere2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def cce_patch(
|
||||||
|
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
||||||
|
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
||||||
|
reduction: str = "mean",
|
||||||
|
filter_eps: float | str | None = "auto",
|
||||||
|
accum_e_fp32: bool = False,
|
||||||
|
accum_c_fp32: bool = False,
|
||||||
|
filter_e_grad: bool = True,
|
||||||
|
filter_c_grad: bool = True,
|
||||||
|
train_only: bool = False,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
if isinstance(impl, LinearCrossEntropyImpl):
|
||||||
|
impl = impl.name.lower()
|
||||||
|
|
||||||
|
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
||||||
|
raise ValueError(f"Unknown {impl=}")
|
||||||
|
|
||||||
|
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
||||||
|
model_type = model_type_or_model.config.model_type
|
||||||
|
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
||||||
|
model_type = model_type_or_model.model_type
|
||||||
|
else:
|
||||||
|
model_type = model_type_or_model
|
||||||
|
|
||||||
|
patch_options = PatchOptions(
|
||||||
|
impl=impl,
|
||||||
|
reduction=reduction,
|
||||||
|
filter_eps=filter_eps,
|
||||||
|
accum_e_fp32=accum_e_fp32,
|
||||||
|
accum_c_fp32=accum_c_fp32,
|
||||||
|
filter_e_grad=filter_e_grad,
|
||||||
|
filter_c_grad=filter_c_grad,
|
||||||
|
train_only=train_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
||||||
|
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
||||||
|
model_type_or_model, patch_options
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(f"Unknown model type {model_type}")
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
"""Monkeypatch for apply_lce to add softcap."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from cut_cross_entropy import linear_cross_entropy
|
||||||
|
from cut_cross_entropy.transformers.utils import PatchOptions
|
||||||
|
|
||||||
|
|
||||||
|
def apply_lce(
|
||||||
|
e: torch.Tensor,
|
||||||
|
c: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
opts: PatchOptions,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
softcap: float | None = None,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Monkey patch for apply_lce to support softcap kwarg."""
|
||||||
|
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
||||||
|
cce_kwargs = opts.to_kwargs()
|
||||||
|
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
||||||
|
cce_kwargs["reduction"] = "sum"
|
||||||
|
else:
|
||||||
|
num_items_in_batch = None
|
||||||
|
|
||||||
|
loss = linear_cross_entropy(
|
||||||
|
e,
|
||||||
|
c,
|
||||||
|
labels.to(e.device),
|
||||||
|
bias=bias,
|
||||||
|
shift=True,
|
||||||
|
softcap=softcap,
|
||||||
|
**cce_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_items_in_batch is not None:
|
||||||
|
loss = loss / num_items_in_batch
|
||||||
|
|
||||||
|
return loss
|
||||||
@@ -20,6 +20,26 @@ liger_layer_norm: true
|
|||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
- deepseek_v2
|
||||||
|
- gemma
|
||||||
|
- gemma2
|
||||||
|
- gemma3 (partial support, no support for FLCE yet)
|
||||||
|
- granite
|
||||||
|
- jamba
|
||||||
|
- llama
|
||||||
|
- mistral
|
||||||
|
- mixtral
|
||||||
|
- mllama
|
||||||
|
- mllama_text_model
|
||||||
|
- olmo2
|
||||||
|
- paligemma
|
||||||
|
- phi3
|
||||||
|
- qwen2
|
||||||
|
- qwen2_5_vl
|
||||||
|
- qwen2_vl
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
```bib
|
```bib
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
@@ -41,11 +42,18 @@ class LigerPlugin(BasePlugin):
|
|||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||||
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
|
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||||
@@ -82,6 +90,8 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
from transformers.loss.loss_utils import nn
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
@@ -104,13 +114,51 @@ class LigerPlugin(BasePlugin):
|
|||||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||||
# nn.CrossEntropyLoss in the forward method.
|
# nn.CrossEntropyLoss in the forward method.
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
|
||||||
|
from transformers.models.gemma3 import modeling_gemma3
|
||||||
|
|
||||||
|
if cfg.liger_rope:
|
||||||
|
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
|
if cfg.liger_rms_norm:
|
||||||
|
|
||||||
|
def _liger_rms_norm_wrapper(dim, **kwargs):
|
||||||
|
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
|
||||||
|
return LigerRMSNorm(hidden_size=dim, **kwargs)
|
||||||
|
|
||||||
|
modeling_gemma3.Gemma3RMSNorm = partial(
|
||||||
|
_liger_rms_norm_wrapper,
|
||||||
|
offset=1.0,
|
||||||
|
casting_mode="gemma",
|
||||||
|
init_fn="zeros",
|
||||||
|
in_place=False,
|
||||||
|
)
|
||||||
|
if cfg.liger_glu_activation:
|
||||||
|
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
||||||
|
if cfg.liger_layer_norm:
|
||||||
|
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
|
||||||
|
|
||||||
|
if cfg.liger_cross_entropy:
|
||||||
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
|
|
||||||
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Fused linear cross entropy is not yet supported for Gemma3."
|
||||||
|
)
|
||||||
|
elif cfg.model_config_type in ["deepseek_v3"]:
|
||||||
|
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
||||||
|
|||||||
89
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
89
src/axolotl/monkeypatch/attention/ring_attn.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""
|
||||||
|
Ring attention group registration and flash attention patching.
|
||||||
|
|
||||||
|
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||||
|
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||||
|
their sequence parallel version of Flash Attention 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
RING_ATTN_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||||
|
"""
|
||||||
|
Getter for ring attention group on this rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The process group for ring attention for this rank.
|
||||||
|
"""
|
||||||
|
return RING_ATTN_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||||
|
"""
|
||||||
|
Setter for ring attention group on this rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Process group for ring attention.
|
||||||
|
"""
|
||||||
|
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||||
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
|
def register_ring_attn(sequence_parallel_degree: int):
|
||||||
|
"""
|
||||||
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_parallel_degree: Sequence parallelism factor.
|
||||||
|
"""
|
||||||
|
LOG.info(
|
||||||
|
"Enabling ring attention sequence parallelism: "
|
||||||
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
|
)
|
||||||
|
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
assert sequence_parallel_degree <= world_size, (
|
||||||
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
|
f"must be less than or equal to world_size ({world_size})"
|
||||||
|
)
|
||||||
|
assert world_size % sequence_parallel_degree == 0, (
|
||||||
|
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||||
|
f"must evenly divide world_size ({world_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Detailed logging of group formation
|
||||||
|
rank = dist.get_rank()
|
||||||
|
group_assignments = {}
|
||||||
|
|
||||||
|
for i in range(world_size // sequence_parallel_degree):
|
||||||
|
ring_attn_ranks = list(
|
||||||
|
range(
|
||||||
|
i * sequence_parallel_degree,
|
||||||
|
(i + 1) * sequence_parallel_degree,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||||
|
|
||||||
|
# Track which GPUs are in which groups
|
||||||
|
for r in ring_attn_ranks:
|
||||||
|
group_assignments[r] = i
|
||||||
|
|
||||||
|
if rank in ring_attn_ranks:
|
||||||
|
set_ring_attn_group(group)
|
||||||
|
|
||||||
|
# Log the GPU group assignments
|
||||||
|
if rank == 0:
|
||||||
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
|
|
||||||
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|
||||||
|
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||||
@@ -22,6 +22,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"phi3",
|
"phi3",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
|
"gemma3_text",
|
||||||
|
"cohere",
|
||||||
|
"cohere2",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"deepseek_v2",
|
"deepseek_v2",
|
||||||
|
|||||||
278
src/axolotl/processing_strategies.py
Normal file
278
src/axolotl/processing_strategies.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from PIL.Image import Resampling
|
||||||
|
from torch import Tensor
|
||||||
|
from transformers import ProcessorMixin
|
||||||
|
from transformers.image_utils import load_image
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingStrategy:
|
||||||
|
"""Base Processing Strategy class"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
self.processor = processor
|
||||||
|
self.chat_template = chat_template
|
||||||
|
self.image_token = None
|
||||||
|
self.image_token_id = None
|
||||||
|
|
||||||
|
self.image_size = image_size
|
||||||
|
self.image_resize_algorithm = (
|
||||||
|
image_resize_algorithm or Image.Resampling.BILINEAR
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(processor, "image_token"):
|
||||||
|
self.image_token = processor.image_token
|
||||||
|
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
self.image_token
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, examples: list[dict]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Preprocess conversation examples to ensure consistent format.
|
||||||
|
Converts different conversation formats to OpenAI format with 'messages'.
|
||||||
|
Supports two formats:
|
||||||
|
1. OpenAI format with 'messages'
|
||||||
|
2. Legacy format with 'conversations'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: list of conversation dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of dicts in OpenAI format with 'messages' key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the conversation format is not supported
|
||||||
|
"""
|
||||||
|
role_mapping = {
|
||||||
|
"human": "user",
|
||||||
|
"gpt": "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
|
def normalize_role(role: str) -> str:
|
||||||
|
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||||
|
return role_mapping.get(role, role)
|
||||||
|
|
||||||
|
def convert_legacy_format(example: dict) -> dict:
|
||||||
|
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||||
|
messages = [
|
||||||
|
{"role": normalize_role(convo["from"]), "content": convo["value"]}
|
||||||
|
for convo in example["conversations"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create new dict without 'conversations' key
|
||||||
|
result = deepcopy(example)
|
||||||
|
result.pop("conversations")
|
||||||
|
result["messages"] = messages
|
||||||
|
return result
|
||||||
|
|
||||||
|
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
|
||||||
|
"""Convert regular messages format to Messages format with content type"""
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": message["role"],
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": message["content"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(message["content"], list):
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
new_messages.append(
|
||||||
|
{
|
||||||
|
"role": message["role"],
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
processed_examples = []
|
||||||
|
for example in examples:
|
||||||
|
if not ("messages" in example or "conversations" in example):
|
||||||
|
raise ValueError(
|
||||||
|
"Only `messages` and `conversations` message keys are currently supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_example = None
|
||||||
|
if "messages" in example: # OpenAI format
|
||||||
|
processed_example = example
|
||||||
|
else: # Legacy format
|
||||||
|
processed_example = convert_legacy_format(example)
|
||||||
|
|
||||||
|
# convert regular messages format to Messages format with content type
|
||||||
|
# for compatibility with apply_chat_template
|
||||||
|
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
||||||
|
processed_example["messages"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# find the image key if it exists
|
||||||
|
possible_image_keys = ["images", "image"]
|
||||||
|
image_key = None
|
||||||
|
for key in possible_image_keys:
|
||||||
|
if key in processed_example:
|
||||||
|
image_key = key
|
||||||
|
break
|
||||||
|
|
||||||
|
# if the image key exists, add the image to the first message
|
||||||
|
if image_key is not None:
|
||||||
|
# TODO: check if it's normal to be single image only for common datasets
|
||||||
|
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||||
|
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
|
||||||
|
image_value = processed_example[image_key][0]
|
||||||
|
|
||||||
|
# Handle image loading (Image, url, path, base64)
|
||||||
|
image_value = load_image(image_value)
|
||||||
|
|
||||||
|
if self.image_size is not None:
|
||||||
|
assert hasattr(
|
||||||
|
image_value, "resize"
|
||||||
|
), "Image does not have a resize method"
|
||||||
|
|
||||||
|
if isinstance(self.image_size, tuple):
|
||||||
|
image_value = image_value.resize(
|
||||||
|
self.image_size, self.image_resize_algorithm
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Set the padding value; here we use black (0, 0, 0) for RGB images
|
||||||
|
padding_color = (0, 0, 0)
|
||||||
|
|
||||||
|
# When image_size is an int (square target), preserve aspect ratio then pad
|
||||||
|
# This is to prevent aspect ratio distortion when resizing to square
|
||||||
|
image_value = ImageOps.pad(
|
||||||
|
image_value,
|
||||||
|
(self.image_size, self.image_size),
|
||||||
|
method=self.image_resize_algorithm,
|
||||||
|
color=padding_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look for any image type in the first message
|
||||||
|
# some dataset have an {type: "image"} in the first message
|
||||||
|
ind_to_add = None
|
||||||
|
|
||||||
|
for i, content in enumerate(
|
||||||
|
processed_example["messages"][0]["content"]
|
||||||
|
):
|
||||||
|
# Usually datasets created with image columns, don't have it in the messages itself
|
||||||
|
if content["type"] == "image" and all(
|
||||||
|
k not in content for k in ["image", "url", "path", "base64"]
|
||||||
|
):
|
||||||
|
ind_to_add = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# If an image type is found, add the image to that index
|
||||||
|
if ind_to_add is not None:
|
||||||
|
processed_example["messages"][0]["content"][ind_to_add][
|
||||||
|
"image"
|
||||||
|
] = image_value
|
||||||
|
else:
|
||||||
|
# if no image type is found, add it to end of the first message
|
||||||
|
processed_example["messages"][0]["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": image_value,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_examples.append(processed_example)
|
||||||
|
|
||||||
|
return processed_examples
|
||||||
|
|
||||||
|
def process_labels(self, input_ids: Tensor) -> Tensor:
|
||||||
|
labels = input_ids.clone()
|
||||||
|
|
||||||
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||||
|
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||||
|
|
||||||
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
labels[labels == self.image_token_id] = -100
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLProcessingStrategy(ProcessingStrategy):
|
||||||
|
"""Processing Strategy class for Qwen2-VL"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||||
|
self.image_token = "<|image_pad|>" # nosec
|
||||||
|
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
self.image_token
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3ProcessingStrategy(ProcessingStrategy):
|
||||||
|
"""Processing Strategy class for Gemma3"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||||
|
self.image_token = processor.tokenizer.special_tokens_map["boi_token"]
|
||||||
|
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
self.image_token
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_labels(self, input_ids):
|
||||||
|
labels = input_ids.clone()
|
||||||
|
|
||||||
|
# Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora
|
||||||
|
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||||
|
labels[labels == self.image_token_id] = -100
|
||||||
|
labels[labels == 262144] = -100 # corresponds to <image_soft_token>
|
||||||
|
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def get_processing_strategy(
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template,
|
||||||
|
chat_template_type,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
if chat_template_type == "qwen2_vl":
|
||||||
|
return Qwen2VLProcessingStrategy(
|
||||||
|
processor, chat_template, image_size, image_resize_algorithm
|
||||||
|
)
|
||||||
|
if chat_template_type == "gemma3":
|
||||||
|
return Gemma3ProcessingStrategy(
|
||||||
|
processor, chat_template, image_size, image_resize_algorithm
|
||||||
|
)
|
||||||
|
if chat_template_type in [
|
||||||
|
"llama3_2_vision",
|
||||||
|
"llava",
|
||||||
|
"mistral_v7_tekken",
|
||||||
|
"pixtral",
|
||||||
|
]:
|
||||||
|
return ProcessingStrategy(
|
||||||
|
processor, chat_template, image_size, image_resize_algorithm
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
|
||||||
@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
|||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -411,11 +411,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
if turn_idx >= len(turns):
|
if turn_idx >= len(turns):
|
||||||
raise ValueError(f"Turn index {turn_idx} out of range")
|
raise ValueError(f"Turn index {turn_idx} out of range")
|
||||||
|
|
||||||
# mistral does not output message if it contains only system message
|
# mistral/gemma3 does not output message if it contains only system message
|
||||||
if (
|
if (
|
||||||
turn_idx == 0
|
turn_idx == 0
|
||||||
and turns[0].get("role") == "system"
|
and turns[0].get("role") == "system"
|
||||||
and "mistral" in self.tokenizer.name_or_path.lower()
|
and (
|
||||||
|
"mistral" in self.tokenizer.name_or_path.lower()
|
||||||
|
# gemma3 uses gemma tokenizer
|
||||||
|
or "gemma" in self.tokenizer.name_or_path.lower()
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return -1, -1
|
return -1, -1
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
def default(
|
def default(
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import transformers.modelcard
|
|||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||||
from peft import PeftConfig, PeftModel
|
from peft import PeftConfig, PeftModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
@@ -169,7 +170,7 @@ def execute_training(
|
|||||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute the training process with appropriate backend configurations.
|
Execute the training process with appropriate SDP kernel configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
@@ -177,9 +178,6 @@ def execute_training(
|
|||||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||||
"""
|
"""
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||||
@@ -305,7 +303,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
|||||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
trainer.create_model_card(**model_card_kwarg)
|
trainer.create_model_card(**model_card_kwarg)
|
||||||
except (AttributeError, UnicodeDecodeError):
|
except (AttributeError, UnicodeDecodeError, OfflineModeIsEnabled):
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# Defensively push to the hub to ensure the model card is updated
|
# Defensively push to the hub to ensure the model card is updated
|
||||||
@@ -317,6 +315,7 @@ def save_initial_configs(
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
peft_config: PeftConfig | None,
|
peft_config: PeftConfig | None,
|
||||||
|
processor: ProcessorMixin | None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save initial configurations before training.
|
Save initial configurations before training.
|
||||||
@@ -344,6 +343,10 @@ def save_initial_configs(
|
|||||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||||
model.config.save_pretrained(str(output_dir))
|
model.config.save_pretrained(str(output_dir))
|
||||||
|
|
||||||
|
if processor:
|
||||||
|
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
||||||
|
processor.save_pretrained(str(output_dir))
|
||||||
|
|
||||||
|
|
||||||
def setup_model_card(cfg: DictDefault):
|
def setup_model_card(cfg: DictDefault):
|
||||||
"""
|
"""
|
||||||
@@ -411,6 +414,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PeftConfig | None,
|
PeftConfig | None,
|
||||||
|
ProcessorMixin | None,
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
||||||
@@ -426,6 +430,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
- Model
|
- Model
|
||||||
- Tokenizer
|
- Tokenizer
|
||||||
- PEFT config
|
- PEFT config
|
||||||
|
- Processor
|
||||||
"""
|
"""
|
||||||
# Load tokenizer, processor and model
|
# Load tokenizer, processor and model
|
||||||
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
||||||
@@ -456,6 +461,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
peft_config,
|
peft_config,
|
||||||
|
processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -478,6 +484,7 @@ def train(
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
peft_config,
|
peft_config,
|
||||||
|
processor,
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
# Determine if we need to resume from a checkpoint
|
# Determine if we need to resume from a checkpoint
|
||||||
@@ -493,7 +500,7 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save initial configs
|
# Save initial configs
|
||||||
save_initial_configs(cfg, tokenizer, model, peft_config)
|
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
||||||
|
|
||||||
# Set up signal handler for graceful termination
|
# Set up signal handler for graceful termination
|
||||||
setup_signal_handler(cfg, model, safe_serialization)
|
setup_signal_handler(cfg, model, safe_serialization)
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from trl.models import unwrap_model_for_generation
|
|||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
broadcast_dict,
|
broadcast_dict,
|
||||||
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
|
|||||||
is_main_process,
|
is_main_process,
|
||||||
zero_first,
|
zero_first,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,14 +1,59 @@
|
|||||||
"""
|
"""
|
||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
||||||
|
includes logic for handling sequence parallelism collation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_position_ids_for_slice(
|
||||||
|
position_ids: torch.Tensor, start_idx: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
||||||
|
This handles the case where position IDs might not be contiguous due to sample
|
||||||
|
packing.
|
||||||
|
"""
|
||||||
|
# Convert to tensor if not already
|
||||||
|
# Find the boundaries between samples (where position_ids reset)
|
||||||
|
adjusted_pos_ids = position_ids.clone()
|
||||||
|
|
||||||
|
# Process each sequence in the batch
|
||||||
|
for i in range(position_ids.shape[0]):
|
||||||
|
seq = position_ids[i]
|
||||||
|
|
||||||
|
# Find sample boundaries
|
||||||
|
boundaries = []
|
||||||
|
for j in range(1, len(seq)):
|
||||||
|
if seq[j] < seq[j - 1]:
|
||||||
|
boundaries.append(j)
|
||||||
|
|
||||||
|
# No need to adjust if there are no boundaries or this is a single sample
|
||||||
|
if not boundaries:
|
||||||
|
adjusted_pos_ids[i] = seq - start_idx
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Adjust each segment separately
|
||||||
|
prev_boundary = 0
|
||||||
|
for boundary in boundaries:
|
||||||
|
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
||||||
|
prev_boundary = boundary
|
||||||
|
|
||||||
|
# Last segment
|
||||||
|
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
||||||
|
|
||||||
|
return adjusted_pos_ids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -43,6 +88,8 @@ class DataCollatorForSeq2Seq:
|
|||||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
return_tensors (`str`):
|
return_tensors (`str`):
|
||||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
|
sequence_parallel_degree (`int`):
|
||||||
|
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
@@ -53,6 +100,16 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
sequence_parallel_degree: int = 1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
# Get information about our position in the SP group
|
||||||
|
sp_group = get_ring_attn_group()
|
||||||
|
self.local_rank = dist.get_rank(group=sp_group)
|
||||||
|
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -119,8 +176,43 @@ class DataCollatorForSeq2Seq:
|
|||||||
)
|
)
|
||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
|
if self.sequence_parallel_degree > 1:
|
||||||
|
features = self.apply_sequence_parallelism(features)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
def apply_sequence_parallelism(
|
||||||
|
self, batch: dict[str, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply sequence parallelism slicing to a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Batch dictionary from parent collator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sliced batch dictionary.
|
||||||
|
"""
|
||||||
|
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
||||||
|
|
||||||
|
for key in keys_to_slice:
|
||||||
|
if key in batch:
|
||||||
|
seq_len = batch[key].shape[1]
|
||||||
|
slice_size = seq_len // self.local_world_size
|
||||||
|
start_idx = self.local_rank * slice_size
|
||||||
|
end_idx = (
|
||||||
|
start_idx + slice_size
|
||||||
|
if self.local_rank < self.local_world_size - 1
|
||||||
|
else seq_len
|
||||||
|
)
|
||||||
|
batch[key] = batch[key][:, start_idx:end_idx]
|
||||||
|
|
||||||
|
# Special handling for position_ids
|
||||||
|
if key == "position_ids" and self.local_rank > 0:
|
||||||
|
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
@@ -148,6 +240,7 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@@ -177,6 +270,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,15 +2,17 @@
|
|||||||
Collators for multi-modal chat messages and packing
|
Collators for multi-modal chat messages and packing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from PIL import Image
|
import torch
|
||||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
from torch import Tensor
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.data.data_collator import DataCollatorMixin
|
from transformers.data.data_collator import DataCollatorMixin
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
from axolotl.processing_strategies import ProcessingStrategy
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiModalChatDataCollator(DataCollatorMixin):
|
class MultiModalChatDataCollator(DataCollatorMixin):
|
||||||
@@ -19,11 +21,9 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
processor: ProcessorMixin
|
processing_strategy: ProcessingStrategy
|
||||||
return_tensors: str = "pt"
|
|
||||||
chat_template: Optional[str] = None
|
|
||||||
packing: bool = False
|
packing: bool = False
|
||||||
max_images: int = -1
|
return_tensors: str = "pt"
|
||||||
padding: Union[bool, str, PaddingStrategy] = True
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
@@ -31,162 +31,62 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
if self.packing:
|
if self.packing:
|
||||||
raise ValueError("Packing is currently not supported.")
|
raise ValueError("Packing is currently not supported.")
|
||||||
|
|
||||||
def torch_call(
|
def torch_call(self, examples: list[dict]) -> dict[str, Any]:
|
||||||
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
return self.process_rows(examples)
|
||||||
) -> dict[str, Any]:
|
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
|
||||||
|
|
||||||
return self.__class__.process_rows(
|
|
||||||
examples, self.processor, self.chat_template, self.max_images
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
|
||||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
|
||||||
# see also DataCollatorWithFlattening and DefaultDataCollator
|
|
||||||
|
|
||||||
# *** This is COPIED from the trl example sft_vlm.py code ***
|
|
||||||
# use this as a starting point
|
|
||||||
|
|
||||||
def _preprocess(examples: list[dict]) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Preprocess conversation examples to ensure consistent format.
|
|
||||||
|
|
||||||
Converts different conversation formats to OpenAI format with 'messages'.
|
|
||||||
Supports two formats:
|
|
||||||
1. OpenAI format with 'messages'
|
|
||||||
2. Legacy format with 'conversations'
|
|
||||||
|
|
||||||
Args:
|
|
||||||
examples: list of conversation dictionaries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict in OpenAI format with 'messages' key
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the conversation format is not supported
|
|
||||||
"""
|
|
||||||
role_mapping = {
|
|
||||||
"human": "user",
|
|
||||||
"gpt": "assistant",
|
|
||||||
}
|
|
||||||
|
|
||||||
def normalize_role(role: str) -> str:
|
|
||||||
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
|
||||||
return role_mapping.get(role, role)
|
|
||||||
|
|
||||||
def convert_legacy_format(example: dict) -> dict:
|
|
||||||
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": normalize_role(convo["from"]),
|
|
||||||
"content": convo["value"],
|
|
||||||
}
|
|
||||||
for convo in example["conversations"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create new dict without 'conversations' key
|
|
||||||
result = deepcopy(example)
|
|
||||||
result.pop("conversations")
|
|
||||||
return {"messages": messages, **result}
|
|
||||||
|
|
||||||
processed_examples = []
|
|
||||||
for example in examples:
|
|
||||||
# OpenAI format
|
|
||||||
if "messages" in example:
|
|
||||||
processed_examples.append(example)
|
|
||||||
|
|
||||||
# Legacy format
|
|
||||||
elif "conversations" in example:
|
|
||||||
processed_examples.append(convert_legacy_format(example))
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Only `messages` and `conversations` message keys are currently supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
return processed_examples
|
|
||||||
|
|
||||||
def _process_images(examples, max_images):
|
|
||||||
"""
|
|
||||||
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
examples: List of dictionaries that may contain 'images' key
|
|
||||||
max_images: Maximum number of images to keep per example (0 means no limit)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Either None (if no images) or List[Image objects] (if all examples have images)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If there's a mix of None and non-None images
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_image(example):
|
|
||||||
if "images" not in example:
|
|
||||||
return None
|
|
||||||
images = example["images"]
|
|
||||||
if isinstance(images, str):
|
|
||||||
return Image.open(images)
|
|
||||||
return images
|
|
||||||
|
|
||||||
images = [get_image(example) for example in examples]
|
|
||||||
|
|
||||||
# Count None and non-None images
|
|
||||||
none_count = sum(1 for img in images if img is None)
|
|
||||||
|
|
||||||
# All images are None
|
|
||||||
if none_count == len(images):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Mix of None and non-None images
|
|
||||||
if none_count > 0:
|
|
||||||
raise ValueError(
|
|
||||||
"All images should be either None or not None. "
|
|
||||||
"Please provide images for all examples or None."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply max_images limit if specified
|
|
||||||
if max_images > 0:
|
|
||||||
images = [
|
|
||||||
(
|
|
||||||
img_batch[:max_images]
|
|
||||||
if isinstance(img_batch, (list, tuple))
|
|
||||||
else img_batch
|
|
||||||
)
|
|
||||||
for img_batch in images
|
|
||||||
]
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
def process_rows(
|
||||||
|
self,
|
||||||
|
examples: list[dict],
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
# Preprocess the examples
|
# Preprocess the examples
|
||||||
examples = _preprocess(examples)
|
examples = self.processing_strategy(examples)
|
||||||
|
|
||||||
# Get the texts and images, and apply the chat template
|
# Initialize batch
|
||||||
texts = [
|
batch: dict[str, Any] = {}
|
||||||
processor.apply_chat_template(
|
|
||||||
example["messages"], chat_template=chat_template, tokenize=False
|
# Process each example
|
||||||
|
for example in examples:
|
||||||
|
# Apply chat template to process the example
|
||||||
|
# This method requires transformers>=4.49.0
|
||||||
|
result = self.processing_strategy.processor.apply_chat_template(
|
||||||
|
example["messages"],
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
return_dict=True,
|
||||||
|
chat_template=self.processing_strategy.chat_template,
|
||||||
)
|
)
|
||||||
for example in examples
|
|
||||||
]
|
|
||||||
|
|
||||||
images = _process_images(examples, max_images=max_images)
|
# TODO: Check if need handling for len(input_ids) > sequence_len
|
||||||
|
|
||||||
# Tokenize the texts and process the images
|
# Add the processed tensors to our batch
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
for key in result.keys():
|
||||||
|
if key not in batch:
|
||||||
|
batch[key] = []
|
||||||
|
|
||||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
batch[key].append(result[key].squeeze(0))
|
||||||
labels = batch["input_ids"].clone()
|
|
||||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
# Pad sequences to the same length
|
||||||
# Ignore the image token index in the loss computation (model specific)
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
batch["input_ids"],
|
||||||
processor.image_token
|
batch_first=True,
|
||||||
|
padding_value=self.tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
labels[labels == image_token_id] = -100
|
|
||||||
batch["labels"] = labels
|
|
||||||
|
|
||||||
if length_only:
|
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||||
return {
|
batch["attention_mask"], batch_first=True, padding_value=0
|
||||||
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
)
|
||||||
}
|
|
||||||
return batch
|
# Create the final batch
|
||||||
|
final_batch = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process the labels
|
||||||
|
final_batch["labels"] = self.processing_strategy.process_labels(
|
||||||
|
final_batch["input_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_batch
|
||||||
|
|||||||
@@ -12,19 +12,13 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config
|
||||||
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||||
)
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
DPODataset,
|
|
||||||
KTODataset,
|
|
||||||
SFTDataset,
|
|
||||||
)
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.models import load_model_config
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -131,6 +125,9 @@ def normalize_config(cfg):
|
|||||||
with open(ds_config_path, encoding="utf-8") as f:
|
with open(ds_config_path, encoding="utf-8") as f:
|
||||||
cfg.deepspeed = json.load(f)
|
cfg.deepspeed = json.load(f)
|
||||||
|
|
||||||
|
if cfg.sequence_parallel_degree is None:
|
||||||
|
cfg.sequence_parallel_degree = 1
|
||||||
|
|
||||||
if cfg.saves_per_epoch:
|
if cfg.saves_per_epoch:
|
||||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
if save_steps < 1.0: # prevent saves on every step
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
@@ -161,7 +158,7 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
cfg.is_multimodal = (
|
cfg.is_multimodal = (
|
||||||
hasattr(model_config, "model_type")
|
hasattr(model_config, "model_type")
|
||||||
and model_config.model_type in ["llava", "mllama"]
|
and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING
|
||||||
or any(
|
or any(
|
||||||
multimodal_name in cfg.base_model.lower()
|
multimodal_name in cfg.base_model.lower()
|
||||||
for multimodal_name in [
|
for multimodal_name in [
|
||||||
@@ -174,7 +171,6 @@ def normalize_config(cfg):
|
|||||||
cfg.processor_config = (
|
cfg.processor_config = (
|
||||||
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
||||||
)
|
)
|
||||||
model_config = model_config.text_config
|
|
||||||
|
|
||||||
cfg.model_config_type = model_config.model_type
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,12 @@ from pathlib import Path
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
from huggingface_hub.errors import HFValidationError
|
from huggingface_hub.errors import (
|
||||||
|
HFValidationError,
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -70,20 +74,25 @@ def load_dataset_w_config(
|
|||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
|
||||||
try:
|
try:
|
||||||
# this is just a basic check to see if the path is a
|
# this is just a basic check to see if the path is a
|
||||||
# valid HF dataset that's loadable
|
# valid HF dataset that's loadable
|
||||||
load_dataset(
|
snapshot_download(
|
||||||
config_dataset.path,
|
repo_id=config_dataset.path,
|
||||||
name=config_dataset.name,
|
repo_type="dataset",
|
||||||
streaming=True,
|
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
trust_remote_code=ds_trust_remote_code,
|
ignore_patterns=["*"],
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
except (
|
||||||
|
RepositoryNotFoundError,
|
||||||
|
RevisionNotFoundError,
|
||||||
|
FileNotFoundError,
|
||||||
|
ConnectionError,
|
||||||
|
HFValidationError,
|
||||||
|
ValueError,
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ds_from_cloud = False
|
ds_from_cloud = False
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import types
|
import types
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
@@ -25,7 +25,7 @@ from peft import (
|
|||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import ( # noqa: F401
|
from transformers import (
|
||||||
AddedToken,
|
AddedToken,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -34,12 +34,17 @@ from transformers import ( # noqa: F401
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AwqConfig,
|
AwqConfig,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
|
Gemma3ForConditionalGeneration,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
|
Mistral3ForConditionalGeneration,
|
||||||
MllamaForConditionalGeneration,
|
MllamaForConditionalGeneration,
|
||||||
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
Qwen2VLForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import (
|
from transformers.integrations.deepspeed import (
|
||||||
HfTrainerDeepSpeedConfig,
|
HfTrainerDeepSpeedConfig,
|
||||||
@@ -67,7 +72,16 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
|
|||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||||
|
"mllama": MllamaForConditionalGeneration,
|
||||||
|
"llava": LlavaForConditionalGeneration,
|
||||||
|
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
||||||
|
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
||||||
|
"mistral3": Mistral3ForConditionalGeneration,
|
||||||
|
"gemma3": Gemma3ForConditionalGeneration,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||||
@@ -94,9 +108,30 @@ def get_module_class_from_name(module, name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||||
|
# Set use_cache to False
|
||||||
|
if hasattr(model_config, "use_cache"):
|
||||||
|
model_config.use_cache = False
|
||||||
|
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
model_config = model_config.text_config
|
# For multimodal configs, use_cache is set in the text_config
|
||||||
|
if hasattr(model_config, "get_text_config"):
|
||||||
|
text_config = model_config.get_text_config()
|
||||||
|
if hasattr(text_config, "use_cache"):
|
||||||
|
text_config.use_cache = False
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No text config found for multimodal model. Please raise an Issue with model details."
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if image_size is not set and load image size from model config if available
|
||||||
|
if (
|
||||||
|
cfg.image_size is None
|
||||||
|
and hasattr(model_config, "vision_config")
|
||||||
|
and hasattr(model_config.vision_config, "image_size")
|
||||||
|
):
|
||||||
|
cfg.image_size = model_config.vision_config.image_size
|
||||||
|
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
||||||
|
|
||||||
quant_config_exists = (
|
quant_config_exists = (
|
||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
@@ -435,6 +470,31 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Attempt to load image size from processor if available
|
||||||
|
if (
|
||||||
|
cfg.image_size is None
|
||||||
|
and hasattr(processor, "size")
|
||||||
|
and any(dim in processor.size for dim in ["width", "height"])
|
||||||
|
):
|
||||||
|
im_width = None
|
||||||
|
im_height = None
|
||||||
|
if "width" in processor.size:
|
||||||
|
im_width = processor.size["width"]
|
||||||
|
if "height" in processor.size:
|
||||||
|
im_height = processor.size["height"]
|
||||||
|
|
||||||
|
# If both width and height are set, use a tuple
|
||||||
|
if im_width is not None and im_height is not None:
|
||||||
|
cfg.image_size = (im_width, im_height)
|
||||||
|
# If only width is set, use as integer
|
||||||
|
elif im_width is not None:
|
||||||
|
cfg.image_size = im_width
|
||||||
|
# If only height is set, use as integer
|
||||||
|
elif im_height is not None:
|
||||||
|
cfg.image_size = im_height
|
||||||
|
|
||||||
|
LOG.debug(f"Loaded image size: {cfg.image_size} from processor")
|
||||||
|
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
@@ -471,12 +531,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
# init model config
|
# init model config
|
||||||
self.model_config = load_model_config(cfg)
|
self.model_config = load_model_config(cfg)
|
||||||
if cfg.is_multimodal:
|
|
||||||
self.text_model_config = self.model_config.text_config
|
|
||||||
else:
|
|
||||||
self.text_model_config = self.model_config
|
|
||||||
|
|
||||||
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
# load any patches from plugins
|
# load any patches from plugins
|
||||||
@@ -547,6 +603,14 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||||
|
|
||||||
|
# Initialize ring attn for sequence parallelism. This must be done after
|
||||||
|
# model init but before the first forward pass, since it modifies flash
|
||||||
|
# attn to use ring comm for SP training across multiple GPUs.
|
||||||
|
register_ring_attn(self.cfg.sequence_parallel_degree)
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||||
@@ -603,7 +667,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
def patch_llama_derived_model(self) -> None:
|
def patch_llama_derived_model(self):
|
||||||
"""Modify all llama derived models in one block"""
|
"""Modify all llama derived models in one block"""
|
||||||
self.patch_loss_llama()
|
self.patch_loss_llama()
|
||||||
|
|
||||||
@@ -653,25 +717,16 @@ class ModelLoader:
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_auto_model_loader(self) -> None:
|
def set_auto_model_loader(self):
|
||||||
"""set self.AutoModelLoader
|
"""
|
||||||
- default value: AutoModelForCausalLM (set at __init__)
|
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
- when using a multi modality model, self.AutoModelLoader should
|
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||||
be set according to model type of the model
|
should be set according to the type of the model.
|
||||||
"""
|
"""
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
if self.model_config.model_type == "llava":
|
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
self.model_config.model_type, AutoModelForVision2Seq
|
||||||
LlavaForConditionalGeneration
|
)
|
||||||
)
|
|
||||||
elif self.model_config.model_type == "mllama":
|
|
||||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
|
||||||
MllamaForConditionalGeneration
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.AutoModelLoader = (
|
|
||||||
AutoModelForVision2Seq # pylint: disable=invalid-name
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_device_map_config(self) -> None:
|
def set_device_map_config(self) -> None:
|
||||||
device_map = self.cfg.device_map
|
device_map = self.cfg.device_map
|
||||||
@@ -695,7 +750,7 @@ class ModelLoader:
|
|||||||
from accelerate import infer_auto_device_map
|
from accelerate import infer_auto_device_map
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model_canvas = self.AutoModelLoader.from_config(
|
model_canvas = self.auto_model_loader.from_config(
|
||||||
self.model_config,
|
self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
)
|
)
|
||||||
@@ -892,8 +947,6 @@ class ModelLoader:
|
|||||||
quantization_config = (
|
quantization_config = (
|
||||||
quantization_config or self.model_kwargs["quantization_config"]
|
quantization_config or self.model_kwargs["quantization_config"]
|
||||||
)
|
)
|
||||||
if self.cfg.is_multimodal:
|
|
||||||
self.model_config.text_config = self.text_model_config
|
|
||||||
self.model = load_sharded_model_quant(
|
self.model = load_sharded_model_quant(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
@@ -914,13 +967,26 @@ class ModelLoader:
|
|||||||
|
|
||||||
_ = _configure_zero3_memory_efficient_loading()
|
_ = _configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
if self.cfg.is_multimodal:
|
# Load model with random initialization if specified
|
||||||
self.model_config.text_config = self.text_model_config
|
if self.cfg.random_init_weights:
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
# AutoModel classes support the from_config method
|
||||||
self.base_model,
|
if self.auto_model_loader in [
|
||||||
config=self.model_config,
|
AutoModelForCausalLM,
|
||||||
**self.model_kwargs,
|
AutoModelForVision2Seq,
|
||||||
)
|
]:
|
||||||
|
self.model = self.auto_model_loader.from_config(
|
||||||
|
config=self.model_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = self.auto_model_loader(
|
||||||
|
config=self.model_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
|
self.base_model,
|
||||||
|
config=self.model_config,
|
||||||
|
**self.model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO (MengqingCao) split these patches seperately
|
# TODO (MengqingCao) split these patches seperately
|
||||||
if self.cfg.flash_attention and not self.inference:
|
if self.cfg.flash_attention and not self.inference:
|
||||||
@@ -955,10 +1021,8 @@ class ModelLoader:
|
|||||||
and self.model_type != "AutoModelForCausalLM"
|
and self.model_type != "AutoModelForCausalLM"
|
||||||
and not self.cfg.trust_remote_code
|
and not self.cfg.trust_remote_code
|
||||||
):
|
):
|
||||||
if self.cfg.is_multimodal:
|
|
||||||
self.model_config.text_config = self.text_model_config
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -972,26 +1036,8 @@ class ModelLoader:
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
|
||||||
# when training starts
|
|
||||||
if (
|
|
||||||
hasattr(self.text_model_config, "max_seq_len")
|
|
||||||
and self.text_model_config.max_seq_len
|
|
||||||
and self.cfg.sequence_len > self.text_model_config.max_seq_len
|
|
||||||
):
|
|
||||||
self.text_model_config.max_seq_len = self.cfg.sequence_len
|
|
||||||
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
|
||||||
elif (
|
|
||||||
hasattr(self.text_model_config, "max_sequence_length")
|
|
||||||
and self.text_model_config.max_sequence_length
|
|
||||||
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
|
|
||||||
):
|
|
||||||
self.text_model_config.max_sequence_length = self.cfg.sequence_len
|
|
||||||
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if self.cfg.is_multimodal:
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.model_config.text_config = self.text_model_config
|
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1009,9 +1055,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
_ = _configure_zero3_memory_efficient_loading()
|
_ = _configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
if self.cfg.is_multimodal:
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.model_config.text_config = self.text_model_config
|
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1174,7 +1218,9 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None:
|
if self.cfg.mean_resizing_embeddings is not None and not (
|
||||||
|
self.model_config.model_type == "llava"
|
||||||
|
):
|
||||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -1273,8 +1319,6 @@ class ModelLoader:
|
|||||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||||
if len(requires_grad) == 0:
|
if len(requires_grad) == 0:
|
||||||
LOG.warning("there are no parameters that require gradient updates")
|
LOG.warning("there are no parameters that require gradient updates")
|
||||||
if hasattr(self.model, "config"):
|
|
||||||
self.model.config.use_cache = False
|
|
||||||
|
|
||||||
if self.cfg.flash_optimum:
|
if self.cfg.flash_optimum:
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
@@ -1307,7 +1351,7 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
loader = ModelLoader(
|
model_loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@@ -1315,7 +1359,7 @@ def load_model(
|
|||||||
reference_model=reference_model,
|
reference_model=reference_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return loader.load_model()
|
return model_loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
|
|||||||
21
src/axolotl/utils/optimizers/soap/LICENSE
Normal file
21
src/axolotl/utils/optimizers/soap/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 Nikhil Vyas
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
495
src/axolotl/utils/optimizers/soap/__init__.py
Normal file
495
src/axolotl/utils/optimizers/soap/__init__.py
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
# Copied from https://github.com/nikhilvyas/SOAP
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
# Parts of the code are modifications of Pytorch's AdamW optimizer
|
||||||
|
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
|
||||||
|
|
||||||
|
|
||||||
|
class SOAP(optim.Optimizer):
|
||||||
|
"""
|
||||||
|
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
params (`Iterable[nn.parameter.Parameter]`):
|
||||||
|
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||||
|
lr (`float`, *optional*, defaults to 0.003):
|
||||||
|
The learning rate to use.
|
||||||
|
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
|
||||||
|
Adam's betas parameters (b1, b2).
|
||||||
|
shampoo_beta (`float`, *optional*, defaults to -1):
|
||||||
|
If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1].
|
||||||
|
eps (`float`, *optional*, defaults to 1e-08):
|
||||||
|
Adam's epsilon for numerical stability.
|
||||||
|
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
|
||||||
|
precondition_frequency (`int`, *optional*, defaults to 10):
|
||||||
|
How often to update the preconditioner.
|
||||||
|
max_precond_dim (`int`, *optional*, defaults to 10000):
|
||||||
|
Maximum dimension of the preconditioner.
|
||||||
|
Set to 10000, so that we exclude most common vocab sizes while including layers.
|
||||||
|
merge_dims (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to merge dimensions of the preconditioner.
|
||||||
|
precondition_1d (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to precondition 1D gradients.
|
||||||
|
normalize_grads (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to normalize gradients per layer.
|
||||||
|
Helps at large precondition_frequency (~100 in our experiments),
|
||||||
|
but hurts performance at small precondition_frequency (~10 in our experiments).
|
||||||
|
data_format (`str`, *optional*, defaults to `channels_first`):
|
||||||
|
Data format of the input for convolutional layers.
|
||||||
|
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
|
||||||
|
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to use bias correction in Adam.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr: float = 3e-3,
|
||||||
|
betas=(0.95, 0.95),
|
||||||
|
shampoo_beta: float = -1,
|
||||||
|
eps: float = 1e-8,
|
||||||
|
weight_decay: float = 0.01,
|
||||||
|
precondition_frequency: int = 10,
|
||||||
|
max_precond_dim: int = 10000, #
|
||||||
|
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
|
||||||
|
precondition_1d: bool = False,
|
||||||
|
normalize_grads: bool = False,
|
||||||
|
data_format: str = "channels_first",
|
||||||
|
correct_bias: bool = True,
|
||||||
|
):
|
||||||
|
defaults = {
|
||||||
|
"lr": lr,
|
||||||
|
"betas": betas,
|
||||||
|
"shampoo_beta": shampoo_beta,
|
||||||
|
"eps": eps,
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
"precondition_frequency": precondition_frequency,
|
||||||
|
"max_precond_dim": max_precond_dim,
|
||||||
|
"merge_dims": merge_dims,
|
||||||
|
"precondition_1d": precondition_1d,
|
||||||
|
"normalize_grads": normalize_grads,
|
||||||
|
"correct_bias": correct_bias,
|
||||||
|
}
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
self._data_format = data_format
|
||||||
|
|
||||||
|
def merge_dims(self, grad, max_precond_dim):
|
||||||
|
"""
|
||||||
|
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
|
||||||
|
"""
|
||||||
|
assert self._data_format in ["channels_first", "channels_last"]
|
||||||
|
if self._data_format == "channels_last" and grad.dim() == 4:
|
||||||
|
grad = grad.permute(0, 3, 1, 2)
|
||||||
|
shape = grad.shape
|
||||||
|
new_shape = []
|
||||||
|
|
||||||
|
curr_shape = 1
|
||||||
|
for sh in shape:
|
||||||
|
temp_shape = curr_shape * sh
|
||||||
|
if temp_shape > max_precond_dim:
|
||||||
|
if curr_shape > 1:
|
||||||
|
new_shape.append(curr_shape)
|
||||||
|
curr_shape = sh
|
||||||
|
else:
|
||||||
|
new_shape.append(sh)
|
||||||
|
curr_shape = 1
|
||||||
|
else:
|
||||||
|
curr_shape = temp_shape
|
||||||
|
|
||||||
|
if curr_shape > 1 or len(new_shape) == 0:
|
||||||
|
new_shape.append(curr_shape)
|
||||||
|
|
||||||
|
new_grad = grad.reshape(new_shape)
|
||||||
|
return new_grad
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""
|
||||||
|
Performs a single optimization step.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||||
|
"""
|
||||||
|
if closure is None:
|
||||||
|
loss = None
|
||||||
|
else:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
if "step" not in state:
|
||||||
|
state["step"] = 0
|
||||||
|
|
||||||
|
# State initialization
|
||||||
|
if "exp_avg" not in state:
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state["exp_avg"] = torch.zeros_like(grad)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||||
|
|
||||||
|
if "Q" not in state:
|
||||||
|
self.init_preconditioner(
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
precondition_frequency=group["precondition_frequency"],
|
||||||
|
precondition_1d=group["precondition_1d"],
|
||||||
|
shampoo_beta=(
|
||||||
|
group["shampoo_beta"]
|
||||||
|
if group["shampoo_beta"] >= 0
|
||||||
|
else group["betas"][1]
|
||||||
|
),
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
)
|
||||||
|
self.update_preconditioner(
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
precondition_1d=group["precondition_1d"],
|
||||||
|
)
|
||||||
|
continue # first step is skipped so that we never use the current gradients in the projection.
|
||||||
|
|
||||||
|
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
||||||
|
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||||
|
grad_projected = self.project(
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
)
|
||||||
|
|
||||||
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
|
state["step"] += 1
|
||||||
|
|
||||||
|
# Decay the first and second moment running average coefficient
|
||||||
|
# In-place operations to update the averages at the same time
|
||||||
|
exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
|
||||||
|
exp_avg_sq.mul_(beta2).add_(
|
||||||
|
grad_projected.square(), alpha=(1.0 - beta2)
|
||||||
|
)
|
||||||
|
|
||||||
|
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||||
|
|
||||||
|
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
||||||
|
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
||||||
|
# exp_avg_projected = self.project(
|
||||||
|
# exp_avg,
|
||||||
|
# state,
|
||||||
|
# merge_dims=group["merge_dims"],
|
||||||
|
# max_precond_dim=group["max_precond_dim"],
|
||||||
|
# )
|
||||||
|
exp_avg_projected = exp_avg
|
||||||
|
|
||||||
|
step_size = group["lr"]
|
||||||
|
if group["correct_bias"]:
|
||||||
|
bias_correction1 = 1.0 - beta1 ** (state["step"])
|
||||||
|
bias_correction2 = 1.0 - beta2 ** (state["step"])
|
||||||
|
step_size = step_size * (bias_correction2**0.5) / bias_correction1
|
||||||
|
|
||||||
|
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
||||||
|
# to the original space
|
||||||
|
norm_grad = self.project_back(
|
||||||
|
exp_avg_projected / denom,
|
||||||
|
state,
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if group["normalize_grads"]:
|
||||||
|
norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5)
|
||||||
|
|
||||||
|
p.add_(norm_grad, alpha=-step_size)
|
||||||
|
|
||||||
|
# From AdamW code: Just adding the square of the weights to the loss function is *not*
|
||||||
|
# the correct way of using L2 regularization/weight decay with Adam,
|
||||||
|
# since that will interact with the m and v parameters in strange ways.
|
||||||
|
#
|
||||||
|
# Instead we want to decay the weights in a manner that doesn't interact
|
||||||
|
# with the m/v parameters. This is equivalent to adding the square
|
||||||
|
# of the weights to the loss with plain (non-momentum) SGD.
|
||||||
|
# Add weight decay at the end (fixed version)
|
||||||
|
if group["weight_decay"] > 0.0:
|
||||||
|
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||||
|
|
||||||
|
# Update is done after the gradient step to avoid using current gradients in the projection.
|
||||||
|
self.update_preconditioner(
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
max_precond_dim=group["max_precond_dim"],
|
||||||
|
merge_dims=group["merge_dims"],
|
||||||
|
precondition_1d=group["precondition_1d"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def init_preconditioner(
|
||||||
|
self,
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
precondition_frequency=10,
|
||||||
|
shampoo_beta=0.95,
|
||||||
|
max_precond_dim=10000,
|
||||||
|
precondition_1d=False,
|
||||||
|
merge_dims=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the preconditioner matrices (L and R in the paper).
|
||||||
|
"""
|
||||||
|
state["GG"] = (
|
||||||
|
[]
|
||||||
|
) # Will hold all the preconditioner matrices (L and R in the paper).
|
||||||
|
if grad.dim() == 1:
|
||||||
|
if not precondition_1d or grad.shape[0] > max_precond_dim:
|
||||||
|
state["GG"].append([])
|
||||||
|
else:
|
||||||
|
state["GG"].append(
|
||||||
|
torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if merge_dims:
|
||||||
|
grad = self.merge_dims(grad, max_precond_dim)
|
||||||
|
|
||||||
|
for sh in grad.shape:
|
||||||
|
if sh > max_precond_dim:
|
||||||
|
state["GG"].append([])
|
||||||
|
else:
|
||||||
|
state["GG"].append(torch.zeros(sh, sh, device=grad.device))
|
||||||
|
|
||||||
|
state["Q"] = None # Will hold all the eigenbases of the preconditioner.
|
||||||
|
state["precondition_frequency"] = precondition_frequency
|
||||||
|
state["shampoo_beta"] = shampoo_beta
|
||||||
|
|
||||||
|
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||||
|
"""
|
||||||
|
Projects the gradient to the eigenbases of the preconditioner.
|
||||||
|
"""
|
||||||
|
original_shape = grad.shape
|
||||||
|
if merge_dims:
|
||||||
|
if grad.dim() == 4 and self._data_format == "channels_last":
|
||||||
|
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
||||||
|
grad = self.merge_dims(grad, max_precond_dim)
|
||||||
|
|
||||||
|
for mat in state["Q"]:
|
||||||
|
if len(mat) > 0:
|
||||||
|
grad = torch.tensordot(
|
||||||
|
grad,
|
||||||
|
mat,
|
||||||
|
dims=[[0], [0]],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
permute_order = list(range(1, len(grad.shape))) + [0]
|
||||||
|
grad = grad.permute(permute_order)
|
||||||
|
|
||||||
|
if merge_dims:
|
||||||
|
if self._data_format == "channels_last" and len(original_shape) == 4:
|
||||||
|
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||||
|
else:
|
||||||
|
grad = grad.reshape(original_shape)
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def update_preconditioner(
|
||||||
|
self,
|
||||||
|
grad,
|
||||||
|
state,
|
||||||
|
max_precond_dim=10000,
|
||||||
|
merge_dims=False,
|
||||||
|
precondition_1d=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
||||||
|
"""
|
||||||
|
if state["Q"] is not None:
|
||||||
|
state["exp_avg"] = self.project_back(
|
||||||
|
state["exp_avg"],
|
||||||
|
state,
|
||||||
|
merge_dims=merge_dims,
|
||||||
|
max_precond_dim=max_precond_dim,
|
||||||
|
)
|
||||||
|
if grad.dim() == 1:
|
||||||
|
if precondition_1d and grad.shape[0] <= max_precond_dim:
|
||||||
|
state["GG"][0].lerp_(
|
||||||
|
grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if merge_dims:
|
||||||
|
new_grad = self.merge_dims(grad, max_precond_dim)
|
||||||
|
for idx, sh in enumerate(new_grad.shape):
|
||||||
|
if sh <= max_precond_dim:
|
||||||
|
outer_product = torch.tensordot(
|
||||||
|
new_grad,
|
||||||
|
new_grad,
|
||||||
|
dims=[
|
||||||
|
[
|
||||||
|
*chain(
|
||||||
|
range(idx), range(idx + 1, len(new_grad.shape))
|
||||||
|
)
|
||||||
|
]
|
||||||
|
]
|
||||||
|
* 2,
|
||||||
|
)
|
||||||
|
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
||||||
|
else:
|
||||||
|
for idx, sh in enumerate(grad.shape):
|
||||||
|
if sh <= max_precond_dim:
|
||||||
|
outer_product = torch.tensordot(
|
||||||
|
grad,
|
||||||
|
grad,
|
||||||
|
# Contracts across all dimensions except for k.
|
||||||
|
dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]]
|
||||||
|
* 2,
|
||||||
|
)
|
||||||
|
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
||||||
|
|
||||||
|
if state["Q"] is None:
|
||||||
|
state["Q"] = self.get_orthogonal_matrix(state["GG"])
|
||||||
|
if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0:
|
||||||
|
state["Q"] = self.get_orthogonal_matrix_QR(
|
||||||
|
state, max_precond_dim, merge_dims
|
||||||
|
)
|
||||||
|
# state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims)
|
||||||
|
|
||||||
|
if state["step"] > 0:
|
||||||
|
state["exp_avg"] = self.project(
|
||||||
|
state["exp_avg"],
|
||||||
|
state,
|
||||||
|
merge_dims=merge_dims,
|
||||||
|
max_precond_dim=max_precond_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
||||||
|
"""
|
||||||
|
Projects the gradient back to the original space.
|
||||||
|
"""
|
||||||
|
original_shape = grad.shape
|
||||||
|
if merge_dims:
|
||||||
|
if self._data_format == "channels_last" and grad.dim() == 4:
|
||||||
|
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
||||||
|
grad = self.merge_dims(grad, max_precond_dim)
|
||||||
|
for mat in state["Q"]:
|
||||||
|
if len(mat) > 0:
|
||||||
|
grad = torch.tensordot(
|
||||||
|
grad,
|
||||||
|
mat,
|
||||||
|
dims=[[0], [1]],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
permute_order = list(range(1, len(grad.shape))) + [0]
|
||||||
|
grad = grad.permute(permute_order)
|
||||||
|
|
||||||
|
if merge_dims:
|
||||||
|
if self._data_format == "channels_last" and len(original_shape) == 4:
|
||||||
|
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||||
|
else:
|
||||||
|
grad = grad.reshape(original_shape)
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def get_orthogonal_matrix(self, mat):
|
||||||
|
"""
|
||||||
|
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
||||||
|
"""
|
||||||
|
matrix = []
|
||||||
|
for m in mat:
|
||||||
|
if len(m) == 0:
|
||||||
|
matrix.append([])
|
||||||
|
continue
|
||||||
|
if m.data.dtype != torch.float:
|
||||||
|
float_data = False
|
||||||
|
original_type = m.data.dtype
|
||||||
|
original_device = m.data.device
|
||||||
|
matrix.append(m.data.float())
|
||||||
|
else:
|
||||||
|
float_data = True
|
||||||
|
matrix.append(m.data)
|
||||||
|
|
||||||
|
final = []
|
||||||
|
for m in matrix:
|
||||||
|
if len(m) == 0:
|
||||||
|
final.append([])
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
_, Q = torch.linalg.eigh(
|
||||||
|
m + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
||||||
|
)
|
||||||
|
except: # pylint: disable=bare-except # noqa: E722
|
||||||
|
_, Q = torch.linalg.eigh(
|
||||||
|
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
||||||
|
)
|
||||||
|
Q = Q.to(m.dtype)
|
||||||
|
Q = torch.flip(Q, [1])
|
||||||
|
|
||||||
|
if not float_data:
|
||||||
|
Q = Q.to(original_device).type(original_type)
|
||||||
|
final.append(Q)
|
||||||
|
return final
|
||||||
|
|
||||||
|
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
|
||||||
|
"""
|
||||||
|
Computes the eigenbases of the preconditioner using one round of power iteration
|
||||||
|
followed by torch.linalg.qr decomposition.
|
||||||
|
"""
|
||||||
|
precond_list = state["GG"]
|
||||||
|
orth_list = state["Q"]
|
||||||
|
|
||||||
|
matrix = []
|
||||||
|
orth_matrix = []
|
||||||
|
for m, o in zip(precond_list, orth_list):
|
||||||
|
if len(m) == 0:
|
||||||
|
matrix.append([])
|
||||||
|
orth_matrix.append([])
|
||||||
|
continue
|
||||||
|
if m.data.dtype != torch.float:
|
||||||
|
float_data = False
|
||||||
|
original_type = m.data.dtype
|
||||||
|
original_device = m.data.device
|
||||||
|
matrix.append(m.data.float())
|
||||||
|
orth_matrix.append(o.data.float())
|
||||||
|
else:
|
||||||
|
float_data = True
|
||||||
|
matrix.append(m.data.float())
|
||||||
|
orth_matrix.append(o.data.float())
|
||||||
|
|
||||||
|
orig_shape = state["exp_avg_sq"].shape
|
||||||
|
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
||||||
|
permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape
|
||||||
|
if merge_dims:
|
||||||
|
exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim)
|
||||||
|
else:
|
||||||
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
|
|
||||||
|
final = []
|
||||||
|
for ind, (m, o) in enumerate(zip(matrix, orth_matrix)):
|
||||||
|
if len(m) == 0:
|
||||||
|
final.append([])
|
||||||
|
continue
|
||||||
|
est_eig = torch.diag(o.T @ m @ o)
|
||||||
|
sort_idx = torch.argsort(est_eig, descending=True)
|
||||||
|
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
||||||
|
o = o[:, sort_idx]
|
||||||
|
power_iter = m @ o
|
||||||
|
Q, _ = torch.linalg.qr(power_iter)
|
||||||
|
|
||||||
|
if not float_data:
|
||||||
|
Q = Q.to(original_device).type(original_type)
|
||||||
|
final.append(Q)
|
||||||
|
|
||||||
|
if merge_dims:
|
||||||
|
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
||||||
|
exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
|
||||||
|
else:
|
||||||
|
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
|
||||||
|
|
||||||
|
state["exp_avg_sq"] = exp_avg_sq
|
||||||
|
return final
|
||||||
@@ -104,9 +104,7 @@ def allocate(
|
|||||||
|
|
||||||
|
|
||||||
class MultipackBatchSampler(BatchSampler):
|
class MultipackBatchSampler(BatchSampler):
|
||||||
"""
|
"""Batch sampler class for multipack"""
|
||||||
Batch Sampler class for multipack
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
165
src/axolotl/utils/schemas/datasets.py
Normal file
165
src/axolotl/utils/schemas/datasets.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""Pydantic models for datasets-related configuration"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import ChatTemplate
|
||||||
|
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedPrompterType(BaseModel):
|
||||||
|
"""Structure for user defined prompt types"""
|
||||||
|
|
||||||
|
system_prompt: str | None = None
|
||||||
|
system_format: str | None = None
|
||||||
|
field_system: str | None = None
|
||||||
|
field_instruction: str | None = None
|
||||||
|
field_input: str | None = None
|
||||||
|
field_output: str | None = None
|
||||||
|
|
||||||
|
format: str | None = None
|
||||||
|
no_input_format: str | None = None
|
||||||
|
field: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SFTDataset(BaseModel):
|
||||||
|
"""SFT configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: str | UserDefinedPrompterType | None = None
|
||||||
|
input_transform: str | None = None
|
||||||
|
shards: int | None = None
|
||||||
|
shards_idx: int | None = None
|
||||||
|
preprocess_shards: int | None = None
|
||||||
|
conversation: str | None = None
|
||||||
|
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||||
|
chat_template: ChatTemplate | str | None = None
|
||||||
|
chat_template_jinja: str | None = None
|
||||||
|
data_files: str | list[str] | None = None
|
||||||
|
input_format: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
ds_type: str | None = None
|
||||||
|
train_on_split: str | None = None
|
||||||
|
field: str | None = None
|
||||||
|
field_human: str | None = None
|
||||||
|
field_model: str | None = None
|
||||||
|
field_messages: str | None = None
|
||||||
|
# deprecated, use message_property_mappings
|
||||||
|
message_field_role: str | None = None
|
||||||
|
# deprecated, use message_property_mappings
|
||||||
|
message_field_content: str | None = None
|
||||||
|
message_property_mappings: dict[str, str] | None = None
|
||||||
|
message_field_training: str | None = None
|
||||||
|
message_field_training_detail: str | None = None
|
||||||
|
logprobs_field: str | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
roles_to_train: list[str] | None = None
|
||||||
|
train_on_eos: str | None = None
|
||||||
|
roles: dict[str, list[str]] | None = None
|
||||||
|
drop_system_message: bool | None = None
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
revision: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def handle_legacy_message_fields(cls, data):
|
||||||
|
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
||||||
|
return handle_legacy_message_fields_logic(data)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def check_chat_template_config(cls, data):
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
data = data.model_dump()
|
||||||
|
|
||||||
|
# Set chat_template to tokenizer_default if not set
|
||||||
|
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||||
|
|
||||||
|
# if chat_template is set to jinja, chat_template_jinja is required
|
||||||
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
|
"chat_template_jinja"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If chat_template_jinja is set, set chat_template to jinja
|
||||||
|
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.jinja
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainingDataset(BaseModel):
|
||||||
|
"""Pretraining dataset configuration subset"""
|
||||||
|
|
||||||
|
name: str | None = None
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = "train"
|
||||||
|
text_column: str | None = "text"
|
||||||
|
type: str | None = "pretrain"
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
data_files: str | None = None
|
||||||
|
skip: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedDPOType(BaseModel):
|
||||||
|
"""User defined typing for DPO"""
|
||||||
|
|
||||||
|
field_system: str | None = None
|
||||||
|
field_prompt: str | None = None
|
||||||
|
field_chosen: str | None = None
|
||||||
|
field_rejected: str | None = None
|
||||||
|
prompt_format: str | None = None
|
||||||
|
chosen_format: str | None = None
|
||||||
|
rejected_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DPODataset(BaseModel):
|
||||||
|
"""DPO configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: UserDefinedDPOType | str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
revision: str | None = None
|
||||||
|
field_messages: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StepwiseSupervisedDataset(BaseModel):
|
||||||
|
"""Stepwise supervised dataset configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
revision: str | None = None
|
||||||
|
step_separator: str | None = None
|
||||||
|
max_completion_length: int | None = None
|
||||||
|
train_on_last_step_only: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedKTOType(BaseModel):
|
||||||
|
"""User defined typing for KTO"""
|
||||||
|
|
||||||
|
field_system: str | None = None
|
||||||
|
field_prompt: str | None = None
|
||||||
|
field_completion: str | None = None
|
||||||
|
field_label: bool | None = None
|
||||||
|
prompt_format: str | None = None
|
||||||
|
completion_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class KTODataset(BaseModel):
|
||||||
|
"""KTO configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: UserDefinedKTOType | str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
revision: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
||||||
68
src/axolotl/utils/schemas/deprecated.py
Normal file
68
src/axolotl/utils/schemas/deprecated.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""Pydantic models for deprecated and remapped configuration parameters"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedParameters(BaseModel):
|
||||||
|
"""configurations that are deprecated"""
|
||||||
|
|
||||||
|
max_packed_sequence_len: int | None = None
|
||||||
|
rope_scaling: Any | None = None
|
||||||
|
noisy_embedding_alpha: float | None = None
|
||||||
|
dpo_beta: float | None = None
|
||||||
|
evaluation_strategy: str | None = None
|
||||||
|
|
||||||
|
@field_validator("max_packed_sequence_len")
|
||||||
|
@classmethod
|
||||||
|
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
|
||||||
|
if max_packed_sequence_len:
|
||||||
|
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||||
|
return max_packed_sequence_len
|
||||||
|
|
||||||
|
@field_validator("rope_scaling")
|
||||||
|
@classmethod
|
||||||
|
def validate_rope_scaling(cls, rope_scaling):
|
||||||
|
if rope_scaling:
|
||||||
|
raise DeprecationWarning(
|
||||||
|
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
|
||||||
|
)
|
||||||
|
return rope_scaling
|
||||||
|
|
||||||
|
@field_validator("noisy_embedding_alpha")
|
||||||
|
@classmethod
|
||||||
|
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
|
||||||
|
if noisy_embedding_alpha:
|
||||||
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
|
return noisy_embedding_alpha
|
||||||
|
|
||||||
|
@field_validator("dpo_beta")
|
||||||
|
@classmethod
|
||||||
|
def validate_dpo_beta(cls, dpo_beta):
|
||||||
|
if dpo_beta is not None:
|
||||||
|
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||||
|
return dpo_beta
|
||||||
|
|
||||||
|
@field_validator("evaluation_strategy")
|
||||||
|
@classmethod
|
||||||
|
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||||
|
if evaluation_strategy is not None:
|
||||||
|
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||||
|
return evaluation_strategy
|
||||||
|
|
||||||
|
|
||||||
|
class RemappedParameters(BaseModel):
|
||||||
|
"""Parameters that have been remapped to other names"""
|
||||||
|
|
||||||
|
overrides_of_model_config: dict[str, Any] | None = Field(
|
||||||
|
default=None, alias="model_config"
|
||||||
|
)
|
||||||
|
overrides_of_model_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None, alias="model_kwargs"
|
||||||
|
)
|
||||||
|
type_of_model: str | None = Field(default=None, alias="model_type")
|
||||||
|
revision_of_model: str | None = Field(default=None, alias="model_revision")
|
||||||
55
src/axolotl/utils/schemas/enums.py
Normal file
55
src/axolotl/utils/schemas/enums.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Enums for Axolotl input config"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class RLType(str, Enum):
|
||||||
|
"""RL trainer type configuration subset"""
|
||||||
|
|
||||||
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
|
grpo = "grpo" # pylint: disable=invalid-name
|
||||||
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
simpo = "simpo" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTemplate(str, Enum):
|
||||||
|
"""Chat templates configuration subset"""
|
||||||
|
|
||||||
|
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||||
|
chatml = "chatml" # pylint: disable=invalid-name
|
||||||
|
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||||
|
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||||
|
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||||
|
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
|
||||||
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
|
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||||
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
|
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||||
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
|
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
||||||
|
jamba = "jamba" # pylint: disable=invalid-name
|
||||||
|
jinja = "jinja" # pylint: disable=invalid-name
|
||||||
|
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||||
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
pixtral = "pixtral" # pylint: disable=invalid-name
|
||||||
|
llava = "llava" # pylint: disable=invalid-name
|
||||||
|
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
||||||
|
gemma3 = "gemma3" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
|
"""Custom supported optimizers"""
|
||||||
|
|
||||||
|
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
|
soap = "soap" # pylint: disable=invalid-name
|
||||||
108
src/axolotl/utils/schemas/integrations.py
Normal file
108
src/axolotl/utils/schemas/integrations.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Pydantic models for Axolotl integrations"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MLFlowConfig(BaseModel):
|
||||||
|
"""MLFlow configuration subset"""
|
||||||
|
|
||||||
|
use_mlflow: bool | None = None
|
||||||
|
mlflow_tracking_uri: str | None = None
|
||||||
|
mlflow_experiment_name: str | None = None
|
||||||
|
mlflow_run_name: str | None = None
|
||||||
|
hf_mlflow_log_artifacts: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LISAConfig(BaseModel):
|
||||||
|
"""LISA configuration subset"""
|
||||||
|
|
||||||
|
lisa_n_layers: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "the number of activate layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_step_interval: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "how often to switch layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_layers_attribute: str | None = Field(
|
||||||
|
default="model.layers",
|
||||||
|
json_schema_extra={"description": "path under the model to access the layers"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WandbConfig(BaseModel):
|
||||||
|
"""Wandb configuration subset"""
|
||||||
|
|
||||||
|
use_wandb: bool | None = None
|
||||||
|
wandb_name: str | None = None
|
||||||
|
wandb_run_id: str | None = None
|
||||||
|
wandb_mode: str | None = None
|
||||||
|
wandb_project: str | None = None
|
||||||
|
wandb_entity: str | None = None
|
||||||
|
wandb_watch: str | None = None
|
||||||
|
wandb_log_model: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_wandb_run(cls, data):
|
||||||
|
if data.get("wandb_run_id") and not data.get("wandb_name"):
|
||||||
|
data["wandb_name"] = data.get("wandb_run_id")
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CometConfig(BaseModel):
|
||||||
|
"""Comet configuration subset"""
|
||||||
|
|
||||||
|
use_comet: bool | None = None
|
||||||
|
comet_api_key: str | None = None
|
||||||
|
comet_workspace: str | None = None
|
||||||
|
comet_project_name: str | None = None
|
||||||
|
comet_experiment_key: str | None = None
|
||||||
|
comet_mode: str | None = None
|
||||||
|
comet_online: bool | None = None
|
||||||
|
comet_experiment_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GradioConfig(BaseModel):
|
||||||
|
"""Gradio configuration subset"""
|
||||||
|
|
||||||
|
gradio_title: str | None = None
|
||||||
|
gradio_share: bool | None = None
|
||||||
|
gradio_server_name: str | None = None
|
||||||
|
gradio_server_port: int | None = None
|
||||||
|
gradio_max_new_tokens: int | None = None
|
||||||
|
gradio_temperature: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RayConfig(BaseModel):
|
||||||
|
"""Ray launcher configuration subset"""
|
||||||
|
|
||||||
|
use_ray: bool = Field(default=False)
|
||||||
|
ray_run_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ray_num_workers: int = Field(
|
||||||
|
default=1,
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resources_per_worker: dict = Field(
|
||||||
|
default_factory=lambda: {"GPU": 1},
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||||
|
},
|
||||||
|
)
|
||||||
55
src/axolotl/utils/schemas/model.py
Normal file
55
src/axolotl/utils/schemas/model.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Pydantic models for model input / output, etc. configuration"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInputConfig(BaseModel):
|
||||||
|
"""Model configuration subset"""
|
||||||
|
|
||||||
|
model_config = {"protected_namespaces": ()}
|
||||||
|
|
||||||
|
base_model: str
|
||||||
|
base_model_config: str | None = None
|
||||||
|
cls_model_config: str | None = None
|
||||||
|
tokenizer_config: str | None = None
|
||||||
|
tokenizer_use_fast: bool | None = None
|
||||||
|
tokenizer_legacy: bool | None = None
|
||||||
|
tokenizer_type: str | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||||
|
)
|
||||||
|
processor_type: str | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
|
)
|
||||||
|
trust_remote_code: bool | None = None
|
||||||
|
|
||||||
|
@field_validator("trust_remote_code")
|
||||||
|
@classmethod
|
||||||
|
def hint_trust_remote_code(cls, trust_remote_code):
|
||||||
|
if trust_remote_code:
|
||||||
|
LOG.warning(
|
||||||
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
|
)
|
||||||
|
return trust_remote_code
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutputConfig(BaseModel):
|
||||||
|
"""model save configuration subset"""
|
||||||
|
|
||||||
|
output_dir: str = Field(default="./model-out")
|
||||||
|
hub_model_id: str | None = None
|
||||||
|
hub_strategy: str | None = None
|
||||||
|
save_safetensors: bool | None = True
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialTokensConfig(BaseModel):
|
||||||
|
"""Special tokens configuration subset"""
|
||||||
|
|
||||||
|
bos_token: str | None = None
|
||||||
|
eos_token: str | None = None
|
||||||
|
pad_token: str | None = None
|
||||||
|
unk_token: str | None = None
|
||||||
|
additional_special_tokens: list[str] | None = None
|
||||||
48
src/axolotl/utils/schemas/multimodal.py
Normal file
48
src/axolotl/utils/schemas/multimodal.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Pydantic models for multimodal-related configuration"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from PIL.Image import Resampling
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalConfig(BaseModel):
|
||||||
|
"""Multi-modal configuration subset"""
|
||||||
|
|
||||||
|
image_size: int | tuple[int, int] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"The size of the image to resize to. It can be an integer (resized into padded-square image) or a tuple (width, height)."
|
||||||
|
"If not provided, we will attempt to load from preprocessor.size, otherwise, images won't be resized."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
image_resize_algorithm: (
|
||||||
|
Literal["bilinear", "bicubic", "lanczos"] | Resampling | None
|
||||||
|
) = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("image_resize_algorithm", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_image_resize_algorithm(cls, image_resize_algorithm):
|
||||||
|
"""
|
||||||
|
Convert the image resize algorithm to a PIL.Image.Resampling enum.
|
||||||
|
"""
|
||||||
|
if isinstance(image_resize_algorithm, str):
|
||||||
|
image_resize_algorithm = image_resize_algorithm.lower()
|
||||||
|
if image_resize_algorithm == "bilinear":
|
||||||
|
image_resize_algorithm = Resampling.BILINEAR
|
||||||
|
elif image_resize_algorithm == "bicubic":
|
||||||
|
image_resize_algorithm = Resampling.BICUBIC
|
||||||
|
elif image_resize_algorithm == "lanczos":
|
||||||
|
image_resize_algorithm = Resampling.LANCZOS
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid image resize algorithm: {image_resize_algorithm}"
|
||||||
|
)
|
||||||
|
return image_resize_algorithm
|
||||||
132
src/axolotl/utils/schemas/peft.py
Normal file
132
src/axolotl/utils/schemas/peft.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Pydantic models for PEFT-related configuration"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class LoftQConfig(BaseModel):
|
||||||
|
"""LoftQ configuration subset"""
|
||||||
|
|
||||||
|
loftq_bits: int = Field(
|
||||||
|
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||||
|
)
|
||||||
|
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||||
|
|
||||||
|
|
||||||
|
class PeftConfig(BaseModel):
|
||||||
|
"""peftq configuration subset"""
|
||||||
|
|
||||||
|
loftq_config: LoftQConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LoraConfig(BaseModel):
|
||||||
|
"""Peft / LoRA configuration subset"""
|
||||||
|
|
||||||
|
load_in_8bit: bool | None = Field(default=False)
|
||||||
|
load_in_4bit: bool | None = Field(default=False)
|
||||||
|
|
||||||
|
adapter: str | None = None
|
||||||
|
lora_model_dir: str | None = None
|
||||||
|
lora_r: int | None = None
|
||||||
|
lora_alpha: int | None = None
|
||||||
|
lora_fan_in_fan_out: bool | None = None
|
||||||
|
lora_target_modules: str | list[str] | None = None
|
||||||
|
lora_target_linear: bool | None = None
|
||||||
|
lora_modules_to_save: list[str] | None = None
|
||||||
|
lora_dropout: float | None = 0.0
|
||||||
|
peft_layers_to_transform: list[int] | None = None
|
||||||
|
peft_layers_pattern: list[str] | None = None
|
||||||
|
peft: PeftConfig | None = None
|
||||||
|
peft_use_dora: bool | None = None
|
||||||
|
peft_use_rslora: bool | None = None
|
||||||
|
peft_layer_replication: list[tuple[int, int]] | None = None
|
||||||
|
peft_init_lora_weights: bool | str | None = None
|
||||||
|
|
||||||
|
qlora_sharded_model_loading: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
lora_on_cpu: bool | None = None
|
||||||
|
gptq: bool | None = None
|
||||||
|
bnb_config_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
loraplus_lr_ratio: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
loraplus_lr_embedding: float | None = Field(
|
||||||
|
default=1e-6,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate for lora embedding layers."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_lora: bool | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_adapter(cls, data):
|
||||||
|
if (
|
||||||
|
not data.get("adapter")
|
||||||
|
and not data.get("inference")
|
||||||
|
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||||
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_qlora(self):
|
||||||
|
if self.adapter == "qlora":
|
||||||
|
if self.merge_lora:
|
||||||
|
# can't merge qlora if loaded in 8bit or 4bit
|
||||||
|
if self.load_in_8bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
|
|
||||||
|
if self.gptq:
|
||||||
|
raise ValueError("Can't merge qlora if gptq")
|
||||||
|
|
||||||
|
if self.load_in_4bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.load_in_8bit:
|
||||||
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
|
|
||||||
|
if self.gptq:
|
||||||
|
raise ValueError("Can't load qlora if gptq")
|
||||||
|
|
||||||
|
if not self.load_in_4bit:
|
||||||
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("loraplus_lr_embedding")
|
||||||
|
@classmethod
|
||||||
|
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
|
||||||
|
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
|
||||||
|
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
||||||
|
return loraplus_lr_embedding
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_lora_dropout(cls, data):
|
||||||
|
if data.get("adapter") is not None and data.get("lora_dropout") is None:
|
||||||
|
data["lora_dropout"] = 0.0
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRAConfig(BaseModel):
|
||||||
|
"""ReLoRA configuration subset"""
|
||||||
|
|
||||||
|
relora_steps: int | None = None
|
||||||
|
relora_warmup_steps: int | None = None
|
||||||
|
relora_anneal_steps: int | None = None
|
||||||
|
relora_prune_ratio: float | None = None
|
||||||
|
relora_cpu_offload: bool | None = None
|
||||||
99
src/axolotl/utils/schemas/training.py
Normal file
99
src/axolotl/utils/schemas/training.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Pydantic models for training hyperparameters"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from transformers import SchedulerType
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LrGroup(BaseModel):
|
||||||
|
"""Custom learning rate group configuration"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
modules: list[str]
|
||||||
|
lr: float
|
||||||
|
|
||||||
|
|
||||||
|
class HyperparametersConfig(BaseModel):
|
||||||
|
"""Training hyperparams configuration subset"""
|
||||||
|
|
||||||
|
gradient_accumulation_steps: int | None = Field(default=1)
|
||||||
|
micro_batch_size: int | None = Field(
|
||||||
|
default=1,
|
||||||
|
json_schema_extra={"description": "per gpu micro batch size for training"},
|
||||||
|
)
|
||||||
|
batch_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Total batch size, we do not recommended setting this manually"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eval_batch_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_find_batch_size: bool | None = None
|
||||||
|
|
||||||
|
train_on_inputs: bool | None = False
|
||||||
|
group_by_length: bool | None = None
|
||||||
|
|
||||||
|
learning_rate: str | float
|
||||||
|
embedding_lr: float | None = None
|
||||||
|
embedding_lr_scale: float | None = None
|
||||||
|
weight_decay: float | None = 0.0
|
||||||
|
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
|
||||||
|
OptimizerNames.ADAMW_TORCH_FUSED
|
||||||
|
)
|
||||||
|
optim_args: (str | dict[str, Any]) | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||||
|
)
|
||||||
|
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
torchdistx_path: str | None = None
|
||||||
|
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
|
||||||
|
SchedulerType.COSINE
|
||||||
|
)
|
||||||
|
lr_scheduler_kwargs: dict[str, Any] | None = None
|
||||||
|
lr_quadratic_warmup: bool | None = None
|
||||||
|
cosine_min_lr_ratio: float | None = None
|
||||||
|
cosine_constant_lr_ratio: float | None = None
|
||||||
|
lr_div_factor: float | None = None
|
||||||
|
lr_groups: list[LrGroup] | None = None
|
||||||
|
|
||||||
|
adam_epsilon: float | None = None
|
||||||
|
adam_beta1: float | None = None
|
||||||
|
adam_beta2: float | None = None
|
||||||
|
max_grad_norm: float | None = None
|
||||||
|
num_epochs: float = Field(default=1.0)
|
||||||
|
|
||||||
|
@field_validator("batch_size")
|
||||||
|
@classmethod
|
||||||
|
def hint_batch_size_set(cls, batch_size):
|
||||||
|
if batch_size:
|
||||||
|
LOG.warning(
|
||||||
|
"%s\n%s",
|
||||||
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
|
)
|
||||||
|
return batch_size
|
||||||
|
|
||||||
|
@field_validator("learning_rate")
|
||||||
|
@classmethod
|
||||||
|
def convert_learning_rate(cls, learning_rate):
|
||||||
|
if learning_rate and isinstance(learning_rate, str):
|
||||||
|
learning_rate = float(learning_rate)
|
||||||
|
return learning_rate
|
||||||
@@ -1,8 +1,4 @@
|
|||||||
"""
|
"""Pydantic models for TRL trainer configuration"""
|
||||||
GRPO specific configuration args
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -12,11 +8,11 @@ class TRLConfig(BaseModel):
|
|||||||
Input args for TRL.
|
Input args for TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beta: Optional[float] = Field(
|
beta: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Beta for RL training"},
|
json_schema_extra={"description": "Beta for RL training"},
|
||||||
)
|
)
|
||||||
max_completion_length: Optional[int] = Field(
|
max_completion_length: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the completion for RL training"
|
"description": "Maximum length of the completion for RL training"
|
||||||
@@ -25,50 +21,50 @@ class TRLConfig(BaseModel):
|
|||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||||
use_vllm: Optional[bool] = Field(
|
use_vllm: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||||
)
|
)
|
||||||
vllm_device: Optional[str] = Field(
|
vllm_device: str | None = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Device to use for VLLM"},
|
json_schema_extra={"description": "Device to use for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_gpu_memory_utilization: Optional[float] = Field(
|
vllm_gpu_memory_utilization: float | None = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_dtype: Optional[str] = Field(
|
vllm_dtype: str | None = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Data type for VLLM"},
|
json_schema_extra={"description": "Data type for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_max_model_len: Optional[int] = Field(
|
vllm_max_model_len: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the model context for VLLM"
|
"description": "Maximum length of the model context for VLLM"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
reward_funcs: Optional[list[str]] = Field(
|
reward_funcs: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "List of reward functions to load"},
|
json_schema_extra={"description": "List of reward functions to load"},
|
||||||
)
|
)
|
||||||
reward_weights: Optional[list[float]] = Field(
|
reward_weights: list[float] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_generations: Optional[int] = Field(
|
num_generations: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log_completions: Optional[bool] = Field(
|
log_completions: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to log completions"},
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
)
|
)
|
||||||
sync_ref_model: Optional[bool] = Field(
|
sync_ref_model: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": (
|
"description": (
|
||||||
@@ -77,13 +73,13 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_mixup_alpha: Optional[float] = Field(
|
ref_model_mixup_alpha: float | None = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_sync_steps: Optional[int] = Field(
|
ref_model_sync_steps: int | None = Field(
|
||||||
default=64,
|
default=64,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||||
79
src/axolotl/utils/schemas/utils.py
Normal file
79
src/axolotl/utils/schemas/utils.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Utilities for Axolotl Pydantic models"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
||||||
|
|
||||||
|
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
||||||
|
- message_field_role: Mapped to the role field
|
||||||
|
- message_field_content: Mapped to the content field
|
||||||
|
|
||||||
|
The new system uses message_property_mappings to support arbitrary field mappings:
|
||||||
|
message_property_mappings:
|
||||||
|
role: source_role_field
|
||||||
|
content: source_content_field
|
||||||
|
additional_field: source_field
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing configuration data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated dictionary with message field mappings consolidated
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there are conflicts between legacy and new mappings
|
||||||
|
"""
|
||||||
|
data = data.copy() # Create a copy to avoid modifying the original
|
||||||
|
|
||||||
|
if data.get("message_property_mappings") is None:
|
||||||
|
data["message_property_mappings"] = {}
|
||||||
|
|
||||||
|
# Check for conflicts and handle role
|
||||||
|
if "message_field_role" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_role is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"role" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
||||||
|
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
||||||
|
|
||||||
|
del data["message_field_role"]
|
||||||
|
elif "role" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["role"] = "role"
|
||||||
|
|
||||||
|
# Check for conflicts and handle content
|
||||||
|
if "message_field_content" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_content is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"content" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["content"]
|
||||||
|
!= data["message_field_content"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||||
|
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["content"] = (
|
||||||
|
data["message_field_content"] or "content"
|
||||||
|
)
|
||||||
|
|
||||||
|
del data["message_field_content"]
|
||||||
|
elif "content" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["content"] = "content"
|
||||||
|
|
||||||
|
return data
|
||||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing:
|
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
|
* cfg.sequence_parallel_degree
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||||
@@ -473,7 +474,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
total_num_steps = int(
|
||||||
|
math.floor(
|
||||||
|
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
@@ -494,7 +499,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
math.ceil(
|
||||||
|
len(train_dataset)
|
||||||
|
* cfg.num_epochs
|
||||||
|
* cfg.sequence_parallel_degree
|
||||||
|
/ cfg.batch_size
|
||||||
|
)
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|||||||
90
styles.css
90
styles.css
@@ -14,7 +14,7 @@
|
|||||||
h1 {
|
h1 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 5rem;
|
font-size: 3rem;
|
||||||
line-height: 1.1;
|
line-height: 1.1;
|
||||||
letter-spacing: -0.05em;
|
letter-spacing: -0.05em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -24,7 +24,7 @@ h1 {
|
|||||||
h2 {
|
h2 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
font-size: 2rem;
|
font-size: 1.5rem;
|
||||||
line-height: 1.2;
|
line-height: 1.2;
|
||||||
letter-spacing: -0.03em;
|
letter-spacing: -0.03em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -35,7 +35,7 @@ h3,
|
|||||||
h4 {
|
h4 {
|
||||||
font-family: var(--font-body);
|
font-family: var(--font-body);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 1.5rem;
|
font-size: 1.25rem;
|
||||||
line-height: 1.5;
|
line-height: 1.5;
|
||||||
letter-spacing: -0.02em;
|
letter-spacing: -0.02em;
|
||||||
}
|
}
|
||||||
@@ -191,3 +191,87 @@ code span.er {
|
|||||||
color: #5cb85c !important;
|
color: #5cb85c !important;
|
||||||
text-decoration: none !important;
|
text-decoration: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* API Documentation Styling */
|
||||||
|
|
||||||
|
/* Improve docstring section rendering */
|
||||||
|
.level3 p {
|
||||||
|
white-space: pre-line !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Format docstring sections */
|
||||||
|
.level3 p strong {
|
||||||
|
display: block;
|
||||||
|
margin-top: 1em;
|
||||||
|
font-weight: bold;
|
||||||
|
color: var(--cyan);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Add spacing after sections */
|
||||||
|
.level3 p:has(strong) {
|
||||||
|
margin-bottom: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Format Args and Returns sections */
|
||||||
|
p:has(code) {
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Function signatures */
|
||||||
|
.sourceCode {
|
||||||
|
margin-bottom: 1.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Parameter tables */
|
||||||
|
.doc-section-parameters table,
|
||||||
|
.doc-section-returns table {
|
||||||
|
margin-top: 1em;
|
||||||
|
margin-bottom: 1.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Make parameter and returns headers smaller */
|
||||||
|
h2.anchored[data-anchor-id="parameters"],
|
||||||
|
h2.anchored[data-anchor-id="returns"],
|
||||||
|
.doc-section-parameters h4,
|
||||||
|
.doc-section-returns h4 {
|
||||||
|
font-size: 1.25rem;
|
||||||
|
margin-top: 2rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
color: var(--lime);
|
||||||
|
border-bottom: 1px solid var(--lime);
|
||||||
|
padding-bottom: 0.3rem;
|
||||||
|
font-family: var(--font-body);
|
||||||
|
font-weight: 500;
|
||||||
|
letter-spacing: normal;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Style documentation tables */
|
||||||
|
table {
|
||||||
|
width: 100%;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
border-collapse: collapse;
|
||||||
|
}
|
||||||
|
|
||||||
|
table th {
|
||||||
|
background-color: #1a1a1a;
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
border-bottom: 2px solid var(--greige-600);
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
|
||||||
|
table td {
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
border-bottom: 1px solid var(--greige-600);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Code in table cells */
|
||||||
|
table td code {
|
||||||
|
background-color: transparent !important;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Improve spacing in parameter and return tables */
|
||||||
|
.doc-section-parameters,
|
||||||
|
.doc-section-returns {
|
||||||
|
margin-top: 1rem;
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,7 +11,11 @@ import time
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
from datasets import load_dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
@@ -25,9 +29,11 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
|||||||
except (
|
except (
|
||||||
requests.exceptions.ReadTimeout,
|
requests.exceptions.ReadTimeout,
|
||||||
requests.exceptions.ConnectionError,
|
requests.exceptions.ConnectionError,
|
||||||
|
requests.exceptions.HTTPError,
|
||||||
) as exc:
|
) as exc:
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
time.sleep(delay)
|
wait = 2**attempt * delay # in seconds
|
||||||
|
time.sleep(wait)
|
||||||
else:
|
else:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
@@ -37,6 +43,7 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
|||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
|
@disable_hf_offline
|
||||||
def snapshot_download_w_retry(*args, **kwargs):
|
def snapshot_download_w_retry(*args, **kwargs):
|
||||||
return snapshot_download(*args, **kwargs)
|
return snapshot_download(*args, **kwargs)
|
||||||
|
|
||||||
@@ -44,19 +51,19 @@ def snapshot_download_w_retry(*args, **kwargs):
|
|||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_smollm2_135m_model():
|
def download_smollm2_135m_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
|
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_llama_68m_random_model():
|
def download_llama_68m_random_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("JackFram/llama-68m")
|
snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_qwen_2_5_half_billion_model():
|
def download_qwen_2_5_half_billion_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
|
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@@ -101,6 +108,37 @@ def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_fozzie_alpaca_dpo_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||||
|
)
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
repo_type="dataset",
|
||||||
|
revision="ea82cff",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
@disable_hf_offline
|
||||||
|
def dataset_fozzie_alpaca_dpo_dataset(
|
||||||
|
download_fozzie_alpaca_dpo_dataset,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
@disable_hf_offline
|
||||||
|
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||||
|
download_fozzie_alpaca_dpo_dataset,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
return load_dataset(
|
||||||
|
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
@@ -109,10 +147,141 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_argilla_dpo_pairs_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_tiny_shakespeare_dataset():
|
def download_tiny_shakespeare_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset")
|
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_deepseek_model_fixture():
|
||||||
|
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_huggyllama_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"huggyllama/llama-7b",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_llama_1b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"NousResearch/Llama-3.2-1B",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_llama3_8b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_llama3_8b_instruct_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"NousResearch/Meta-Llama-3-8B-Instruct",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_phi_35_mini_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_phi_3_medium_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"microsoft/Phi-3-medium-128k-instruct",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_mistral_7b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"casperhansen/mistral-7b-instruct-v0.1-awq",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_gemma_2b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"unsloth/gemma-2b-it",
|
||||||
|
revision="703fb4a",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_gemma2_9b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"mlx-community/gemma-2-9b-it-4bit",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_mlx_mistral_7b_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_llama2_model_fixture():
|
||||||
|
# download the tokenizer only
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"NousResearch/Llama-2-7b-hf",
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*token*", "config.json"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
@enable_hf_offline
|
||||||
|
def tokenizer_huggyllama(
|
||||||
|
download_huggyllama_model_fixture,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
tokenizer.pad_token = "</s>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -178,3 +347,34 @@ def cleanup_monkeypatches():
|
|||||||
module_globals = module_name_tuple[1]
|
module_globals = module_name_tuple[1]
|
||||||
for module_global in module_globals:
|
for module_global in module_globals:
|
||||||
globals().pop(module_global, None)
|
globals().pop(module_global, None)
|
||||||
|
|
||||||
|
|
||||||
|
# # pylint: disable=redefined-outer-name,unused-argument
|
||||||
|
# def test_load_fixtures(
|
||||||
|
# download_smollm2_135m_model,
|
||||||
|
# download_llama_68m_random_model,
|
||||||
|
# download_qwen_2_5_half_billion_model,
|
||||||
|
# download_tatsu_lab_alpaca_dataset,
|
||||||
|
# download_mhenrichsen_alpaca_2k_dataset,
|
||||||
|
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
|
||||||
|
# download_mlabonne_finetome_100k_dataset,
|
||||||
|
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
|
||||||
|
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
|
||||||
|
# download_fozzie_alpaca_dpo_dataset,
|
||||||
|
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
|
||||||
|
# download_argilla_dpo_pairs_dataset,
|
||||||
|
# download_tiny_shakespeare_dataset,
|
||||||
|
# download_deepseek_model_fixture,
|
||||||
|
# download_huggyllama_model_fixture,
|
||||||
|
# download_llama_1b_model_fixture,
|
||||||
|
# download_llama3_8b_model_fixture,
|
||||||
|
# download_llama3_8b_instruct_model_fixture,
|
||||||
|
# download_phi_35_mini_model_fixture,
|
||||||
|
# download_phi_3_medium_model_fixture,
|
||||||
|
# download_mistral_7b_model_fixture,
|
||||||
|
# download_gemma_2b_model_fixture,
|
||||||
|
# download_gemma2_9b_model_fixture,
|
||||||
|
# download_mlx_mistral_7b_model_fixture,
|
||||||
|
# download_llama2_model_fixture,
|
||||||
|
# ):
|
||||||
|
# pass
|
||||||
|
|||||||
@@ -10,10 +10,13 @@ from transformers import AddedToken, AutoTokenizer
|
|||||||
from axolotl.core.chat.format.chatml import format_message
|
from axolotl.core.chat.format.chatml import format_message
|
||||||
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
||||||
|
|
||||||
|
from tests.hf_offline_utils import enable_hf_offline # noqa
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", name="llama_tokenizer")
|
@pytest.fixture(scope="session", name="llama_tokenizer")
|
||||||
|
@enable_hf_offline
|
||||||
def llama_tokenizer_fixture():
|
def llama_tokenizer_fixture():
|
||||||
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
|
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ e2e tests for kd trainer support in Axolotl
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from e2e.utils import check_tensorboard, require_torch_2_5_1
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
@@ -13,6 +12,8 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="kd_min_cfg")
|
@pytest.fixture(name="kd_min_cfg")
|
||||||
def min_cfg(temp_dir):
|
def min_cfg(temp_dir):
|
||||||
|
|||||||
@@ -2,15 +2,13 @@
|
|||||||
Simple end-to-end test for Liger integration
|
Simple end-to-end test for Liger integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from e2e.utils import require_torch_2_4_1
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists
|
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
||||||
|
|
||||||
|
|
||||||
class LigerIntegrationTestCase:
|
class LigerIntegrationTestCase:
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from e2e.utils import require_vllm
|
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import require_vllm
|
||||||
|
|
||||||
|
|
||||||
class TestGRPO:
|
class TestGRPO:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -9,12 +9,13 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from e2e.utils import check_tensorboard
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user