Compare commits
4 Commits
mm_mc_chat
...
pre-commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
156fede4f7 | ||
|
|
dcbbd7af79 | ||
|
|
21bac7ce1a | ||
|
|
aaa4571826 |
7
.github/workflows/docs.yml
vendored
7
.github/workflows/docs.yml
vendored
@@ -20,12 +20,9 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
- name: Install dependencies
|
- name: install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter quartodoc
|
python3 -m pip install jupyter
|
||||||
python3 -m pip install -e . --no-deps
|
|
||||||
- name: Build autodoc
|
|
||||||
run: quartodoc build
|
|
||||||
- name: Publish to GitHub Pages (and render)
|
- name: Publish to GitHub Pages (and render)
|
||||||
uses: quarto-dev/quarto-actions/publish@v2
|
uses: quarto-dev/quarto-actions/publish@v2
|
||||||
with:
|
with:
|
||||||
|
|||||||
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -98,9 +98,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
pytest -v tests/cli/
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -173,9 +172,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
pytest -v tests/cli/
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -181,10 +181,6 @@ prepared-datasets/
|
|||||||
submit.sh
|
submit.sh
|
||||||
*.out*
|
*.out*
|
||||||
|
|
||||||
# Quartodoc generated files
|
|
||||||
objects.json
|
|
||||||
site_libs/
|
|
||||||
|
|
||||||
typings/
|
typings/
|
||||||
out/
|
out/
|
||||||
|
|
||||||
|
|||||||
@@ -97,7 +97,6 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
|||||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||||
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
|
||||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||||
|
|
||||||
## 🤝 Getting Help
|
## 🤝 Getting Help
|
||||||
|
|||||||
194
_quarto.yml
194
_quarto.yml
@@ -1,179 +1,6 @@
|
|||||||
project:
|
project:
|
||||||
type: website
|
type: website
|
||||||
|
|
||||||
quartodoc:
|
|
||||||
dir: docs/api
|
|
||||||
package: axolotl
|
|
||||||
title: API Reference
|
|
||||||
parser: google
|
|
||||||
|
|
||||||
sections:
|
|
||||||
- title: Core
|
|
||||||
desc: Core functionality for training
|
|
||||||
contents:
|
|
||||||
- train
|
|
||||||
- evaluate
|
|
||||||
- datasets
|
|
||||||
- convert
|
|
||||||
- prompt_tokenizers
|
|
||||||
- logging_config
|
|
||||||
- core.trainer_builder
|
|
||||||
- core.training_args
|
|
||||||
- core.chat.messages
|
|
||||||
- core.chat.format.chatml
|
|
||||||
- core.chat.format.llama3x
|
|
||||||
- core.chat.format.shared
|
|
||||||
- core.datasets.chat
|
|
||||||
- core.datasets.transforms.chat_builder
|
|
||||||
- title: CLI
|
|
||||||
desc: Command-line interface
|
|
||||||
contents:
|
|
||||||
- cli.main
|
|
||||||
- cli.train
|
|
||||||
- cli.evaluate
|
|
||||||
- cli.args
|
|
||||||
- cli.checks
|
|
||||||
- cli.config
|
|
||||||
- cli.inference
|
|
||||||
- cli.merge_lora
|
|
||||||
- cli.merge_sharded_fsdp_weights
|
|
||||||
- cli.preprocess
|
|
||||||
- cli.sweeps
|
|
||||||
- cli.utils
|
|
||||||
- cli.cloud.base
|
|
||||||
- cli.cloud.modal_
|
|
||||||
- title: Trainers
|
|
||||||
desc: Training implementations
|
|
||||||
contents:
|
|
||||||
- core.trainers.base
|
|
||||||
- core.trainers.trl
|
|
||||||
- core.trainers.dpo.trainer
|
|
||||||
- core.trainers.grpo.trainer
|
|
||||||
- title: Prompt Strategies
|
|
||||||
desc: Prompt formatting strategies
|
|
||||||
contents:
|
|
||||||
- prompt_strategies.base
|
|
||||||
- prompt_strategies.chat_template
|
|
||||||
- prompt_strategies.alpaca_chat
|
|
||||||
- prompt_strategies.alpaca_instruct
|
|
||||||
- prompt_strategies.alpaca_w_system
|
|
||||||
- prompt_strategies.user_defined
|
|
||||||
- prompt_strategies.llama2_chat
|
|
||||||
- prompt_strategies.completion
|
|
||||||
- prompt_strategies.input_output
|
|
||||||
- prompt_strategies.stepwise_supervised
|
|
||||||
- prompt_strategies.metharme
|
|
||||||
- prompt_strategies.orcamini
|
|
||||||
- prompt_strategies.pygmalion
|
|
||||||
- prompt_strategies.messages.chat
|
|
||||||
- prompt_strategies.dpo.chat_template
|
|
||||||
- prompt_strategies.dpo.llama3
|
|
||||||
- prompt_strategies.dpo.chatml
|
|
||||||
- prompt_strategies.dpo.zephyr
|
|
||||||
- prompt_strategies.dpo.user_defined
|
|
||||||
- prompt_strategies.dpo.passthrough
|
|
||||||
- prompt_strategies.kto.llama3
|
|
||||||
- prompt_strategies.kto.chatml
|
|
||||||
- prompt_strategies.kto.user_defined
|
|
||||||
- prompt_strategies.orpo.chat_template
|
|
||||||
- prompt_strategies.bradley_terry.llama3
|
|
||||||
- title: Kernels
|
|
||||||
desc: Low-level performance optimizations
|
|
||||||
contents:
|
|
||||||
- kernels.lora
|
|
||||||
- kernels.geglu
|
|
||||||
- kernels.swiglu
|
|
||||||
- kernels.quantize
|
|
||||||
- kernels.utils
|
|
||||||
- title: MonkeyPatches
|
|
||||||
desc: Runtime patches for model optimizations
|
|
||||||
contents:
|
|
||||||
- monkeypatch.llama_attn_hijack_flash
|
|
||||||
- monkeypatch.llama_attn_hijack_xformers
|
|
||||||
- monkeypatch.mistral_attn_hijack_flash
|
|
||||||
- monkeypatch.multipack
|
|
||||||
- monkeypatch.relora
|
|
||||||
- monkeypatch.llama_expand_mask
|
|
||||||
- monkeypatch.lora_kernels
|
|
||||||
- monkeypatch.utils
|
|
||||||
- monkeypatch.btlm_attn_hijack_flash
|
|
||||||
- monkeypatch.llama_patch_multipack
|
|
||||||
- monkeypatch.stablelm_attn_hijack_flash
|
|
||||||
- monkeypatch.trainer_fsdp_optim
|
|
||||||
- monkeypatch.transformers_fa_utils
|
|
||||||
- monkeypatch.unsloth_
|
|
||||||
- monkeypatch.attention.mllama
|
|
||||||
- monkeypatch.data.batch_dataset_fetcher
|
|
||||||
- monkeypatch.mixtral
|
|
||||||
- title: Utils
|
|
||||||
desc: Utility functions
|
|
||||||
contents:
|
|
||||||
- utils.models
|
|
||||||
- utils.tokenization
|
|
||||||
- utils.chat_templates
|
|
||||||
- utils.lora
|
|
||||||
- utils.lora_embeddings
|
|
||||||
- utils.model_shard_quant
|
|
||||||
- utils.bench
|
|
||||||
- utils.freeze
|
|
||||||
- utils.trainer
|
|
||||||
- utils.schedulers
|
|
||||||
- utils.distributed
|
|
||||||
- utils.dict
|
|
||||||
- utils.optimizers.adopt
|
|
||||||
- utils.data.pretraining
|
|
||||||
- utils.data.sft
|
|
||||||
- utils.gradient_checkpointing.unsloth
|
|
||||||
- title: Schemas
|
|
||||||
desc: Pydantic data models for Axolotl config
|
|
||||||
contents:
|
|
||||||
- utils.schemas.config
|
|
||||||
- utils.schemas.model
|
|
||||||
- utils.schemas.training
|
|
||||||
- utils.schemas.datasets
|
|
||||||
- utils.schemas.peft
|
|
||||||
- utils.schemas.trl
|
|
||||||
- utils.schemas.multimodal
|
|
||||||
- utils.schemas.integrations
|
|
||||||
- utils.schemas.enums
|
|
||||||
- utils.schemas.utils
|
|
||||||
- title: Integrations
|
|
||||||
desc: Third-party integrations and extensions
|
|
||||||
contents:
|
|
||||||
- integrations.base
|
|
||||||
- integrations.cut_cross_entropy.args
|
|
||||||
- integrations.grokfast.optimizer
|
|
||||||
- integrations.kd.trainer
|
|
||||||
- integrations.liger.args
|
|
||||||
- integrations.lm_eval.args
|
|
||||||
- integrations.spectrum.args
|
|
||||||
- title: Common
|
|
||||||
desc: Common utilities and shared functionality
|
|
||||||
contents:
|
|
||||||
- common.architectures
|
|
||||||
- common.const
|
|
||||||
- common.datasets
|
|
||||||
- title: Models
|
|
||||||
desc: Custom model implementations
|
|
||||||
contents:
|
|
||||||
- models.mamba.modeling_mamba
|
|
||||||
- title: Data Processing
|
|
||||||
desc: Data processing utilities
|
|
||||||
contents:
|
|
||||||
- utils.collators.core
|
|
||||||
- utils.collators.batching
|
|
||||||
- utils.collators.mamba
|
|
||||||
- utils.collators.mm_chat
|
|
||||||
- utils.samplers.multipack
|
|
||||||
- title: Callbacks
|
|
||||||
desc: Training callbacks
|
|
||||||
contents:
|
|
||||||
- utils.callbacks.perplexity
|
|
||||||
- utils.callbacks.profiler
|
|
||||||
- utils.callbacks.lisa
|
|
||||||
- utils.callbacks.mlflow_
|
|
||||||
- utils.callbacks.comet_
|
|
||||||
|
|
||||||
website:
|
website:
|
||||||
title: "Axolotl"
|
title: "Axolotl"
|
||||||
description: "We make fine-tuning accessible, scalable, and fun"
|
description: "We make fine-tuning accessible, scalable, and fun"
|
||||||
@@ -208,8 +35,6 @@ website:
|
|||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/config.qmd
|
- docs/config.qmd
|
||||||
- text: "API Reference"
|
|
||||||
href: docs/api
|
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
@@ -255,22 +80,3 @@ format:
|
|||||||
theme: darkly
|
theme: darkly
|
||||||
css: styles.css
|
css: styles.css
|
||||||
toc: true
|
toc: true
|
||||||
# Enable better handling of line breaks in markdown
|
|
||||||
preserve-tabs: true
|
|
||||||
html-math-method: mathjax
|
|
||||||
# Improved markdown processing options
|
|
||||||
md-extensions:
|
|
||||||
- markdown_it
|
|
||||||
- def_list
|
|
||||||
- attr_list
|
|
||||||
- fenced_divs
|
|
||||||
- tables
|
|
||||||
- html_admonition
|
|
||||||
- lineblocks
|
|
||||||
- fancy_lists
|
|
||||||
# Control whitespace handling
|
|
||||||
whitespace: preserve
|
|
||||||
# Process newlines in paragraphs
|
|
||||||
wrap: preserve
|
|
||||||
# Better line break handling
|
|
||||||
preserve-linebreaks: true
|
|
||||||
|
|||||||
@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -3,10 +3,9 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
|
||||||
|
|||||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1,4 +1,2 @@
|
|||||||
/.quarto/
|
/.quarto/
|
||||||
_site/
|
_site/
|
||||||
/api/*.qmd
|
|
||||||
/api/*.html
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "Command Line Interface (CLI)"
|
title: "CLI Reference"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
|
|||||||
@@ -32,9 +32,6 @@ tokenizer_legacy:
|
|||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
shrink_embeddings:
|
shrink_embeddings:
|
||||||
# Whether to load the model with randomly initialized weights. Useful for
|
|
||||||
# pre-training a model from scratch or debugging purposes.
|
|
||||||
random_init_weights:
|
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -466,7 +463,6 @@ auto_find_batch_size: # Optional[bool]
|
|||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
|
|
||||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||||
|
|
||||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||||
@@ -507,58 +503,36 @@ lr_div_factor: # Learning rate div factor
|
|||||||
|
|
||||||
# Specify optimizer
|
# Specify optimizer
|
||||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||||
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||||
#
|
#
|
||||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||||
# in the examples/ for your model and fine-tuning use case.
|
# in the examples/ for your model and fine-tuning use case.
|
||||||
#
|
#
|
||||||
# Valid values for 'optimizer' include:
|
# Valid values for 'optimizer' include:
|
||||||
|
# - adamw_hf
|
||||||
# - adamw_torch
|
# - adamw_torch
|
||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
# - adamw_torch_npu_fused
|
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
# - adamw_torch_4bit
|
|
||||||
# - ademamix
|
|
||||||
# - sgd
|
# - sgd
|
||||||
# - adagrad
|
# - adagrad
|
||||||
# - adamw_bnb_8bit
|
# - adamw_bnb_8bit
|
||||||
# - adamw_8bit # alias for adamw_bnb_8bit
|
|
||||||
# - ademamix_8bit
|
|
||||||
# - lion_8bit
|
# - lion_8bit
|
||||||
# - lion_32bit
|
# - lion_32bit
|
||||||
# - paged_adamw_32bit
|
# - paged_adamw_32bit
|
||||||
# - paged_adamw_8bit
|
# - paged_adamw_8bit
|
||||||
# - paged_ademamix_32bit
|
|
||||||
# - paged_ademamix_8bit
|
|
||||||
# - paged_lion_32bit
|
# - paged_lion_32bit
|
||||||
# - paged_lion_8bit
|
# - paged_lion_8bit
|
||||||
# - rmsprop
|
|
||||||
# - rmsprop_bnb
|
|
||||||
# - rmsprop_bnb_8bit
|
|
||||||
# - rmsprop_bnb_32bit
|
|
||||||
# - galore_adamw
|
# - galore_adamw
|
||||||
# - galore_adamw_8bit
|
# - galore_adamw_8bit
|
||||||
# - galore_adafactor
|
# - galore_adafactor
|
||||||
# - galore_adamw_layerwise
|
# - galore_adamw_layerwise
|
||||||
# - galore_adamw_8bit_layerwise
|
# - galore_adamw_8bit_layerwise
|
||||||
# - galore_adafactor_layerwise
|
# - galore_adafactor_layerwise
|
||||||
# - lomo
|
|
||||||
# - adalomo
|
|
||||||
# - grokadamw
|
|
||||||
# - schedule_free_adamw
|
|
||||||
# - schedule_free_sgd
|
|
||||||
# - apollo_adamw
|
|
||||||
# - apollo_adamw_layerwise
|
|
||||||
#
|
|
||||||
# Additional custom optimizers include:
|
|
||||||
# - optimi_adamw
|
|
||||||
# - ao_adamw_8bit
|
|
||||||
# - ao_adamw_fp8
|
|
||||||
optimizer:
|
optimizer:
|
||||||
# Dictionary of arguments to pass to the optimizer
|
# Dictionary of arguments to pass to the optimizer
|
||||||
optim_args:
|
optim_args:
|
||||||
@@ -610,14 +584,6 @@ resume_from_checkpoint:
|
|||||||
# Be careful with this being turned on between different models.
|
# Be careful with this being turned on between different models.
|
||||||
auto_resume_from_checkpoints: false
|
auto_resume_from_checkpoints: false
|
||||||
|
|
||||||
## Multimodal section
|
|
||||||
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
|
||||||
# Will read from model/processor config if not set.
|
|
||||||
image_size:
|
|
||||||
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
|
||||||
image_resize_algorithm: 'bilinear'
|
|
||||||
## End of multimodal section
|
|
||||||
|
|
||||||
# Don't mess with this, it's here for accelerate and torchrun
|
# Don't mess with this, it's here for accelerate and torchrun
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|
||||||
@@ -651,14 +617,6 @@ ddp_timeout:
|
|||||||
ddp_bucket_cap_mb:
|
ddp_bucket_cap_mb:
|
||||||
ddp_broadcast_buffers:
|
ddp_broadcast_buffers:
|
||||||
|
|
||||||
# Sequence parallelism
|
|
||||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
|
||||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
|
||||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
|
||||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
|
||||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
|
||||||
sequence_parallel_degree:
|
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ description: How datasets are processed
|
|||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||||
the [dataset format](dataset-formats) and prompt strategies to:
|
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
||||||
|
|
||||||
- parse the dataset based on the *dataset format*
|
- parse the dataset based on the *dataset format*
|
||||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||||
|
|||||||
@@ -103,7 +103,8 @@ This uses the same tags as the [`main` image](#sec-main-tags).
|
|||||||
|
|
||||||
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
||||||
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
||||||
- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service.
|
- `PUBLIC_KEY`: Add a public key for the SSH service.
|
||||||
|
- `SSH_KEY`: Add a private key for the SSH service.
|
||||||
|
|
||||||
#### Volume mounts
|
#### Volume mounts
|
||||||
|
|
||||||
|
|||||||
@@ -37,10 +37,6 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||||
|
|
||||||
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
|
|
||||||
|
|
||||||
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
|
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
@@ -1,171 +1,28 @@
|
|||||||
---
|
# MultiModal / Vision Language Models (BETA)
|
||||||
title: MultiModal / Vision Language Models (BETA)
|
|
||||||
format:
|
|
||||||
html:
|
|
||||||
toc: true
|
|
||||||
toc-depth: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## Supported Models
|
### Supported Models
|
||||||
|
|
||||||
- [Mllama](#sec-mllama)
|
- Mllama, i.e. llama with vision models
|
||||||
- [Pixtral](#sec-pixtral)
|
|
||||||
- [Llava-1.5](#sec-llava-15)
|
|
||||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
|
||||||
- [Gemma-3](#sec-gemma-3)
|
|
||||||
- [Qwen2-VL](#sec-qwen2-vl)
|
|
||||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
|
||||||
|
|
||||||
## Usage
|
### Usage
|
||||||
|
|
||||||
Multimodal support is limited and doesn't have full feature parity.
|
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
|
||||||
|
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
|
||||||
Here are the hyperparams you'll need to use to finetune a multimodal model.
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
skip_prepare_dataset: true
|
skip_prepare_dataset: true
|
||||||
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
|
||||||
sample_packing: false # not yet supported with multimodal
|
|
||||||
|
|
||||||
chat_template: # see in next section
|
chat_template: llama3_2_vision
|
||||||
|
|
||||||
# example dataset
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
# (optional) if doing lora, only finetune the Language model,
|
# only finetune the Language model, leave the vision model and vision tower frozen
|
||||||
# leave the vision model and vision tower frozen
|
|
||||||
# load_in_8bit: true
|
|
||||||
adapter: lora
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
# (optional) if you want to resize images to a set size
|
|
||||||
image_size: 512
|
|
||||||
image_resize_algorithm: bilinear
|
|
||||||
```
|
|
||||||
|
|
||||||
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
|
||||||
|
|
||||||
::: {.callout-warning}
|
|
||||||
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
|
||||||
:::
|
|
||||||
|
|
||||||
### Mllama {#sec-mllama}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
|
|
||||||
|
|
||||||
chat_template: llama3_2_vision
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pixtral {#sec-pixtral}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: mistralai/Pixtral-12B-2409
|
|
||||||
|
|
||||||
chat_template: pixtral
|
|
||||||
```
|
|
||||||
|
|
||||||
### Llava-1.5 {#sec-llava-15}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: llava-hf/llava-1.5-7b-hf
|
|
||||||
|
|
||||||
chat_template: llava
|
|
||||||
```
|
|
||||||
|
|
||||||
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
|
||||||
|
|
||||||
chat_template: mistral_v7_tekken
|
|
||||||
```
|
|
||||||
|
|
||||||
### Gemma-3 {#sec-gemma-3}
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
The Gemma3-1B model is a text-only model, so please train as regular text model.
|
|
||||||
:::
|
|
||||||
|
|
||||||
For multi-modal 4B/12B/27B models, use the following config:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: google/gemma-3-4b-it
|
|
||||||
|
|
||||||
chat_template: gemma3
|
|
||||||
```
|
|
||||||
|
|
||||||
### Qwen2-VL {#sec-qwen2-vl}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
|
||||||
|
|
||||||
chat_template: qwen2_vl
|
|
||||||
```
|
|
||||||
|
|
||||||
### Qwen2.5-VL {#sec-qwen25-vl}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
|
||||||
|
|
||||||
chat_template: qwen2_vl # same as qwen2-vl
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dataset Format
|
|
||||||
|
|
||||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
|
||||||
|
|
||||||
- A message is a list of `role` and `content`.
|
|
||||||
- `role` can be `system`, `user`, `assistant`, etc.
|
|
||||||
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
|
|
||||||
|
|
||||||
::: {.callout-note}
|
|
||||||
For backwards compatibility:
|
|
||||||
|
|
||||||
- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key.
|
|
||||||
- If `content` is a string, it will be converted to a list with `type` as `text`.
|
|
||||||
:::
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
|
|
||||||
|
|
||||||
- `"path": "/path/to/image.jpg"`
|
|
||||||
- `"url": "https://example.com/image.jpg"`
|
|
||||||
- `"base64": "..."`
|
|
||||||
- `"image": PIL.Image`
|
|
||||||
:::
|
|
||||||
|
|
||||||
Here is an example of a multi-modal dataset:
|
|
||||||
```json
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "You are a helpful assistant."}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
|
||||||
{"type": "text", "text": "Describe this image in detail."}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "The image is a bee."}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,90 +0,0 @@
|
|||||||
---
|
|
||||||
title: Sequence Parallelism
|
|
||||||
description: Train with long sequences split across multiple GPUs.
|
|
||||||
---
|
|
||||||
|
|
||||||
# Sequence Parallelism
|
|
||||||
|
|
||||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
|
||||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
|
||||||
GPU processes a different portion of the sequence, and the results are aggregated
|
|
||||||
through a ring communication pattern.
|
|
||||||
|
|
||||||
## When to Use Sequence Parallelism
|
|
||||||
|
|
||||||
Use sequence parallelism when:
|
|
||||||
|
|
||||||
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
|
||||||
- You have multiple GPUs available
|
|
||||||
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
To enable sequence parallelism, add the following to your configuration file:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Set to a divisor (> 1) of the number of GPUs available
|
|
||||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
|
||||||
```
|
|
||||||
|
|
||||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
|
||||||
|
|
||||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
|
||||||
- With 4 GPUs, valid values would be 2 or 4
|
|
||||||
|
|
||||||
## Implementation Details
|
|
||||||
|
|
||||||
When sequence parallelism is enabled:
|
|
||||||
|
|
||||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
|
||||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
|
||||||
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
|
||||||
4. The trainer uses special ring communication patterns for attention operations
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
To use sequence parallelism, you need:
|
|
||||||
|
|
||||||
- Multiple GPUs (at least 2)
|
|
||||||
- The `ring-flash-attn` package. Install with:
|
|
||||||
- `pip install axolotl[ring-flash-attn]` (preferred)
|
|
||||||
- `pip install ring-flash-attn>=0.1.4`
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
|
||||||
- May have a small performance overhead due to communication between GPUs
|
|
||||||
|
|
||||||
## Example
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Example config with sequence parallelism
|
|
||||||
base_model: meta-llama/Llama-3-8B-Instruct
|
|
||||||
sequence_len: 8192
|
|
||||||
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
|
||||||
flash_attention: true # Required with sequence parallelism
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
|
||||||
into 2 subsequences of length 4096 across 2 GPUs.
|
|
||||||
|
|
||||||
## Sample Packing with Sequence Parallelism
|
|
||||||
|
|
||||||
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
|
||||||
|
|
||||||
1. Samples are first packed together
|
|
||||||
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
|
||||||
3. Position IDs are automatically adjusted to maintain proper relative positions
|
|
||||||
|
|
||||||
## Effect on Batch Size
|
|
||||||
|
|
||||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
|
||||||
|
|
||||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
|
||||||
- The number of batches processed per step decreases
|
|
||||||
|
|
||||||
For example:
|
|
||||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
|
||||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
|
||||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
base_model: CohereForAI/c4ai-command-r7b-12-2024
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: cohere
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
base_model: google/gemma-3-4b-it
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: gemma3
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
base_model: google/gemma-3-1b-it
|
|
||||||
# optionally might have model_type or tokenizer_type
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: gemma3_text
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
base_model: llava-hf/llava-1.5-7b-hf
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: llava
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: mistral_v7_tekken
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
base_model: mistral-community/pixtral-12b
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: pixtral
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <pad>
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: qwen2_vl
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -2,5 +2,3 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
quartodoc
|
|
||||||
jupyter
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
bitsandbytes==0.45.3
|
bitsandbytes==0.45.3
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
flash-attn==2.7.4.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
liger-kernel==0.5.3
|
liger-kernel==0.5.3
|
||||||
@@ -12,7 +13,7 @@ liger-kernel==0.5.3
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.15.0
|
peft==0.15.0
|
||||||
transformers==4.50.0
|
transformers==4.49.0
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.5.2
|
accelerate==1.5.2
|
||||||
datasets==3.4.1
|
datasets==3.4.1
|
||||||
@@ -35,7 +36,6 @@ einops
|
|||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4,<=2.0.1
|
numpy>=1.24.4,<=2.0.1
|
||||||
|
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
|
|||||||
315
requirements_env.txt
Normal file
315
requirements_env.txt
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
accelerate==0.34.1
|
||||||
|
addict==2.4.0
|
||||||
|
aiofiles==23.2.1
|
||||||
|
aiohttp==3.9.0
|
||||||
|
aiosignal==1.3.1
|
||||||
|
aiostream==0.5.2
|
||||||
|
alembic==1.13.1
|
||||||
|
annotated-types==0.6.0
|
||||||
|
annoy==1.17.3
|
||||||
|
ansible==6.7.0
|
||||||
|
ansible-core==2.13.13
|
||||||
|
ansible-vault==2.1.0
|
||||||
|
anyio==3.7.1
|
||||||
|
appdirs==1.4.4
|
||||||
|
art==6.0
|
||||||
|
asgiref==3.7.2
|
||||||
|
async-timeout==4.0.2
|
||||||
|
attrdict==2.0.1
|
||||||
|
attrs==22.2.0
|
||||||
|
awscli==1.32.75
|
||||||
|
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
||||||
|
backoff==2.2.1
|
||||||
|
base58==2.1.1
|
||||||
|
beartype==0.17.2
|
||||||
|
bitnet==0.2.1
|
||||||
|
bitsandbytes==0.42.0
|
||||||
|
bittensor==6.7.0
|
||||||
|
black==23.7.0
|
||||||
|
blinker==1.7.0
|
||||||
|
boto3==1.34.75
|
||||||
|
botocore==1.34.75
|
||||||
|
cachetools==5.3.3
|
||||||
|
cachy==0.1.1
|
||||||
|
certifi==2023.7.22
|
||||||
|
cffi==1.16.0
|
||||||
|
cfgv==3.3.1
|
||||||
|
chai-guanaco==1.2.4
|
||||||
|
charset-normalizer==3.2.0
|
||||||
|
cleo==0.6.8
|
||||||
|
click==8.1.7
|
||||||
|
cloudpickle==2.0.0
|
||||||
|
cohere==4.11.2
|
||||||
|
colorama==0.4.4
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
CoLT5-attention==0.10.20
|
||||||
|
contextlib2==21.6.0
|
||||||
|
contourpy==1.2.0
|
||||||
|
cryptography==41.0.3
|
||||||
|
cycler==0.12.1
|
||||||
|
cytoolz==0.12.3
|
||||||
|
databricks-cli==0.18.0
|
||||||
|
dataclasses-json==0.5.7
|
||||||
|
datasets==2.11.0
|
||||||
|
ddt==1.6.0
|
||||||
|
decorator==5.1.1
|
||||||
|
deepspeed==0.15.0
|
||||||
|
# Editable Git install with no remote (dialogpt==0.1)
|
||||||
|
-e /Users/wing/Projects/ml/dialogpt/src
|
||||||
|
dill==0.3.6
|
||||||
|
distlib==0.3.6
|
||||||
|
docker==7.0.0
|
||||||
|
docker-pycreds==0.4.0
|
||||||
|
docstring-parser==0.15
|
||||||
|
docutils==0.16
|
||||||
|
ecdsa==0.18.0
|
||||||
|
einops==0.7.0
|
||||||
|
einops-exts==0.0.4
|
||||||
|
einx==0.1.3
|
||||||
|
entrypoints==0.4
|
||||||
|
eth-hash==0.6.0
|
||||||
|
eth-keys==0.5.0
|
||||||
|
eth-typing==4.0.0
|
||||||
|
eth-utils==2.3.1
|
||||||
|
evaluate==0.4.0
|
||||||
|
exceptiongroup==1.1.1
|
||||||
|
fastapi==0.109.2
|
||||||
|
fastcore==1.5.29
|
||||||
|
ffmpy==0.4.0
|
||||||
|
filelock==3.12.2
|
||||||
|
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
||||||
|
fire==0.5.0
|
||||||
|
first==2.0.2
|
||||||
|
flake8==7.0.0
|
||||||
|
Flask==3.0.1
|
||||||
|
fonttools==4.47.2
|
||||||
|
frozendict==2.4.1
|
||||||
|
frozenlist==1.3.3
|
||||||
|
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||||
|
fsspec==2023.6.0
|
||||||
|
fuzzywuzzy==0.18.0
|
||||||
|
gitdb==4.0.10
|
||||||
|
GitPython==3.1.31
|
||||||
|
google-pasta==0.2.0
|
||||||
|
gradio==4.42.0
|
||||||
|
gradio_client==1.3.0
|
||||||
|
greenlet==2.0.2
|
||||||
|
grpclib==0.4.7
|
||||||
|
gunicorn==21.2.0
|
||||||
|
h11==0.14.0
|
||||||
|
h2==4.1.0
|
||||||
|
hpack==4.0.0
|
||||||
|
httpcore==0.17.3
|
||||||
|
httpx==0.24.1
|
||||||
|
huggingface-hub==0.23.4
|
||||||
|
humanfriendly==10.0
|
||||||
|
hyperframe==6.0.1
|
||||||
|
identify==2.5.24
|
||||||
|
idna==3.4
|
||||||
|
immutables==0.20
|
||||||
|
importlib-metadata==6.7.0
|
||||||
|
importlib-resources==6.1.1
|
||||||
|
inflection==0.5.1
|
||||||
|
iniconfig==2.0.0
|
||||||
|
itsdangerous==2.1.2
|
||||||
|
Jinja2==3.1.2
|
||||||
|
jmespath==1.0.1
|
||||||
|
joblib==1.3.2
|
||||||
|
jsonlines==3.1.0
|
||||||
|
jsonschema==2.6.0
|
||||||
|
kiwisolver==1.4.5
|
||||||
|
langchain==0.0.144
|
||||||
|
Levenshtein==0.24.0
|
||||||
|
libcst==1.1.0
|
||||||
|
liger-kernel==0.0.0
|
||||||
|
lion-pytorch==0.1.2
|
||||||
|
llama-cpp-python==0.1.36
|
||||||
|
llvmlite==0.40.1
|
||||||
|
local-attention==1.9.0
|
||||||
|
loguru==0.7.0
|
||||||
|
Mako==1.3.2
|
||||||
|
Markdown==3.5.2
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
markdown2==2.4.10
|
||||||
|
MarkupSafe==2.1.2
|
||||||
|
marshmallow==3.19.0
|
||||||
|
marshmallow-enum==1.5.1
|
||||||
|
matplotlib==3.8.2
|
||||||
|
mccabe==0.7.0
|
||||||
|
mdurl==0.1.2
|
||||||
|
MEGABYTE-pytorch==0.0.7
|
||||||
|
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
||||||
|
mlflow==2.10.0
|
||||||
|
modal==0.62.77
|
||||||
|
more-itertools==10.2.0
|
||||||
|
mpmath==1.2.1
|
||||||
|
msgpack==1.0.7
|
||||||
|
msgpack-numpy-opentensor==0.5.0
|
||||||
|
multidict==6.0.4
|
||||||
|
multiprocess==0.70.14
|
||||||
|
munch==2.5.0
|
||||||
|
mypy==1.3.0
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
nest-asyncio==1.6.0
|
||||||
|
netaddr==0.10.1
|
||||||
|
networkx==3.0rc1
|
||||||
|
nh3==0.2.14
|
||||||
|
nodeenv==1.8.0
|
||||||
|
nomic==2.0.2
|
||||||
|
numba==0.57.1
|
||||||
|
numexpr==2.8.4
|
||||||
|
numpy==1.24.4
|
||||||
|
oauthlib==3.2.2
|
||||||
|
openai==0.27.4
|
||||||
|
openapi==1.1.0
|
||||||
|
openapi-schema-pydantic==1.2.4
|
||||||
|
optimum==1.8.6
|
||||||
|
orjson==3.10.7
|
||||||
|
packaging==23.1
|
||||||
|
pandas==2.0.0
|
||||||
|
parameterized==0.9.0
|
||||||
|
password-strength==0.0.3.post2
|
||||||
|
pastel==0.1.1
|
||||||
|
pathos==0.3.0
|
||||||
|
pathspec==0.11.1
|
||||||
|
pathtools==0.1.2
|
||||||
|
peft==0.11.1
|
||||||
|
pendulum==3.0.0
|
||||||
|
Pillow==9.5.0
|
||||||
|
pip-tools==1.11.0
|
||||||
|
platformdirs==3.2.0
|
||||||
|
pluggy==1.4.0
|
||||||
|
poetry==0.7.1
|
||||||
|
pox==0.3.2
|
||||||
|
ppft==1.7.6.6
|
||||||
|
pre-commit==3.3.2
|
||||||
|
prettytable==3.10.0
|
||||||
|
prompt-toolkit==3.0.39
|
||||||
|
protobuf==3.20.2
|
||||||
|
protobuf3-to-dict==0.1.5
|
||||||
|
psutil==5.9.5
|
||||||
|
psycopg==3.1.18
|
||||||
|
PuLP==2.8.0
|
||||||
|
py==1.11.0
|
||||||
|
py-bip39-bindings==0.1.11
|
||||||
|
py-cpuinfo==9.0.0
|
||||||
|
py-ed25519-zebra-bindings==1.0.1
|
||||||
|
py-sr25519-bindings==0.2.0
|
||||||
|
pyarrow==11.0.0
|
||||||
|
pyasn1==0.6.0
|
||||||
|
pycodestyle==2.11.1
|
||||||
|
pycparser==2.21
|
||||||
|
pycryptodome==3.20.0
|
||||||
|
pydantic==2.5.3
|
||||||
|
pydantic_core==2.14.6
|
||||||
|
pydub==0.25.1
|
||||||
|
pyfiglet==0.8.post1
|
||||||
|
pyflakes==3.2.0
|
||||||
|
Pygments==2.15.1
|
||||||
|
PyJWT==2.8.0
|
||||||
|
pylev==1.4.0
|
||||||
|
PyNaCl==1.5.0
|
||||||
|
pynvml==11.5.0
|
||||||
|
pyparsing==2.4.7
|
||||||
|
pyrsistent==0.14.11
|
||||||
|
pytest==8.0.2
|
||||||
|
pytest-asyncio==0.23.4
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
python-Levenshtein==0.24.0
|
||||||
|
python-multipart==0.0.9
|
||||||
|
pytz==2023.3
|
||||||
|
PyYAML==6.0.1
|
||||||
|
querystring-parser==1.2.4
|
||||||
|
rapidfuzz==3.6.1
|
||||||
|
regex==2023.6.3
|
||||||
|
requests==2.31.0
|
||||||
|
requests-toolbelt==0.8.0
|
||||||
|
resolvelib==0.8.1
|
||||||
|
responses==0.18.0
|
||||||
|
retry==0.9.2
|
||||||
|
rich==13.7.0
|
||||||
|
rsa==4.7.2
|
||||||
|
ruff==0.6.3
|
||||||
|
s3transfer==0.10.1
|
||||||
|
safetensors==0.4.5
|
||||||
|
sagemaker==2.148.0
|
||||||
|
scalecodec==1.2.7
|
||||||
|
schedulefree==1.2.1
|
||||||
|
schema==0.7.5
|
||||||
|
scikit-learn==1.4.0
|
||||||
|
scipy==1.9.3
|
||||||
|
seaborn==0.13.2
|
||||||
|
semantic-version==2.10.0
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
sentry-sdk==1.19.1
|
||||||
|
setproctitle==1.3.2
|
||||||
|
shellingham==1.5.4
|
||||||
|
shortuuid==1.0.11
|
||||||
|
shtab==1.6.5
|
||||||
|
sigtools==4.0.1
|
||||||
|
six==1.16.0
|
||||||
|
skypilot==0.4.1
|
||||||
|
smdebug-rulesconfig==1.0.1
|
||||||
|
smmap==5.0.0
|
||||||
|
sniffio==1.3.0
|
||||||
|
SQLAlchemy==1.4.47
|
||||||
|
sqlparse==0.4.4
|
||||||
|
starlette==0.36.3
|
||||||
|
substrate-interface==1.5.2
|
||||||
|
svgwrite==1.4.3
|
||||||
|
sympy==1.11.1
|
||||||
|
synchronicity==0.6.7
|
||||||
|
tabulate==0.9.0
|
||||||
|
tblib==1.7.0
|
||||||
|
tenacity==8.2.2
|
||||||
|
tensor-parallel==2.0.0
|
||||||
|
termcolor==2.2.0
|
||||||
|
text2art==0.2.0
|
||||||
|
threadpoolctl==3.2.0
|
||||||
|
tiktoken==0.6.0
|
||||||
|
time-machine==2.14.1
|
||||||
|
timm==0.9.16
|
||||||
|
tokenizers==0.19.1
|
||||||
|
tokenmonster==1.1.12
|
||||||
|
toml==0.9.6
|
||||||
|
tomli==2.0.1
|
||||||
|
tomlkit==0.12.0
|
||||||
|
toolz==0.12.1
|
||||||
|
torch==2.2.0
|
||||||
|
torchdata==0.6.1
|
||||||
|
torchdiffeq==0.2.3
|
||||||
|
TorchFix==0.4.0
|
||||||
|
torchtext==0.15.2
|
||||||
|
torchvision==0.17.0
|
||||||
|
tqdm==4.66.2
|
||||||
|
transformers==4.44.2
|
||||||
|
trl==0.9.6
|
||||||
|
typer==0.12.5
|
||||||
|
types-certifi==2021.10.8.3
|
||||||
|
types-requests==2.31.0.20240125
|
||||||
|
types-setuptools==69.0.0.20240125
|
||||||
|
types-toml==0.10.8.7
|
||||||
|
typing==3.7.4.3
|
||||||
|
typing-inspect==0.8.0
|
||||||
|
typing_extensions==4.9.0
|
||||||
|
tyro==0.5.18
|
||||||
|
tzdata==2023.3
|
||||||
|
unique-names-generator==1.0.2
|
||||||
|
urllib3==2.2.2
|
||||||
|
uvicorn==0.22.0
|
||||||
|
vector_quantize_pytorch==1.14.1
|
||||||
|
virtualenv==20.23.0
|
||||||
|
voyager==2.0.2
|
||||||
|
wandb==0.16.2
|
||||||
|
watchfiles==0.21.0
|
||||||
|
wavedrom==2.0.3.post3
|
||||||
|
wcwidth==0.2.6
|
||||||
|
websocket-client==1.7.0
|
||||||
|
websockets==12.0
|
||||||
|
Werkzeug==3.0.1
|
||||||
|
wonderwords==2.2.0
|
||||||
|
xxhash==3.2.0
|
||||||
|
yarl==1.8.2
|
||||||
|
zetascale==2.2.7
|
||||||
|
zipp==3.15.0
|
||||||
22
setup.py
22
setup.py
@@ -16,7 +16,13 @@ def parse_requirements():
|
|||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
is_extras = (
|
||||||
|
"flash-attn" in line
|
||||||
|
or "flash-attention" in line
|
||||||
|
or "deepspeed" in line
|
||||||
|
or "mamba-ssm" in line
|
||||||
|
or "lion-pytorch" in line
|
||||||
|
)
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
_, url = line.split()
|
_, url = line.split()
|
||||||
@@ -33,6 +39,7 @@ def parse_requirements():
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"triton",
|
"triton",
|
||||||
"mamba-ssm",
|
"mamba-ssm",
|
||||||
|
"flash-attn",
|
||||||
"xformers",
|
"xformers",
|
||||||
"autoawq",
|
"autoawq",
|
||||||
"liger-kernel",
|
"liger-kernel",
|
||||||
@@ -117,8 +124,9 @@ setup(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
"flash-attn": [
|
||||||
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
"flash-attn==2.7.4.post1",
|
||||||
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.4",
|
"deepspeed==0.16.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
@@ -133,15 +141,15 @@ setup(
|
|||||||
"mlflow": [
|
"mlflow": [
|
||||||
"mlflow",
|
"mlflow",
|
||||||
],
|
],
|
||||||
|
"lion-pytorch": [
|
||||||
|
"lion-pytorch==0.1.2",
|
||||||
|
],
|
||||||
"galore": [
|
"galore": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
],
|
],
|
||||||
"apollo": [
|
|
||||||
"apollo-torch",
|
|
||||||
],
|
|
||||||
"optimizers": [
|
"optimizers": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
"apollo-torch",
|
"lion-pytorch==0.1.2",
|
||||||
"lomo-optim==0.1.1",
|
"lomo-optim==0.1.1",
|
||||||
"torch-optimi==0.2.1",
|
"torch-optimi==0.2.1",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def do_inference(
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Inference-specific CLI arguments.
|
cli_args: Inference-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
@@ -151,7 +151,7 @@ def do_inference_gradio(
|
|||||||
"""
|
"""
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from axolotl.cli.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
|
||||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
LOG.info("Running merge of LoRA with base model...")
|
LOG.info("Running merge of LoRA with base model...")
|
||||||
@@ -44,9 +44,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
)
|
)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
if processor:
|
|
||||||
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -17,14 +17,13 @@ from axolotl.cli.config import load_cfg
|
|||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
|
||||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||||
@@ -34,9 +33,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Training-specific CLI arguments.
|
cli_args: Training-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
|
||||||
set_pytorch_cuda_alloc_conf()
|
|
||||||
|
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
@@ -48,13 +44,16 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
del model, tokenizer, trainer
|
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
|
del model
|
||||||
|
del tokenizer
|
||||||
|
del trainer
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
|
|||||||
@@ -13,16 +13,11 @@ from typing import Any, Callable, Type, Union, get_args, get_origin
|
|||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import (
|
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
PreTrainedModel,
|
|
||||||
PreTrainedTokenizer,
|
|
||||||
PreTrainedTokenizerFast,
|
|
||||||
ProcessorMixin,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -300,13 +295,9 @@ def load_model_and_tokenizer(
|
|||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
) -> tuple[
|
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
||||||
PreTrainedModel,
|
|
||||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
|
||||||
ProcessorMixin | None,
|
|
||||||
]:
|
|
||||||
"""
|
"""
|
||||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
||||||
config.
|
config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -314,7 +305,7 @@ def load_model_and_tokenizer(
|
|||||||
inference: Boolean denoting inference mode.
|
inference: Boolean denoting inference mode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
`transformers` model and tokenizer.
|
||||||
"""
|
"""
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
@@ -322,9 +313,4 @@ def load_model_and_tokenizer(
|
|||||||
LOG.info("loading model...")
|
LOG.info("loading model...")
|
||||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||||
|
|
||||||
processor = None
|
return model, tokenizer
|
||||||
if cfg.is_multimodal:
|
|
||||||
LOG.info("loading processor...")
|
|
||||||
processor = load_processor(cfg, tokenizer)
|
|
||||||
|
|
||||||
return model, tokenizer, processor
|
|
||||||
|
|||||||
@@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
"""Builder for the training args and trainer"""
|
"""
|
||||||
|
Builder for the training args and trainer
|
||||||
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
import importlib
|
||||||
@@ -36,7 +38,7 @@ from transformers import (
|
|||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers import (
|
from axolotl.core.trainers.base import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
@@ -60,7 +62,6 @@ from axolotl.core.training_args import (
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
from axolotl.processing_strategies import get_processing_strategy
|
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
@@ -84,8 +85,8 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
@@ -748,12 +749,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.accelerator_config
|
self.cfg.accelerator_config
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.image_size:
|
|
||||||
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
|
||||||
if self.cfg.image_resize_algorithm:
|
|
||||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
|
||||||
self.cfg.image_resize_algorithm
|
|
||||||
)
|
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
if self.cfg.kd_alpha is not None:
|
if self.cfg.kd_alpha is not None:
|
||||||
@@ -769,10 +764,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kd_top_k_before_softmax
|
self.cfg.kd_top_k_before_softmax
|
||||||
)
|
)
|
||||||
|
|
||||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
|
||||||
self.cfg.sequence_parallel_degree
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
elif self.cfg.process_reward_model:
|
elif self.cfg.process_reward_model:
|
||||||
@@ -856,10 +847,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if (
|
if self.cfg.pretraining_sample_concatenation is False:
|
||||||
self.cfg.pretraining_sample_concatenation is False
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
or self.cfg.micro_batch_size > 1
|
if self.cfg.micro_batch_size > 1:
|
||||||
):
|
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -887,7 +877,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
elif (
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
and self.cfg.flash_attention is not True
|
and self.cfg.flash_attention is not True
|
||||||
):
|
):
|
||||||
@@ -897,13 +889,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
if self.cfg.processor_type and self.processor:
|
if self.cfg.processor_type and self.processor:
|
||||||
collator = MultiModalChatDataCollator
|
collator = MultiModalChatDataCollator
|
||||||
kwargs["processing_strategy"] = get_processing_strategy(
|
kwargs["processor"] = self.processor
|
||||||
self.processor,
|
kwargs["chat_template"] = training_args.chat_template
|
||||||
training_args.chat_template,
|
|
||||||
self.cfg.chat_template,
|
|
||||||
image_size=training_args.image_size,
|
|
||||||
image_resize_algorithm=training_args.image_resize_algorithm,
|
|
||||||
)
|
|
||||||
elif self.cfg.batch_flattening:
|
elif self.cfg.batch_flattening:
|
||||||
collator = DataCollatorWithFlattening
|
collator = DataCollatorWithFlattening
|
||||||
collator_args.pop(0)
|
collator_args.pop(0)
|
||||||
@@ -923,8 +910,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
|
||||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
"""Init for axolotl.core.trainers"""
|
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from .base import AxolotlTrainer
|
|
||||||
from .dpo.trainer import AxolotlDPOTrainer
|
|
||||||
from .grpo.trainer import AxolotlGRPOTrainer
|
|
||||||
from .mamba import AxolotlMambaTrainer
|
|
||||||
from .relora import ReLoRATrainer
|
|
||||||
from .trl import (
|
|
||||||
AxolotlCPOTrainer,
|
|
||||||
AxolotlKTOTrainer,
|
|
||||||
AxolotlORPOTrainer,
|
|
||||||
AxolotlPRMTrainer,
|
|
||||||
AxolotlRewardTrainer,
|
|
||||||
TRLPPOTrainer,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,47 +1,365 @@
|
|||||||
"""Module for customized trainers"""
|
"""
|
||||||
|
module for customized trainers
|
||||||
# pylint: disable=too-many-lines
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# pylint: disable=too-many-lines
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Literal
|
from typing import Dict, Literal, Optional
|
||||||
|
|
||||||
import datasets
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import (
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
BatchSampler,
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
DataLoader,
|
|
||||||
RandomSampler,
|
|
||||||
Sampler,
|
|
||||||
SequentialSampler,
|
|
||||||
)
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
OptimizerMixin,
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
SchedulerMixin,
|
|
||||||
SequenceParallelMixin,
|
|
||||||
)
|
|
||||||
from axolotl.core.trainers.utils import (
|
|
||||||
sanitize_kwargs_for_ds_tagging,
|
|
||||||
sanitize_kwargs_for_tagging,
|
|
||||||
)
|
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
from axolotl.utils.schedulers import (
|
||||||
|
RexLR,
|
||||||
|
get_cosine_schedule_with_min_lr,
|
||||||
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
if isinstance(tag_names, str):
|
||||||
|
tag_names = [tag_names]
|
||||||
|
|
||||||
|
if kwargs is not None:
|
||||||
|
if "tags" not in kwargs:
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||||
|
kwargs["tags"].extend(tag_names)
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||||
|
tag_names.append(kwargs["tags"])
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||||
|
if isinstance(dataset_tags, str):
|
||||||
|
dataset_tags = [dataset_tags]
|
||||||
|
|
||||||
|
if (dataset_tags is not None) and (kwargs is not None):
|
||||||
|
if "dataset_tags" not in kwargs:
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||||
|
kwargs["dataset_tags"].extend(dataset_tags)
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||||
|
dataset_tags.append(kwargs["dataset_tags"])
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
passed as an argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_training_steps (int): The number of training steps to do.
|
||||||
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
|
"""
|
||||||
|
use_cosine_quadratic = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cosine_min_lr = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.cosine_min_lr_ratio is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||||
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
extra_lr_kwargs = {}
|
||||||
|
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["pct_start"] = pct_start
|
||||||
|
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||||
|
|
||||||
|
self.lr_scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
**extra_lr_kwargs,
|
||||||
|
**self.args.lr_scheduler_kwargs,
|
||||||
|
)
|
||||||
|
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
|
||||||
|
self.lr_scheduler = RexLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
)
|
||||||
|
elif use_cosine_quadratic:
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
|
else:
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for shared handling of building custom optimizers
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_optimizer_grouped_parameters(
|
||||||
|
self, opt_model, optimizer_kwargs
|
||||||
|
) -> list[dict]:
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
params: dict = {
|
||||||
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
|
"no_weight_decay": {},
|
||||||
|
}
|
||||||
|
lr_groups_lookup = {}
|
||||||
|
lr_groups_learning_rates = {}
|
||||||
|
if self.args.lr_groups:
|
||||||
|
for lr_group in self.args.lr_groups:
|
||||||
|
group_name = lr_group["name"]
|
||||||
|
group_modules = lr_group["modules"]
|
||||||
|
for module in group_modules:
|
||||||
|
lr_groups_lookup[module] = group_name
|
||||||
|
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||||
|
params[f"to_weight_decay_{group_name}"] = {}
|
||||||
|
|
||||||
|
for name, param in opt_model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
lr_group_modules = [
|
||||||
|
group_modules
|
||||||
|
for group_modules in lr_groups_lookup
|
||||||
|
if group_modules in name
|
||||||
|
]
|
||||||
|
if lr_groups_lookup and any(lr_group_modules):
|
||||||
|
lr_group_module = lr_group_modules[0]
|
||||||
|
group_name = lr_groups_lookup[lr_group_module]
|
||||||
|
params[f"to_weight_decay_{group_name}"][name] = param
|
||||||
|
else:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||||
|
if params[f"to_weight_decay_{group_name}"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(
|
||||||
|
params[f"to_weight_decay_{group_name}"].values()
|
||||||
|
),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": group_lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
|
and self.args.lr_groups is None
|
||||||
|
and self.optimizer_cls_and_kwargs is None
|
||||||
|
):
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self.optimizer
|
||||||
|
and self.optimizer_cls_and_kwargs is not None
|
||||||
|
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||||
|
):
|
||||||
|
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
self.optimizer = optimizer_factory_cls()(
|
||||||
|
opt_model, self.args, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.optimizer:
|
||||||
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
else:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args, opt_model
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
|
opt_model, optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(
|
||||||
|
self.args, "loraplus_lr_embedding", 1e-6
|
||||||
|
)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
|
**optimizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for GaLore optimizer.
|
||||||
|
if "params" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
|
|
||||||
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for LOMO optimizer.
|
||||||
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||||
|
"optimizer_dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer = optimizer_cls(
|
||||||
|
optimizer_grouped_parameters, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
skipped = 0
|
||||||
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum(
|
||||||
|
{
|
||||||
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||||
|
"""
|
||||||
|
Extend the base Trainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
@@ -58,18 +376,12 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
# Initialize sequence parallelism if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._setup_sequence_parallel()
|
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
@@ -82,20 +394,8 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def _create_multipack_sampler(
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
self, base_sampler: Sampler, dataset: Dataset
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
) -> MultipackBatchSampler:
|
|
||||||
"""
|
|
||||||
Helper method to create a `MultipackBatchSampler` for multipacking sequences
|
|
||||||
for training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
|
|
||||||
dataset: Dataset to sample from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Multipack (sample packing) batch sampler.
|
|
||||||
"""
|
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
batch_size = self.args.per_device_train_batch_size
|
batch_size = self.args.per_device_train_batch_size
|
||||||
batch_max_len = self.args.max_seq_length
|
batch_max_len = self.args.max_seq_length
|
||||||
@@ -106,223 +406,130 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
)
|
)
|
||||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
|
|
||||||
|
if self.args.curriculum_sampling:
|
||||||
|
sampler = SequentialSampler(self.train_dataset)
|
||||||
|
else:
|
||||||
|
sampler = RandomSampler(self.train_dataset)
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
return MultipackBatchSampler(
|
||||||
base_sampler,
|
sampler,
|
||||||
lengths=get_dataset_lengths(dataset),
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
batch_max_len=batch_max_len,
|
batch_max_len=batch_max_len,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
group_size=self.args.sample_packing_group_size,
|
||||||
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
if self.args.curriculum_sampling:
|
||||||
def _get_train_sampler(self) -> Sampler | None:
|
return SequentialSampler(self.train_dataset)
|
||||||
"""
|
|
||||||
Helper method to get the sampler for training. Handles cases for sequence
|
|
||||||
parallelism, sample packing, and curriculum sampling (sequential).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
|
||||||
depends on the passed training args.
|
|
||||||
"""
|
|
||||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
|
||||||
|
|
||||||
# Determine the base sampler first
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
|
||||||
elif self.args.curriculum_sampling:
|
|
||||||
base_sampler = SequentialSampler(self.train_dataset)
|
|
||||||
elif use_sample_packing:
|
|
||||||
base_sampler = RandomSampler(self.train_dataset)
|
|
||||||
else:
|
|
||||||
# Default to parent class implementation for standard random sampling
|
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
# Apply multipack wrapper if needed
|
def _get_eval_sampler(
|
||||||
if use_sample_packing:
|
self, eval_dataset: Dataset
|
||||||
return self._create_multipack_sampler(
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
base_sampler=base_sampler,
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
dataset=self.train_dataset,
|
if self.args.multipack_real_batches:
|
||||||
)
|
batch_size = self.args.per_device_eval_batch_size
|
||||||
|
batch_max_len = self.args.max_seq_length
|
||||||
return base_sampler
|
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
|
||||||
"""
|
|
||||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
|
||||||
and sample packing cases.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
|
||||||
depends on the passed training args.
|
|
||||||
"""
|
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
||||||
|
|
||||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
|
||||||
use_multipack = (
|
|
||||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine the base sampler
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
|
||||||
elif use_multipack:
|
|
||||||
base_sampler = SequentialSampler(eval_dataset)
|
|
||||||
else:
|
else:
|
||||||
|
batch_size = 1
|
||||||
|
batch_max_len = (
|
||||||
|
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||||
|
)
|
||||||
|
return MultipackBatchSampler(
|
||||||
|
SequentialSampler(eval_dataset),
|
||||||
|
lengths=get_dataset_lengths(self.eval_dataset),
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
|
batch_max_len=batch_max_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
group_size=self.args.sample_packing_group_size,
|
||||||
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
# Apply multipack wrapper if needed
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
if use_multipack:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
return self._create_multipack_sampler(
|
train_dataset = self.train_dataset
|
||||||
base_sampler=base_sampler,
|
if "length" in train_dataset.features.keys():
|
||||||
dataset=eval_dataset,
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
)
|
data_collator = self.data_collator
|
||||||
|
dataloader_params = {
|
||||||
return base_sampler
|
"batch_size": self._train_batch_size,
|
||||||
|
"collate_fn": data_collator,
|
||||||
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
|
||||||
"""Create common dataloader parameters for train or eval."""
|
|
||||||
batch_size = custom_batch_size or (
|
|
||||||
self.args.eval_batch_size if is_eval else self._train_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"collate_fn": self.data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add persistent workers only for training
|
|
||||||
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
|
||||||
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
|
||||||
|
|
||||||
# Add prefetch factor if specified
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
if self.args.dataloader_prefetch_factor:
|
||||||
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
dataloader_params["prefetch_factor"] = (
|
||||||
|
self.args.dataloader_prefetch_factor
|
||||||
|
)
|
||||||
|
|
||||||
return params
|
sampler = self._get_train_sampler()
|
||||||
|
|
||||||
def _prepare_dataloader(
|
|
||||||
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
|
||||||
):
|
|
||||||
"""Prepare a dataloader with the given dataset and sampler."""
|
|
||||||
# Get base parameters
|
|
||||||
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
|
||||||
|
|
||||||
# Add sampler configuration
|
|
||||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
# batch_size and batch_sampler are mutually exclusive
|
|
||||||
dataloader_params["batch_sampler"] = sampler
|
dataloader_params["batch_sampler"] = sampler
|
||||||
del dataloader_params["batch_size"]
|
del dataloader_params["batch_size"]
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
if not is_eval:
|
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
# Create the dataloader
|
|
||||||
dataloader = DataLoader(dataset, **dataloader_params)
|
|
||||||
|
|
||||||
if self.args.sample_packing and (
|
|
||||||
(not is_eval and not self.args.pretraining)
|
|
||||||
or (is_eval and self.args.eval_sample_packing is not False)
|
|
||||||
):
|
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
return self.accelerator.prepare_data_loader(
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
DataLoader(train_dataset, **dataloader_params)
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
# Otherwise prepare with accelerator
|
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
|
||||||
"""Get dataloader for training"""
|
|
||||||
train_dataset = self.train_dataset
|
|
||||||
data_collator = self.data_collator # type: ignore
|
|
||||||
|
|
||||||
# Handle dataset preprocessing
|
|
||||||
if isinstance(train_dataset, datasets.Dataset):
|
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
|
||||||
if not self.args.sample_packing or self.args.pretraining:
|
|
||||||
train_dataset = self._remove_unused_columns(
|
|
||||||
train_dataset, description="training"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
|
||||||
data_collator,
|
|
||||||
description="training",
|
|
||||||
)
|
)
|
||||||
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
# Get sampler and create dataloader
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
sampler = self._get_train_sampler()
|
|
||||||
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
|
||||||
"""Get dataloader for evaluation"""
|
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
||||||
|
|
||||||
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
)
|
)
|
||||||
if "length" in eval_dataset.column_names:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.train_data_collator
|
self.train_data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# Handle sample packing or sequence parallelism
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
if (
|
eval_dataset = (
|
||||||
self.args.sample_packing
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
and self.args.eval_sample_packing is not False
|
|
||||||
or self.args.sequence_parallel_degree > 1
|
|
||||||
):
|
|
||||||
# Get appropriate data collator
|
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.eval_data_collator
|
|
||||||
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
|
||||||
else self.data_collator
|
|
||||||
)
|
)
|
||||||
if "length" in eval_dataset.column_names:
|
|
||||||
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
|
data_collator = self.data_collator
|
||||||
# Handle dataset preprocessing for SP
|
dataloader_params = {
|
||||||
if self.args.sequence_parallel_degree > 1:
|
"batch_size": self.args.eval_batch_size,
|
||||||
if isinstance(eval_dataset, datasets.Dataset):
|
"collate_fn": data_collator,
|
||||||
eval_dataset = self._remove_unused_columns(
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
eval_dataset, description="evaluation"
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
}
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
dataloader_params["prefetch_factor"] = (
|
||||||
|
self.args.dataloader_prefetch_factor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(eval_sampler, BatchSampler):
|
||||||
|
dataloader_params["batch_sampler"] = eval_sampler
|
||||||
|
del dataloader_params["batch_size"]
|
||||||
else:
|
else:
|
||||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
dataloader_params["sampler"] = eval_sampler
|
||||||
self.data_collator, description="evaluation"
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
)
|
|
||||||
|
|
||||||
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
self.accelerator.even_batches = False
|
||||||
batch_size = (
|
return self.accelerator.prepare_data_loader(
|
||||||
self.args.eval_batch_size
|
DataLoader(eval_dataset, **dataloader_params)
|
||||||
if self.args.sample_packing
|
|
||||||
else self.args.per_device_eval_batch_size
|
|
||||||
)
|
)
|
||||||
sampler = self._get_eval_sampler(eval_dataset)
|
|
||||||
dataloader = self._prepare_dataloader(
|
|
||||||
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def _get_bench_sampler(
|
def _get_bench_sampler(
|
||||||
self, bench_dataset: Dataset
|
self, bench_dataset: Dataset
|
||||||
) -> torch.utils.data.Sampler | None:
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return SequentialSampler(bench_dataset)
|
return SequentialSampler(bench_dataset)
|
||||||
return None
|
return None
|
||||||
@@ -347,7 +554,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
@override
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
):
|
):
|
||||||
@@ -364,7 +570,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
return_outputs=return_outputs,
|
return_outputs=return_outputs,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().compute_loss(
|
return super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -539,10 +744,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@@ -559,13 +764,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs: The values to log.
|
logs (`Dict[str, float]`):
|
||||||
start_time: The start of training.
|
The values to log.
|
||||||
|
start_time (`Optional[float]`):
|
||||||
|
The start of training.
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -577,7 +784,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
) -> None:
|
) -> None:
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
@@ -590,26 +797,110 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
def training_step(
|
|
||||||
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Mamba specific trainer to handle loss calculation
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "mamba"]
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model,
|
||||||
inputs: dict[str, torch.Tensor | Any],
|
inputs,
|
||||||
num_items_in_batch: int | None = None,
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
) -> torch.Tensor:
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
input_ids = inputs.pop("input_ids")
|
||||||
|
lm_logits = model(input_ids).logits
|
||||||
|
|
||||||
|
labels = input_ids.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lm_loss
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRATrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Perform a training step on a batch of inputs. Overrides the
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
|
||||||
enabled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model to perform training step for.
|
|
||||||
inputs: Dictionary mapping.
|
|
||||||
"""
|
"""
|
||||||
# Set up sequence parallelism for this step if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._update_ring_flash_attn_params(inputs)
|
|
||||||
|
|
||||||
# Proceed with normal training step
|
tag_names = ["axolotl", "relora"]
|
||||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
|
||||||
|
|
||||||
return loss
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self,
|
||||||
|
num_training_steps: int,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
):
|
||||||
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
|
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
if self.args.relora_steps:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||||
|
)
|
||||||
|
anneal_steps = (
|
||||||
|
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||||
|
)
|
||||||
|
self.lr_scheduler = ReLoRAScheduler(
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
self.args.relora_steps,
|
||||||
|
anneal_steps,
|
||||||
|
warmup_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base RewardTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base trl.PRMTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "prm"]
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import SchedulerMixin
|
from axolotl.core.trainers.base import (
|
||||||
from axolotl.core.trainers.utils import (
|
SchedulerMixin,
|
||||||
sanitize_kwargs_for_ds_tagging,
|
_sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
_sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
@@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
"""Module for mamba trainer"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
|
||||||
"""Mamba specific trainer to handle loss calculation"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "mamba"]
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
inputs,
|
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
|
||||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
input_ids = inputs.pop("input_ids")
|
|
||||||
lm_logits = model(input_ids).logits
|
|
||||||
|
|
||||||
labels = input_ids.to(lm_logits.device)
|
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
lm_loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
return lm_loss
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
"""Init for axolotl.core.trainers.mixins"""
|
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from .optimizer import OptimizerMixin
|
|
||||||
from .scheduler import SchedulerMixin
|
|
||||||
from .sequence_parallel import SequenceParallelMixin
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
"""Module for Axolotl trainer optimizer mixin"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
|
||||||
from torch import nn
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
import smdistributed.modelparallel.torch as smp
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerMixin(Trainer):
|
|
||||||
"""Mixin class for shared handling of building custom optimizers"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(
|
|
||||||
self, opt_model, optimizer_kwargs
|
|
||||||
) -> list[dict]:
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
|
||||||
params: dict = {
|
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
|
||||||
"no_weight_decay": {},
|
|
||||||
}
|
|
||||||
lr_groups_lookup = {}
|
|
||||||
lr_groups_learning_rates = {}
|
|
||||||
if self.args.lr_groups:
|
|
||||||
for lr_group in self.args.lr_groups:
|
|
||||||
group_name = lr_group["name"]
|
|
||||||
group_modules = lr_group["modules"]
|
|
||||||
for module in group_modules:
|
|
||||||
lr_groups_lookup[module] = group_name
|
|
||||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
|
||||||
params[f"to_weight_decay_{group_name}"] = {}
|
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
if name.endswith("modules_to_save.default.weight") or any(
|
|
||||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
|
||||||
):
|
|
||||||
params["embeddings"][name] = param
|
|
||||||
elif name in decay_parameters:
|
|
||||||
lr_group_modules = [
|
|
||||||
group_modules
|
|
||||||
for group_modules in lr_groups_lookup
|
|
||||||
if group_modules in name
|
|
||||||
]
|
|
||||||
if lr_groups_lookup and any(lr_group_modules):
|
|
||||||
lr_group_module = lr_group_modules[0]
|
|
||||||
group_name = lr_groups_lookup[lr_group_module]
|
|
||||||
params[f"to_weight_decay_{group_name}"][name] = param
|
|
||||||
else:
|
|
||||||
params["to_weight_decay"][name] = param
|
|
||||||
else:
|
|
||||||
params["no_weight_decay"][name] = param
|
|
||||||
optimizer_grouped_parameters = []
|
|
||||||
if params["to_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["to_weight_decay"].values()),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["embeddings"]:
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
if self.args.embedding_lr_scale:
|
|
||||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
|
||||||
elif self.args.embedding_lr:
|
|
||||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["embeddings"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["no_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["no_weight_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
|
||||||
if params[f"to_weight_decay_{group_name}"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(
|
|
||||||
params[f"to_weight_decay_{group_name}"].values()
|
|
||||||
),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": group_lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if (
|
|
||||||
self.args.loraplus_lr_ratio is None
|
|
||||||
and self.args.embedding_lr_scale is None
|
|
||||||
and self.args.embedding_lr is None
|
|
||||||
and self.args.lr_groups is None
|
|
||||||
and self.optimizer_cls_and_kwargs is None
|
|
||||||
):
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
|
|
||||||
if (
|
|
||||||
not self.optimizer
|
|
||||||
and self.optimizer_cls_and_kwargs is not None
|
|
||||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
|
||||||
):
|
|
||||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
self.optimizer = optimizer_factory_cls()(
|
|
||||||
opt_model, self.args, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.optimizer:
|
|
||||||
if self.optimizer_cls_and_kwargs is not None:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
else:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args, opt_model
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
|
||||||
opt_model, optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
loraplus_lr_embedding = getattr(
|
|
||||||
self.args, "loraplus_lr_embedding", 1e-6
|
|
||||||
)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for GaLore optimizer.
|
|
||||||
if "params" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
|
||||||
|
|
||||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for LOMO optimizer.
|
|
||||||
if "model" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
|
||||||
|
|
||||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
|
||||||
# to avoid arguments conflicts.
|
|
||||||
if "optimizer_dict" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
|
||||||
"optimizer_dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer = optimizer_cls(
|
|
||||||
optimizer_grouped_parameters, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
for module in opt_model.modules():
|
|
||||||
if isinstance(module, nn.Embedding):
|
|
||||||
skipped += sum(
|
|
||||||
{
|
|
||||||
p.data_ptr(): p.numel() for p in module.parameters()
|
|
||||||
}.values()
|
|
||||||
)
|
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
|
||||||
manager.register_module_override(
|
|
||||||
module, "weight", {"optim_bits": 32}
|
|
||||||
)
|
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
"""Module for Axolotl trainer scheduler mixin"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
|
|
||||||
from axolotl.utils.schedulers import (
|
|
||||||
RexLR,
|
|
||||||
get_cosine_schedule_with_min_lr,
|
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
|
||||||
get_cosine_schedule_with_warmup_decay_constant,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Mixin class for scheduler setup in CausalTrainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
|
||||||
passed as an argument.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_training_steps (int): The number of training steps to do.
|
|
||||||
optimizer (torch.optim.Optimizer): The training optimizer
|
|
||||||
"""
|
|
||||||
use_cosine_quadratic = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.lr_quadratic_warmup is True
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cosine_min_lr = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.cosine_min_lr_ratio is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
|
||||||
# fmt: on
|
|
||||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
|
||||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
|
||||||
extra_lr_kwargs = {}
|
|
||||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["pct_start"] = pct_start
|
|
||||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
|
||||||
|
|
||||||
self.lr_scheduler = OneCycleLR(
|
|
||||||
optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
**extra_lr_kwargs,
|
|
||||||
**self.args.lr_scheduler_kwargs,
|
|
||||||
)
|
|
||||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
|
|
||||||
self.lr_scheduler = RexLR(
|
|
||||||
optimizer=optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
)
|
|
||||||
elif use_cosine_quadratic:
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
|
||||||
else:
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from datasets import Dataset
|
|
||||||
from torch.utils.data import DistributedSampler, Sampler
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from ring_flash_attn import update_ring_flash_attn_params
|
|
||||||
except ImportError:
|
|
||||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
|
||||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelMixin:
|
|
||||||
"""
|
|
||||||
Mixin class for sequence parallelism support in trainers.
|
|
||||||
|
|
||||||
This mixin provides functionality for handling sequence parallelism,
|
|
||||||
including creating appropriate samplers, managing data partitioning,
|
|
||||||
and updating ring flash attention parameters during training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def _setup_sequence_parallel(self):
|
|
||||||
"""Set up sequence parallelism environment."""
|
|
||||||
self.ring_attn_group = get_ring_attn_group()
|
|
||||||
|
|
||||||
def _create_sequence_parallel_sampler(
|
|
||||||
self,
|
|
||||||
dataset: Dataset,
|
|
||||||
shuffle: bool = True,
|
|
||||||
is_eval: bool = False,
|
|
||||||
) -> DistributedSampler:
|
|
||||||
"""
|
|
||||||
Helper method to create sampler for sequence parallelism (SP).
|
|
||||||
|
|
||||||
We create a distributed sampler with rank equal to the SP group ID, which
|
|
||||||
means that all ranks in the SP group receive the same sample / set of samples
|
|
||||||
per training step. We also set the number of replicas equal to the number of
|
|
||||||
SP groups, which is a bit of a hack / unintended use, but works!
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: Dataset to sample from.
|
|
||||||
shuffle: Whether to shuffle the dataset.
|
|
||||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Distributed sampler.
|
|
||||||
"""
|
|
||||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
|
||||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
return DistributedSampler(
|
|
||||||
dataset,
|
|
||||||
num_replicas=num_sp_groups,
|
|
||||||
rank=sp_group_id,
|
|
||||||
seed=self.args.seed if shuffle else None,
|
|
||||||
shuffle=shuffle,
|
|
||||||
drop_last=not is_eval,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
|
||||||
"""
|
|
||||||
Get a training sampler configured for sequence parallelism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: The training dataset
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured sequence parallel sampler.
|
|
||||||
"""
|
|
||||||
return self._create_sequence_parallel_sampler(
|
|
||||||
dataset,
|
|
||||||
shuffle=not self.args.curriculum_sampling,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
|
||||||
"""
|
|
||||||
Get an evaluation sampler configured for sequence parallelism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
eval_dataset: The evaluation dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured sequence parallel sampler.
|
|
||||||
"""
|
|
||||||
return self._create_sequence_parallel_sampler(
|
|
||||||
eval_dataset, shuffle=False, is_eval=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
|
||||||
"""
|
|
||||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
|
||||||
the substituted ring_flash_attn. This is accomplished by using the passed
|
|
||||||
`input_ids`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Current batch of inputs.
|
|
||||||
"""
|
|
||||||
# At this point, inputs should already be partitioned by the sequence
|
|
||||||
# parallel data collator
|
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
|
||||||
seq_len = inputs["input_ids"].shape[1]
|
|
||||||
packed_seq_lens = [seq_len] * batch_size
|
|
||||||
|
|
||||||
# Calculate the full sequence length across all GPUs in this SP group
|
|
||||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
cu_seqlens = torch.cumsum(
|
|
||||||
torch.tensor(
|
|
||||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlens = F.pad(
|
|
||||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
"""Module for ReLoRA trainer"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRATrainer(AxolotlTrainer):
|
|
||||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "relora"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.lr_scheduler = None
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self,
|
|
||||||
num_training_steps: int,
|
|
||||||
optimizer: torch.optim.Optimizer | None = None,
|
|
||||||
):
|
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
|
||||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
|
||||||
|
|
||||||
if self.args.relora_steps:
|
|
||||||
warmup_steps = (
|
|
||||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
|
||||||
)
|
|
||||||
anneal_steps = (
|
|
||||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
|
||||||
)
|
|
||||||
self.lr_scheduler = ReLoRAScheduler(
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler,
|
|
||||||
self.args.relora_steps,
|
|
||||||
anneal_steps,
|
|
||||||
warmup_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
@@ -1,25 +1,16 @@
|
|||||||
"""Module for TRL PPO trainer"""
|
"""
|
||||||
|
module for TRL PPO training
|
||||||
from typing import Literal, Union
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trl import (
|
from trl import PPOTrainer
|
||||||
CPOTrainer,
|
|
||||||
KTOTrainer,
|
|
||||||
ORPOTrainer,
|
|
||||||
PPOTrainer,
|
|
||||||
PRMTrainer,
|
|
||||||
RewardTrainer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
class TRLPPOTrainer(PPOTrainer):
|
||||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
"""
|
||||||
|
wrapper for ppo trainer to handle customizations
|
||||||
tag_names = ["axolotl", "ppo"]
|
"""
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
@@ -40,7 +31,9 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, batch in tqdm(enumerate(self.dataloader)):
|
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
||||||
|
enumerate(self.dataloader)
|
||||||
|
):
|
||||||
query_tensors = batch["input_ids"]
|
query_tensors = batch["input_ids"]
|
||||||
|
|
||||||
# generate model response
|
# generate model response
|
||||||
@@ -72,189 +65,3 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
rewards,
|
rewards,
|
||||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
|
||||||
|
|
||||||
def get_batch_loss_metrics(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
batch: dict[str, Union[list, torch.LongTensor]],
|
|
||||||
train_eval: Literal["train", "eval"] = "train",
|
|
||||||
):
|
|
||||||
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
|
||||||
|
|
||||||
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
|
|
||||||
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
forward_output = self.concatenated_forward(model, batch)
|
|
||||||
(
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
policy_chosen_logits,
|
|
||||||
policy_rejected_logits,
|
|
||||||
policy_nll_loss,
|
|
||||||
) = forward_output[:5]
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
aux_loss = forward_output[5]
|
|
||||||
|
|
||||||
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
|
|
||||||
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
|
||||||
)
|
|
||||||
# full ORPO loss
|
|
||||||
loss = policy_nll_loss - losses.mean()
|
|
||||||
|
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
|
||||||
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
|
|
||||||
chosen_rewards
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
|
|
||||||
rejected_rewards
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
|
|
||||||
reward_accuracies
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
|
||||||
chosen_rewards - rejected_rewards
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}logps/rejected"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logps/chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
|
|
||||||
policy_rejected_logits.detach().mean()
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
|
|
||||||
policy_chosen_logits.detach().mean()
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}nll_loss"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}log_odds_ratio"] = (
|
|
||||||
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}log_odds_chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
|
|
||||||
)
|
|
||||||
for k, v in metrics.items():
|
|
||||||
metrics[k] = v.item()
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
loss += self.aux_loss_coef * aux_loss
|
|
||||||
|
|
||||||
return loss, metrics
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
|
||||||
|
|
||||||
def get_batch_loss_metrics(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
batch: dict[str, Union[list, torch.LongTensor]],
|
|
||||||
train_eval: Literal["train", "eval"] = "train",
|
|
||||||
):
|
|
||||||
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
forward_output = self.concatenated_forward(model, batch)
|
|
||||||
(
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
policy_chosen_logits,
|
|
||||||
policy_rejected_logits,
|
|
||||||
policy_nll_loss,
|
|
||||||
) = forward_output[:5]
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
aux_loss = forward_output[5]
|
|
||||||
|
|
||||||
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
|
||||||
metrics[f"{prefix}rewards/chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}rewards/rejected"] = (
|
|
||||||
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}rewards/accuracies"] = (
|
|
||||||
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}rewards/margins"] = (
|
|
||||||
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logps/rejected"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_rejected_logps)
|
|
||||||
.detach()
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logps/chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_chosen_logps)
|
|
||||||
.detach()
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logits/rejected"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logits/chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}nll_loss"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
loss += self.aux_loss_coef * aux_loss
|
|
||||||
|
|
||||||
return loss, metrics
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "prm"]
|
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
"""Utils for Axolotl trainers"""
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|
||||||
if isinstance(tag_names, str):
|
|
||||||
tag_names = [tag_names]
|
|
||||||
|
|
||||||
if kwargs is not None:
|
|
||||||
if "tags" not in kwargs:
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
|
||||||
kwargs["tags"].extend(tag_names)
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
|
||||||
tag_names.append(kwargs["tags"])
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
|
||||||
if isinstance(dataset_tags, str):
|
|
||||||
dataset_tags = [dataset_tags]
|
|
||||||
|
|
||||||
if (dataset_tags is not None) and (kwargs is not None):
|
|
||||||
if "dataset_tags" not in kwargs:
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
|
||||||
kwargs["dataset_tags"].extend(dataset_tags)
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
|
||||||
dataset_tags.append(kwargs["dataset_tags"])
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
@@ -5,7 +5,6 @@ extra axolotl specific training args
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL.Image import Resampling
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
@@ -208,33 +207,14 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_parallel_degree: Optional[int] = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# multi-modal section
|
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The size of the image to resize to"},
|
|
||||||
)
|
|
||||||
|
|
||||||
image_resize_algorithm: Resampling | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The algorithm to use for image resizing"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# end of multi-modal section
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||||
default value so it can't be used as a mixin.
|
so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import Dataset
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
@@ -27,18 +25,18 @@ LOG = get_logger("axolotl.evaluate")
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_dataset(
|
def evaluate_dataset(
|
||||||
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
|
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
) -> Optional[Dict[str, float]]:
|
) -> Optional[Dict[str, float]]:
|
||||||
"""Helper function to evaluate a single dataset.
|
"""Helper function to evaluate a single dataset safely.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainer: The trainer instance.
|
trainer: The trainer instance
|
||||||
dataset: Dataset to evaluate.
|
dataset: Dataset to evaluate
|
||||||
dataset_type: Type of dataset ('train' or 'eval').
|
dataset_type: Type of dataset ('train' or 'eval')
|
||||||
flash_optimum: Whether to use flash optimum.
|
flash_optimum: Whether to use flash optimum
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of metrics or None if dataset is None.
|
Dictionary of metrics or None if dataset is None
|
||||||
"""
|
"""
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
@@ -65,14 +63,17 @@ def evaluate_dataset(
|
|||||||
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets.
|
Evaluate a model on training and validation datasets
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping metric names to their values.
|
Tuple containing:
|
||||||
|
- The model (either PeftModel or PreTrainedModel)
|
||||||
|
- The tokenizer
|
||||||
|
- Dictionary of evaluation metrics
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
|||||||
@@ -11,17 +11,19 @@
|
|||||||
# the License.
|
# the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Module to handle merging the plugins' input arguments with the base configurations.
|
module to handle merging the plugins' input arguments with the base configurations.
|
||||||
|
|
||||||
This was moved here to prevent circular imports.
|
this was moved here to prevent circular imports
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from axolotl.utils.schemas.config import (
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
|
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_input_args():
|
def merge_input_args():
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Cut Cross Entropy
|
# Cut Cross Entropy
|
||||||
|
|
||||||
Cut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
||||||
|
|
||||||
See https://github.com/apple/ml-cross-entropy
|
See https://github.com/apple/ml-cross-entropy
|
||||||
|
|
||||||
@@ -29,20 +29,6 @@ plugins:
|
|||||||
cut_cross_entropy: true
|
cut_cross_entropy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
- llama
|
|
||||||
- phi3
|
|
||||||
- gemma
|
|
||||||
- gemma2
|
|
||||||
- gemma3
|
|
||||||
- gemma3_text
|
|
||||||
- mistral
|
|
||||||
- mistral3
|
|
||||||
- qwen2
|
|
||||||
- cohere
|
|
||||||
- cohere2
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
```bib
|
```bib
|
||||||
|
|||||||
@@ -72,9 +72,7 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
if cfg.cut_cross_entropy:
|
if cfg.cut_cross_entropy:
|
||||||
self._check_requirements()
|
self._check_requirements()
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
from cut_cross_entropy.transformers import cce_patch
|
||||||
cce_patch,
|
|
||||||
)
|
|
||||||
|
|
||||||
with zero_only():
|
with zero_only():
|
||||||
LOG.info(
|
LOG.info(
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
"""Cohere and Cohere2 CCE patch."""
|
|
||||||
|
|
||||||
# This patch is based off transformers 4.50.0.
|
|
||||||
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
|
||||||
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
|
||||||
# operation is done internally and should be mathematically equivalent.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.cohere.modeling_cohere import (
|
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
COHERE_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>> from transformers import AutoTokenizer, CohereForCausalLM
|
|
||||||
|
|
||||||
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
|
||||||
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
|
||||||
|
|
||||||
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>> # Generate
|
|
||||||
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
# scale weight by logit_scale in-place of logits
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight * self.logit_scale,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
logits = logits * self.logit_scale # main diff from Llama
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_cohere(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.cohere import modeling_cohere
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_cohere.CohereForCausalLM
|
|
||||||
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_cohere2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.cohere2 import modeling_cohere2
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
|
||||||
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
"""Gemma CCE patch"""
|
|
||||||
|
|
||||||
# This patch is based off transformers 4.50.0.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.gemma.modeling_gemma import (
|
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
||||||
|
|
||||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
|
||||||
|
|
||||||
>>> prompt = "What is your favorite condiment?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is your favorite condiment?"
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma import modeling_gemma
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma.GemmaForCausalLM
|
|
||||||
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,465 +0,0 @@
|
|||||||
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
|
||||||
|
|
||||||
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
|
||||||
# and updated for transformers 4.50.0.
|
|
||||||
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
|
||||||
# with both gemma3 (text and multimodal) models.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache, HybridCache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA3_INPUTS_DOCSTRING,
|
|
||||||
Gemma3CausalLMOutputWithPast,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[HybridCache] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
|
||||||
|
|
||||||
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
||||||
|
|
||||||
>>> prompt = "What is your favorite condiment?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is your favorite condiment?"
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
if self.config.final_logit_softcapping is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"final_logit_softcapping is not supported for gemma3_text with CCE. Disabling."
|
|
||||||
)
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
|
|
||||||
if self.config.final_logit_softcapping is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
"final_logit_softcapping is not supported for gemma3 with CCE. Disabling."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if self.config.final_logit_softcapping is not None:
|
|
||||||
logits = logits / self.config.final_logit_softcapping
|
|
||||||
logits = torch.tanh(logits)
|
|
||||||
logits = logits * self.config.final_logit_softcapping
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
|
||||||
|
|
||||||
>>> prompt = "answer en Where is the cow standing?"
|
|
||||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"answer en Where is the cow standing?\nbeach"
|
|
||||||
```"""
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
is_training = token_type_ids is not None and labels is not None
|
|
||||||
|
|
||||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
|
||||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
|
||||||
special_image_mask = input_ids == self.config.image_token_index
|
|
||||||
llm_input_ids = input_ids.clone()
|
|
||||||
llm_input_ids[special_image_mask] = 0
|
|
||||||
else:
|
|
||||||
llm_input_ids = input_ids # type: ignore
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
|
||||||
|
|
||||||
if cache_position is None:
|
|
||||||
past_seen_tokens = (
|
|
||||||
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
|
||||||
)
|
|
||||||
cache_position = torch.arange( # type: ignore
|
|
||||||
past_seen_tokens,
|
|
||||||
past_seen_tokens + inputs_embeds.shape[1],
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge text and images
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(pixel_values)
|
|
||||||
|
|
||||||
if input_ids is None:
|
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
||||||
torch.tensor(
|
|
||||||
self.config.image_token_index,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
|
||||||
-1
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not is_torchdynamo_compiling()
|
|
||||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
|
||||||
):
|
|
||||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of images does not match number of special image tokens in the input text. "
|
|
||||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
|
||||||
"tokens from image embeddings."
|
|
||||||
)
|
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
|
||||||
|
|
||||||
# mask out pad-token-ids in labels for BC
|
|
||||||
if labels is not None and self.pad_token_id in labels:
|
|
||||||
logger.warning_once(
|
|
||||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
|
||||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
|
||||||
)
|
|
||||||
labels = torch.where( # type: ignore
|
|
||||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
|
||||||
)
|
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
|
||||||
attention_mask,
|
|
||||||
token_type_ids,
|
|
||||||
past_key_values,
|
|
||||||
cache_position,
|
|
||||||
inputs_embeds,
|
|
||||||
is_training,
|
|
||||||
)
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
shift_logits = logits[..., :-1, :]
|
|
||||||
shift_labels = labels[..., 1:]
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = shift_logits[
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = shift_labels[
|
|
||||||
shift_attention_mask.to(shift_labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = shift_logits.contiguous()
|
|
||||||
shift_labels = shift_labels.contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
|
||||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
||||||
loss = loss_fct(flat_logits, flat_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Gemma3CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma2 import modeling_gemma2
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
|
||||||
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma3_text(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma3 import modeling_gemma3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
|
||||||
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma3 import modeling_gemma3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
|
||||||
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,392 +0,0 @@
|
|||||||
"""Mistral and Mistral3 CCE patch."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.mistral3.modeling_mistral3 import (
|
|
||||||
Mistral3CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
MISTRAL_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] | None = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
|
||||||
|
|
||||||
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
image_sizes: torch.Tensor | None = None,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
||||||
|
|
||||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
vision_feature_layer = (
|
|
||||||
vision_feature_layer
|
|
||||||
if vision_feature_layer is not None
|
|
||||||
else self.config.vision_feature_layer
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
vision_feature_layer=vision_feature_layer,
|
|
||||||
image_sizes=image_sizes,
|
|
||||||
)
|
|
||||||
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
not is_torchdynamo_compiling()
|
|
||||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
|
||||||
):
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
|
||||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
||||||
)
|
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = logits[..., :-1, :][
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = labels[..., 1:][
|
|
||||||
shift_attention_mask.to(labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1).to(shift_logits.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Mistral3CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mistral import modeling_mistral
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mistral.MistralForCausalLM
|
|
||||||
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mistral import modeling_mistral
|
|
||||||
from transformers.models.mistral3 import modeling_mistral3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
|
||||||
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,379 +0,0 @@
|
|||||||
"""Mllama CCE patch."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.mllama.modeling_mllama import (
|
|
||||||
MLLAMA_INPUTS_DOCSTRING,
|
|
||||||
_prepare_cross_attention_mask,
|
|
||||||
)
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_mask: Optional[torch.LongTensor] = None,
|
|
||||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
|
||||||
|
|
||||||
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
|
||||||
|
|
||||||
>>> prompt = "If I had to write a haiku, it would be:"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
|
||||||
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
>>> print(result)
|
|
||||||
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
|
||||||
I love the idea of snowflakes gently falling, each one
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cross_attention_mask=cross_attention_mask,
|
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
|
||||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
|
||||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
|
||||||
|
|
||||||
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
|
||||||
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
|
||||||
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
|
||||||
|
|
||||||
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> output = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
|
|
||||||
>>> prompt_len = inputs.input_ids.shape[-1]
|
|
||||||
>>> generated_ids = output[:, prompt_len:]
|
|
||||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
||||||
>>> print(generated_text)
|
|
||||||
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None and cross_attention_states is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
if aspect_ratio_ids is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
|
||||||
)
|
|
||||||
# get vision tokens from vision model
|
|
||||||
vision_outputs = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
aspect_ratio_ids=aspect_ratio_ids,
|
|
||||||
aspect_ratio_mask=aspect_ratio_mask,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
cross_attention_states = vision_outputs[0]
|
|
||||||
cross_attention_states = self.multi_modal_projector(
|
|
||||||
cross_attention_states
|
|
||||||
).reshape(
|
|
||||||
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
if cross_attention_mask is not None:
|
|
||||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
|
||||||
_prepare_cross_attention_mask(
|
|
||||||
cross_attention_mask,
|
|
||||||
num_vision_tokens=self.vision_model.num_patches,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
full_text_row_masked_out_mask = None
|
|
||||||
|
|
||||||
if cross_attention_mask is not None and cache_position is not None:
|
|
||||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
|
||||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
|
||||||
:, :, cache_position
|
|
||||||
]
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
cross_attention_mask=cross_attention_mask,
|
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
|
||||||
logits = hidden_states
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (loss,) + outputs if loss is not None else outputs
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=outputs.logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mllama(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mllama import modeling_mllama
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
|
||||||
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the language model
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
|
|
||||||
# patch the causal language model
|
|
||||||
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
||||||
|
|
||||||
"""Cut Cross Entropy patcher"""
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
|
||||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
|
||||||
from cut_cross_entropy.transformers.llama import patch_llama
|
|
||||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
|
||||||
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
|
|
||||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
|
||||||
patch_cohere,
|
|
||||||
patch_cohere2,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
|
||||||
patch_gemma2,
|
|
||||||
patch_gemma3,
|
|
||||||
patch_gemma3_text,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
|
||||||
patch_mistral,
|
|
||||||
patch_mistral3,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
|
||||||
|
|
||||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
|
||||||
"llama": patch_llama,
|
|
||||||
"mllama": patch_mllama,
|
|
||||||
"phi3": patch_phi3,
|
|
||||||
"gemma": patch_gemma,
|
|
||||||
"gemma2": patch_gemma2,
|
|
||||||
"gemma3": patch_gemma3,
|
|
||||||
"gemma3_text": patch_gemma3_text,
|
|
||||||
"mistral": patch_mistral,
|
|
||||||
"mistral3": patch_mistral3,
|
|
||||||
"qwen2": patch_qwen2,
|
|
||||||
"cohere": patch_cohere,
|
|
||||||
"cohere2": patch_cohere2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def cce_patch(
|
|
||||||
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
|
||||||
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
|
||||||
reduction: str = "mean",
|
|
||||||
filter_eps: float | str | None = "auto",
|
|
||||||
accum_e_fp32: bool = False,
|
|
||||||
accum_c_fp32: bool = False,
|
|
||||||
filter_e_grad: bool = True,
|
|
||||||
filter_c_grad: bool = True,
|
|
||||||
train_only: bool = False,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
if isinstance(impl, LinearCrossEntropyImpl):
|
|
||||||
impl = impl.name.lower()
|
|
||||||
|
|
||||||
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
|
||||||
raise ValueError(f"Unknown {impl=}")
|
|
||||||
|
|
||||||
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
|
||||||
model_type = model_type_or_model.config.model_type
|
|
||||||
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
|
||||||
model_type = model_type_or_model.model_type
|
|
||||||
else:
|
|
||||||
model_type = model_type_or_model
|
|
||||||
|
|
||||||
patch_options = PatchOptions(
|
|
||||||
impl=impl,
|
|
||||||
reduction=reduction,
|
|
||||||
filter_eps=filter_eps,
|
|
||||||
accum_e_fp32=accum_e_fp32,
|
|
||||||
accum_c_fp32=accum_c_fp32,
|
|
||||||
filter_e_grad=filter_e_grad,
|
|
||||||
filter_c_grad=filter_c_grad,
|
|
||||||
train_only=train_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
|
||||||
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
|
||||||
model_type_or_model, patch_options
|
|
||||||
)
|
|
||||||
|
|
||||||
raise RuntimeError(f"Unknown model type {model_type}")
|
|
||||||
@@ -114,5 +114,3 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
elif cfg.model_config_type in ["gemma3_text", "deepseek_v3"]:
|
|
||||||
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
|
||||||
|
|||||||
@@ -1,89 +0,0 @@
|
|||||||
"""
|
|
||||||
Ring attention group registration and flash attention patching.
|
|
||||||
|
|
||||||
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
|
||||||
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
|
||||||
their sequence parallel version of Flash Attention 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
RING_ATTN_GROUP = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
|
||||||
"""
|
|
||||||
Getter for ring attention group on this rank.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The process group for ring attention for this rank.
|
|
||||||
"""
|
|
||||||
return RING_ATTN_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
|
||||||
"""
|
|
||||||
Setter for ring attention group on this rank.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
Process group for ring attention.
|
|
||||||
"""
|
|
||||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
|
||||||
RING_ATTN_GROUP = ring_attn_group
|
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(sequence_parallel_degree: int):
|
|
||||||
"""
|
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequence_parallel_degree: Sequence parallelism factor.
|
|
||||||
"""
|
|
||||||
LOG.info(
|
|
||||||
"Enabling ring attention sequence parallelism: "
|
|
||||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
|
||||||
)
|
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
assert sequence_parallel_degree <= world_size, (
|
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
|
||||||
f"must be less than or equal to world_size ({world_size})"
|
|
||||||
)
|
|
||||||
assert world_size % sequence_parallel_degree == 0, (
|
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
|
||||||
f"must evenly divide world_size ({world_size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Detailed logging of group formation
|
|
||||||
rank = dist.get_rank()
|
|
||||||
group_assignments = {}
|
|
||||||
|
|
||||||
for i in range(world_size // sequence_parallel_degree):
|
|
||||||
ring_attn_ranks = list(
|
|
||||||
range(
|
|
||||||
i * sequence_parallel_degree,
|
|
||||||
(i + 1) * sequence_parallel_degree,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
|
||||||
|
|
||||||
# Track which GPUs are in which groups
|
|
||||||
for r in ring_attn_ranks:
|
|
||||||
group_assignments[r] = i
|
|
||||||
|
|
||||||
if rank in ring_attn_ranks:
|
|
||||||
set_ring_attn_group(group)
|
|
||||||
|
|
||||||
# Log the GPU group assignments
|
|
||||||
if rank == 0:
|
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
|
||||||
|
|
||||||
from ring_flash_attn import substitute_hf_flash_attn
|
|
||||||
|
|
||||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
|
||||||
@@ -22,9 +22,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"phi3",
|
"phi3",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
"gemma3_text",
|
|
||||||
"cohere",
|
|
||||||
"cohere2",
|
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"deepseek_v2",
|
"deepseek_v2",
|
||||||
|
|||||||
@@ -1,313 +0,0 @@
|
|||||||
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
|
||||||
|
|
||||||
import ast
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from PIL import Image, ImageOps
|
|
||||||
from PIL.Image import Resampling
|
|
||||||
from torch import Tensor
|
|
||||||
from transformers import ProcessorMixin
|
|
||||||
from transformers.image_utils import load_image
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessingStrategy:
|
|
||||||
"""Base Processing Strategy class"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
processor: ProcessorMixin,
|
|
||||||
chat_template: Optional[str] = None,
|
|
||||||
image_size: int | tuple[int, int] | None = None,
|
|
||||||
image_resize_algorithm: Resampling | None = None,
|
|
||||||
):
|
|
||||||
self.processor = processor
|
|
||||||
self.chat_template = chat_template
|
|
||||||
self.image_token = None
|
|
||||||
self.image_token_id = None
|
|
||||||
|
|
||||||
self.image_size = image_size
|
|
||||||
self.image_resize_algorithm = (
|
|
||||||
image_resize_algorithm or Image.Resampling.BILINEAR
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(processor, "image_token"):
|
|
||||||
self.image_token = processor.image_token
|
|
||||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
|
||||||
self.image_token
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, examples: list[dict]) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Preprocess conversation examples to ensure consistent format.
|
|
||||||
Converts different conversation formats to OpenAI format with 'messages'.
|
|
||||||
Supports two formats:
|
|
||||||
1. OpenAI format with 'messages'
|
|
||||||
2. Legacy format with 'conversations'
|
|
||||||
|
|
||||||
Args:
|
|
||||||
examples: list of conversation dictionaries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of dicts in OpenAI format with 'messages' key
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the conversation format is not supported
|
|
||||||
"""
|
|
||||||
role_mapping = {
|
|
||||||
"human": "user",
|
|
||||||
"gpt": "assistant",
|
|
||||||
}
|
|
||||||
|
|
||||||
def normalize_role(role: str) -> str:
|
|
||||||
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
|
||||||
return role_mapping.get(role, role)
|
|
||||||
|
|
||||||
def convert_legacy_format(example: dict) -> dict:
|
|
||||||
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
|
||||||
messages = [
|
|
||||||
{"role": normalize_role(convo["from"]), "content": convo["value"]}
|
|
||||||
for convo in example["conversations"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create new dict without 'conversations' key
|
|
||||||
result = deepcopy(example)
|
|
||||||
result.pop("conversations")
|
|
||||||
result["messages"] = messages
|
|
||||||
return result
|
|
||||||
|
|
||||||
def convert_multiple_choice_to_multimedia_messages(
|
|
||||||
messages: dict,
|
|
||||||
) -> list[dict]:
|
|
||||||
|
|
||||||
def construct_prompt(sample):
|
|
||||||
question = sample["question"]
|
|
||||||
options = sample["options"]
|
|
||||||
if isinstance(options, str):
|
|
||||||
options = ast.literal_eval(options)
|
|
||||||
|
|
||||||
example = ""
|
|
||||||
start_chr = "A"
|
|
||||||
prediction_range = []
|
|
||||||
index2ans = {}
|
|
||||||
for option in options:
|
|
||||||
prediction_range.append(start_chr)
|
|
||||||
example += f"({start_chr}) {option}\n"
|
|
||||||
index2ans[start_chr] = option
|
|
||||||
start_chr = chr(ord(start_chr) + 1)
|
|
||||||
|
|
||||||
empty_prompt_sample_structure = "{}\n\n{}\n\nAnswer with the option's letter from the given choices directly."
|
|
||||||
empty_prompt = empty_prompt_sample_structure.format(question, example)
|
|
||||||
|
|
||||||
return empty_prompt
|
|
||||||
|
|
||||||
new_messages = []
|
|
||||||
|
|
||||||
user_content = construct_prompt(messages)
|
|
||||||
assistant_response = messages["answer"]
|
|
||||||
|
|
||||||
new_messages.append(
|
|
||||||
{"role": "user", "content": [{"type": "text", "text": user_content}]}
|
|
||||||
)
|
|
||||||
|
|
||||||
new_messages.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "text", "text": assistant_response}],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_messages
|
|
||||||
|
|
||||||
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
|
|
||||||
"""Convert regular messages format to Messages format with content type"""
|
|
||||||
|
|
||||||
new_messages = []
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message["content"], str):
|
|
||||||
new_messages.append(
|
|
||||||
{
|
|
||||||
"role": message["role"],
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": message["content"],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif isinstance(message["content"], list):
|
|
||||||
content = message["content"]
|
|
||||||
|
|
||||||
new_messages.append(
|
|
||||||
{
|
|
||||||
"role": message["role"],
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_messages
|
|
||||||
|
|
||||||
processed_examples = []
|
|
||||||
for example in examples:
|
|
||||||
if not (
|
|
||||||
"messages" in example
|
|
||||||
or "conversations" in example
|
|
||||||
or "question" in example
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Only `messages`, `conversations`, and `question` message keys are currently supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_example = None
|
|
||||||
if "messages" in example: # OpenAI format
|
|
||||||
processed_example = example
|
|
||||||
# convert regular messages format to Messages format with content type
|
|
||||||
# for compatibility with apply_chat_template
|
|
||||||
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
|
||||||
processed_example["messages"]
|
|
||||||
)
|
|
||||||
elif "question" in example: # Multiple choice format
|
|
||||||
processed_example = {}
|
|
||||||
processed_example["messages"] = (
|
|
||||||
convert_multiple_choice_to_multimedia_messages(example)
|
|
||||||
)
|
|
||||||
else: # Legacy format
|
|
||||||
processed_example = convert_legacy_format(example)
|
|
||||||
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
|
||||||
processed_example["messages"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# find the image key if it exists
|
|
||||||
|
|
||||||
image_keys = []
|
|
||||||
for key in example.keys():
|
|
||||||
if "image" in key:
|
|
||||||
image_keys.append(key)
|
|
||||||
|
|
||||||
for im_key in image_keys:
|
|
||||||
if example[im_key] is None:
|
|
||||||
continue
|
|
||||||
if isinstance(example[im_key], list):
|
|
||||||
if len(example[im_key]) == 0:
|
|
||||||
continue
|
|
||||||
image_value = example[im_key][0]
|
|
||||||
else:
|
|
||||||
image_value = example[im_key]
|
|
||||||
|
|
||||||
image_value = load_image(image_value)
|
|
||||||
|
|
||||||
if self.image_size is not None:
|
|
||||||
assert hasattr(
|
|
||||||
image_value, "resize"
|
|
||||||
), "Image does not have a resize method"
|
|
||||||
|
|
||||||
if isinstance(self.image_size, tuple):
|
|
||||||
image_value = image_value.resize(
|
|
||||||
self.image_size, self.image_resize_algorithm
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Set the padding value; here we use black (0, 0, 0) for RGB images
|
|
||||||
padding_color = (0, 0, 0)
|
|
||||||
|
|
||||||
# When image_size is an int (square target), preserve aspect ratio then pad
|
|
||||||
# This is to prevent aspect ratio distortion when resizing to square
|
|
||||||
image_value = ImageOps.pad(
|
|
||||||
image_value,
|
|
||||||
(self.image_size, self.image_size),
|
|
||||||
method=self.image_resize_algorithm,
|
|
||||||
color=padding_color,
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_example["messages"][0]["content"].append(
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"image": image_value,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_examples.append(processed_example)
|
|
||||||
|
|
||||||
return processed_examples
|
|
||||||
|
|
||||||
def process_labels(self, input_ids: Tensor) -> Tensor:
|
|
||||||
labels = input_ids.clone()
|
|
||||||
|
|
||||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
|
||||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
|
||||||
|
|
||||||
# Ignore the image token index in the loss computation (model specific)
|
|
||||||
labels[labels == self.image_token_id] = -100
|
|
||||||
|
|
||||||
return labels
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLProcessingStrategy(ProcessingStrategy):
|
|
||||||
"""Processing Strategy class for Qwen2-VL"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
processor: ProcessorMixin,
|
|
||||||
chat_template: Optional[str] = None,
|
|
||||||
image_size: int | tuple[int, int] | None = None,
|
|
||||||
image_resize_algorithm: Resampling | None = None,
|
|
||||||
):
|
|
||||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
|
||||||
self.image_token = "<|image_pad|>" # nosec
|
|
||||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
|
||||||
self.image_token
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3ProcessingStrategy(ProcessingStrategy):
|
|
||||||
"""Processing Strategy class for Gemma3"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
processor: ProcessorMixin,
|
|
||||||
chat_template: Optional[str] = None,
|
|
||||||
image_size: int | tuple[int, int] | None = None,
|
|
||||||
image_resize_algorithm: Resampling | None = None,
|
|
||||||
):
|
|
||||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
|
||||||
self.image_token = processor.tokenizer.special_tokens_map["boi_token"]
|
|
||||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
|
||||||
self.image_token
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_labels(self, input_ids):
|
|
||||||
labels = input_ids.clone()
|
|
||||||
|
|
||||||
# Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora
|
|
||||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
|
||||||
labels[labels == self.image_token_id] = -100
|
|
||||||
labels[labels == 262144] = -100 # corresponds to <image_soft_token>
|
|
||||||
|
|
||||||
return labels
|
|
||||||
|
|
||||||
|
|
||||||
def get_processing_strategy(
|
|
||||||
processor: ProcessorMixin,
|
|
||||||
chat_template,
|
|
||||||
chat_template_type,
|
|
||||||
image_size: int | tuple[int, int] | None = None,
|
|
||||||
image_resize_algorithm: Resampling | None = None,
|
|
||||||
):
|
|
||||||
if chat_template_type == "qwen2_vl":
|
|
||||||
return Qwen2VLProcessingStrategy(
|
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
|
||||||
)
|
|
||||||
if chat_template_type == "gemma3":
|
|
||||||
return Gemma3ProcessingStrategy(
|
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
|
||||||
)
|
|
||||||
if chat_template_type in [
|
|
||||||
"llama3_2_vision",
|
|
||||||
"llava",
|
|
||||||
"mistral_v7_tekken",
|
|
||||||
"pixtral",
|
|
||||||
]:
|
|
||||||
return ProcessingStrategy(
|
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
|
||||||
)
|
|
||||||
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
|
|
||||||
@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
|||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.schemas.datasets import DatasetConfig
|
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||||
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
def default(
|
def default(
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ def execute_training(
|
|||||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute the training process with appropriate SDP kernel configurations.
|
Execute the training process with appropriate backend configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
@@ -177,6 +177,9 @@ def execute_training(
|
|||||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||||
"""
|
"""
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
|
if cfg.group_by_length:
|
||||||
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||||
@@ -314,7 +317,6 @@ def save_initial_configs(
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
peft_config: PeftConfig | None,
|
peft_config: PeftConfig | None,
|
||||||
processor: ProcessorMixin | None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save initial configurations before training.
|
Save initial configurations before training.
|
||||||
@@ -342,10 +344,6 @@ def save_initial_configs(
|
|||||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||||
model.config.save_pretrained(str(output_dir))
|
model.config.save_pretrained(str(output_dir))
|
||||||
|
|
||||||
if processor:
|
|
||||||
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
|
||||||
processor.save_pretrained(str(output_dir))
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_card(cfg: DictDefault):
|
def setup_model_card(cfg: DictDefault):
|
||||||
"""
|
"""
|
||||||
@@ -413,7 +411,6 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PeftConfig | None,
|
PeftConfig | None,
|
||||||
ProcessorMixin | None,
|
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
||||||
@@ -429,7 +426,6 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
- Model
|
- Model
|
||||||
- Tokenizer
|
- Tokenizer
|
||||||
- PEFT config
|
- PEFT config
|
||||||
- Processor
|
|
||||||
"""
|
"""
|
||||||
# Load tokenizer, processor and model
|
# Load tokenizer, processor and model
|
||||||
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
||||||
@@ -460,7 +456,6 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
peft_config,
|
peft_config,
|
||||||
processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -483,7 +478,6 @@ def train(
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
peft_config,
|
peft_config,
|
||||||
processor,
|
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
# Determine if we need to resume from a checkpoint
|
# Determine if we need to resume from a checkpoint
|
||||||
@@ -499,7 +493,7 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save initial configs
|
# Save initial configs
|
||||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
save_initial_configs(cfg, tokenizer, model, peft_config)
|
||||||
|
|
||||||
# Set up signal handler for graceful termination
|
# Set up signal handler for graceful termination
|
||||||
setup_signal_handler(cfg, model, safe_serialization)
|
setup_signal_handler(cfg, model, safe_serialization)
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from trl.models import unwrap_model_for_generation
|
|||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
broadcast_dict,
|
broadcast_dict,
|
||||||
@@ -42,7 +43,6 @@ from axolotl.utils.distributed import (
|
|||||||
is_main_process,
|
is_main_process,
|
||||||
zero_first,
|
zero_first,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,59 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
includes logic for handling sequence parallelism collation.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_position_ids_for_slice(
|
|
||||||
position_ids: torch.Tensor, start_idx: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
|
||||||
This handles the case where position IDs might not be contiguous due to sample
|
|
||||||
packing.
|
|
||||||
"""
|
|
||||||
# Convert to tensor if not already
|
|
||||||
# Find the boundaries between samples (where position_ids reset)
|
|
||||||
adjusted_pos_ids = position_ids.clone()
|
|
||||||
|
|
||||||
# Process each sequence in the batch
|
|
||||||
for i in range(position_ids.shape[0]):
|
|
||||||
seq = position_ids[i]
|
|
||||||
|
|
||||||
# Find sample boundaries
|
|
||||||
boundaries = []
|
|
||||||
for j in range(1, len(seq)):
|
|
||||||
if seq[j] < seq[j - 1]:
|
|
||||||
boundaries.append(j)
|
|
||||||
|
|
||||||
# No need to adjust if there are no boundaries or this is a single sample
|
|
||||||
if not boundaries:
|
|
||||||
adjusted_pos_ids[i] = seq - start_idx
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Adjust each segment separately
|
|
||||||
prev_boundary = 0
|
|
||||||
for boundary in boundaries:
|
|
||||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
|
||||||
prev_boundary = boundary
|
|
||||||
|
|
||||||
# Last segment
|
|
||||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
|
||||||
|
|
||||||
return adjusted_pos_ids
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -88,8 +43,6 @@ class DataCollatorForSeq2Seq:
|
|||||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
return_tensors (`str`):
|
return_tensors (`str`):
|
||||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
sequence_parallel_degree (`int`):
|
|
||||||
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
@@ -100,16 +53,6 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
sequence_parallel_degree: int = 1
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.sequence_parallel_degree > 1:
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
|
||||||
|
|
||||||
# Get information about our position in the SP group
|
|
||||||
sp_group = get_ring_attn_group()
|
|
||||||
self.local_rank = dist.get_rank(group=sp_group)
|
|
||||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -176,43 +119,8 @@ class DataCollatorForSeq2Seq:
|
|||||||
)
|
)
|
||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
if self.sequence_parallel_degree > 1:
|
|
||||||
features = self.apply_sequence_parallelism(features)
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def apply_sequence_parallelism(
|
|
||||||
self, batch: dict[str, torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply sequence parallelism slicing to a batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: Batch dictionary from parent collator.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sliced batch dictionary.
|
|
||||||
"""
|
|
||||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
|
||||||
|
|
||||||
for key in keys_to_slice:
|
|
||||||
if key in batch:
|
|
||||||
seq_len = batch[key].shape[1]
|
|
||||||
slice_size = seq_len // self.local_world_size
|
|
||||||
start_idx = self.local_rank * slice_size
|
|
||||||
end_idx = (
|
|
||||||
start_idx + slice_size
|
|
||||||
if self.local_rank < self.local_world_size - 1
|
|
||||||
else seq_len
|
|
||||||
)
|
|
||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
|
||||||
|
|
||||||
# Special handling for position_ids
|
|
||||||
if key == "position_ids" and self.local_rank > 0:
|
|
||||||
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
@@ -240,7 +148,6 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@@ -270,7 +177,6 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,17 +2,15 @@
|
|||||||
Collators for multi-modal chat messages and packing
|
Collators for multi-modal chat messages and packing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
from PIL import Image
|
||||||
from torch import Tensor
|
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
from transformers.data.data_collator import DataCollatorMixin
|
from transformers.data.data_collator import DataCollatorMixin
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
from axolotl.processing_strategies import ProcessingStrategy
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiModalChatDataCollator(DataCollatorMixin):
|
class MultiModalChatDataCollator(DataCollatorMixin):
|
||||||
@@ -21,9 +19,11 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
processing_strategy: ProcessingStrategy
|
processor: ProcessorMixin
|
||||||
packing: bool = False
|
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
chat_template: Optional[str] = None
|
||||||
|
packing: bool = False
|
||||||
|
max_images: int = -1
|
||||||
padding: Union[bool, str, PaddingStrategy] = True
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
@@ -31,62 +31,162 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
if self.packing:
|
if self.packing:
|
||||||
raise ValueError("Packing is currently not supported.")
|
raise ValueError("Packing is currently not supported.")
|
||||||
|
|
||||||
def torch_call(self, examples: list[dict]) -> dict[str, Any]:
|
def torch_call(
|
||||||
return self.process_rows(examples)
|
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
|
|
||||||
def process_rows(
|
return self.__class__.process_rows(
|
||||||
self,
|
examples, self.processor, self.chat_template, self.max_images
|
||||||
examples: list[dict],
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
# Preprocess the examples
|
|
||||||
examples = self.processing_strategy(examples)
|
|
||||||
|
|
||||||
# Initialize batch
|
|
||||||
batch: dict[str, Any] = {}
|
|
||||||
|
|
||||||
# Process each example
|
|
||||||
for example in examples:
|
|
||||||
# Apply chat template to process the example
|
|
||||||
# This method requires transformers>=4.49.0
|
|
||||||
result = self.processing_strategy.processor.apply_chat_template(
|
|
||||||
example["messages"],
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
return_dict=True,
|
|
||||||
chat_template=self.processing_strategy.chat_template,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Check if need handling for len(input_ids) > sequence_len
|
@staticmethod
|
||||||
|
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||||
|
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||||
|
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||||
|
|
||||||
# Add the processed tensors to our batch
|
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||||
for key in result.keys():
|
# use this as a starting point
|
||||||
if key not in batch:
|
|
||||||
batch[key] = []
|
|
||||||
|
|
||||||
batch[key].append(result[key].squeeze(0))
|
def _preprocess(examples: list[dict]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Preprocess conversation examples to ensure consistent format.
|
||||||
|
|
||||||
# Pad sequences to the same length
|
Converts different conversation formats to OpenAI format with 'messages'.
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
Supports two formats:
|
||||||
batch["input_ids"],
|
1. OpenAI format with 'messages'
|
||||||
batch_first=True,
|
2. Legacy format with 'conversations'
|
||||||
padding_value=self.tokenizer.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
Args:
|
||||||
batch["attention_mask"], batch_first=True, padding_value=0
|
examples: list of conversation dictionaries
|
||||||
)
|
|
||||||
|
|
||||||
# Create the final batch
|
Returns:
|
||||||
final_batch = {
|
dict in OpenAI format with 'messages' key
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
Raises:
|
||||||
|
ValueError: If the conversation format is not supported
|
||||||
|
"""
|
||||||
|
role_mapping = {
|
||||||
|
"human": "user",
|
||||||
|
"gpt": "assistant",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Process the labels
|
def normalize_role(role: str) -> str:
|
||||||
final_batch["labels"] = self.processing_strategy.process_labels(
|
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||||
final_batch["input_ids"]
|
return role_mapping.get(role, role)
|
||||||
|
|
||||||
|
def convert_legacy_format(example: dict) -> dict:
|
||||||
|
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": normalize_role(convo["from"]),
|
||||||
|
"content": convo["value"],
|
||||||
|
}
|
||||||
|
for convo in example["conversations"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create new dict without 'conversations' key
|
||||||
|
result = deepcopy(example)
|
||||||
|
result.pop("conversations")
|
||||||
|
return {"messages": messages, **result}
|
||||||
|
|
||||||
|
processed_examples = []
|
||||||
|
for example in examples:
|
||||||
|
# OpenAI format
|
||||||
|
if "messages" in example:
|
||||||
|
processed_examples.append(example)
|
||||||
|
|
||||||
|
# Legacy format
|
||||||
|
elif "conversations" in example:
|
||||||
|
processed_examples.append(convert_legacy_format(example))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Only `messages` and `conversations` message keys are currently supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
return final_batch
|
return processed_examples
|
||||||
|
|
||||||
|
def _process_images(examples, max_images):
|
||||||
|
"""
|
||||||
|
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: List of dictionaries that may contain 'images' key
|
||||||
|
max_images: Maximum number of images to keep per example (0 means no limit)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either None (if no images) or List[Image objects] (if all examples have images)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there's a mix of None and non-None images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_image(example):
|
||||||
|
if "images" not in example:
|
||||||
|
return None
|
||||||
|
images = example["images"]
|
||||||
|
if isinstance(images, str):
|
||||||
|
return Image.open(images)
|
||||||
|
return images
|
||||||
|
|
||||||
|
images = [get_image(example) for example in examples]
|
||||||
|
|
||||||
|
# Count None and non-None images
|
||||||
|
none_count = sum(1 for img in images if img is None)
|
||||||
|
|
||||||
|
# All images are None
|
||||||
|
if none_count == len(images):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Mix of None and non-None images
|
||||||
|
if none_count > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"All images should be either None or not None. "
|
||||||
|
"Please provide images for all examples or None."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_images limit if specified
|
||||||
|
if max_images > 0:
|
||||||
|
images = [
|
||||||
|
(
|
||||||
|
img_batch[:max_images]
|
||||||
|
if isinstance(img_batch, (list, tuple))
|
||||||
|
else img_batch
|
||||||
|
)
|
||||||
|
for img_batch in images
|
||||||
|
]
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
# Preprocess the examples
|
||||||
|
examples = _preprocess(examples)
|
||||||
|
|
||||||
|
# Get the texts and images, and apply the chat template
|
||||||
|
texts = [
|
||||||
|
processor.apply_chat_template(
|
||||||
|
example["messages"], chat_template=chat_template, tokenize=False
|
||||||
|
)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
|
||||||
|
images = _process_images(examples, max_images=max_images)
|
||||||
|
|
||||||
|
# Tokenize the texts and process the images
|
||||||
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||||
|
labels = batch["input_ids"].clone()
|
||||||
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||||
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
|
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
|
processor.image_token
|
||||||
|
)
|
||||||
|
labels[labels == image_token_id] = -100
|
||||||
|
batch["labels"] = labels
|
||||||
|
|
||||||
|
if length_only:
|
||||||
|
return {
|
||||||
|
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|||||||
@@ -12,13 +12,19 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config
|
|
||||||
from axolotl.utils.schemas.config import (
|
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||||
|
)
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
|
DPODataset,
|
||||||
|
KTODataset,
|
||||||
|
SFTDataset,
|
||||||
|
)
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model_config
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -125,9 +131,6 @@ def normalize_config(cfg):
|
|||||||
with open(ds_config_path, encoding="utf-8") as f:
|
with open(ds_config_path, encoding="utf-8") as f:
|
||||||
cfg.deepspeed = json.load(f)
|
cfg.deepspeed = json.load(f)
|
||||||
|
|
||||||
if cfg.sequence_parallel_degree is None:
|
|
||||||
cfg.sequence_parallel_degree = 1
|
|
||||||
|
|
||||||
if cfg.saves_per_epoch:
|
if cfg.saves_per_epoch:
|
||||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
if save_steps < 1.0: # prevent saves on every step
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
@@ -158,7 +161,7 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
cfg.is_multimodal = (
|
cfg.is_multimodal = (
|
||||||
hasattr(model_config, "model_type")
|
hasattr(model_config, "model_type")
|
||||||
and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING
|
and model_config.model_type in ["llava", "mllama"]
|
||||||
or any(
|
or any(
|
||||||
multimodal_name in cfg.base_model.lower()
|
multimodal_name in cfg.base_model.lower()
|
||||||
for multimodal_name in [
|
for multimodal_name in [
|
||||||
@@ -171,6 +174,7 @@ def normalize_config(cfg):
|
|||||||
cfg.processor_config = (
|
cfg.processor_config = (
|
||||||
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
||||||
)
|
)
|
||||||
|
model_config = model_config.text_config
|
||||||
|
|
||||||
cfg.model_config_type = model_config.model_type
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,8 @@
|
|||||||
"""Pydantic models for TRL trainer configuration"""
|
"""
|
||||||
|
GRPO specific configuration args
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -8,11 +12,11 @@ class TRLConfig(BaseModel):
|
|||||||
Input args for TRL.
|
Input args for TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beta: float | None = Field(
|
beta: Optional[float] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Beta for RL training"},
|
json_schema_extra={"description": "Beta for RL training"},
|
||||||
)
|
)
|
||||||
max_completion_length: int | None = Field(
|
max_completion_length: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the completion for RL training"
|
"description": "Maximum length of the completion for RL training"
|
||||||
@@ -21,50 +25,50 @@ class TRLConfig(BaseModel):
|
|||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||||
use_vllm: bool | None = Field(
|
use_vllm: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||||
)
|
)
|
||||||
vllm_device: str | None = Field(
|
vllm_device: Optional[str] = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Device to use for VLLM"},
|
json_schema_extra={"description": "Device to use for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_gpu_memory_utilization: float | None = Field(
|
vllm_gpu_memory_utilization: Optional[float] = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_dtype: str | None = Field(
|
vllm_dtype: Optional[str] = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Data type for VLLM"},
|
json_schema_extra={"description": "Data type for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_max_model_len: int | None = Field(
|
vllm_max_model_len: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the model context for VLLM"
|
"description": "Maximum length of the model context for VLLM"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
reward_funcs: list[str] | None = Field(
|
reward_funcs: Optional[list[str]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "List of reward functions to load"},
|
json_schema_extra={"description": "List of reward functions to load"},
|
||||||
)
|
)
|
||||||
reward_weights: list[float] | None = Field(
|
reward_weights: Optional[list[float]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_generations: int | None = Field(
|
num_generations: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log_completions: bool | None = Field(
|
log_completions: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to log completions"},
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
)
|
)
|
||||||
sync_ref_model: bool | None = Field(
|
sync_ref_model: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": (
|
"description": (
|
||||||
@@ -73,13 +77,13 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_mixup_alpha: float | None = Field(
|
ref_model_mixup_alpha: Optional[float] = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_sync_steps: int | None = Field(
|
ref_model_sync_steps: Optional[int] = Field(
|
||||||
default=64,
|
default=64,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||||
@@ -34,16 +34,12 @@ from transformers import ( # noqa: F401
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AwqConfig,
|
AwqConfig,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
Gemma3ForConditionalGeneration,
|
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
Mistral3ForConditionalGeneration,
|
|
||||||
MllamaForConditionalGeneration,
|
MllamaForConditionalGeneration,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
|
||||||
Qwen2VLForConditionalGeneration,
|
|
||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import (
|
from transformers.integrations.deepspeed import (
|
||||||
HfTrainerDeepSpeedConfig,
|
HfTrainerDeepSpeedConfig,
|
||||||
@@ -71,16 +67,7 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
|
|||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
|
||||||
"mllama": MllamaForConditionalGeneration,
|
|
||||||
"llava": LlavaForConditionalGeneration,
|
|
||||||
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
|
||||||
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
|
||||||
"mistral3": Mistral3ForConditionalGeneration,
|
|
||||||
"gemma3": Gemma3ForConditionalGeneration,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||||
@@ -109,21 +96,7 @@ def get_module_class_from_name(module, name):
|
|||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
if hasattr(model_config, "text_config"):
|
|
||||||
model_config = model_config.text_config
|
model_config = model_config.text_config
|
||||||
model_config.use_cache = False
|
|
||||||
elif hasattr(model_config, "get_text_config"):
|
|
||||||
model_config = model_config.get_text_config()
|
|
||||||
model_config.use_cache = False
|
|
||||||
|
|
||||||
# check if image_size is not set and load image size from model config if available
|
|
||||||
if (
|
|
||||||
cfg.image_size is None
|
|
||||||
and hasattr(model_config, "vision_config")
|
|
||||||
and hasattr(model_config.vision_config, "image_size")
|
|
||||||
):
|
|
||||||
cfg.image_size = model_config.vision_config.image_size
|
|
||||||
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
|
||||||
|
|
||||||
quant_config_exists = (
|
quant_config_exists = (
|
||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
@@ -462,31 +435,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attempt to load image size from processor if available
|
|
||||||
if (
|
|
||||||
cfg.image_size is None
|
|
||||||
and hasattr(processor, "size")
|
|
||||||
and any(dim in processor.size for dim in ["width", "height"])
|
|
||||||
):
|
|
||||||
im_width = None
|
|
||||||
im_height = None
|
|
||||||
if "width" in processor.size:
|
|
||||||
im_width = processor.size["width"]
|
|
||||||
if "height" in processor.size:
|
|
||||||
im_height = processor.size["height"]
|
|
||||||
|
|
||||||
# If both width and height are set, use a tuple
|
|
||||||
if im_width is not None and im_height is not None:
|
|
||||||
cfg.image_size = (im_width, im_height)
|
|
||||||
# If only width is set, use as integer
|
|
||||||
elif im_width is not None:
|
|
||||||
cfg.image_size = im_width
|
|
||||||
# If only height is set, use as integer
|
|
||||||
elif im_height is not None:
|
|
||||||
cfg.image_size = im_height
|
|
||||||
|
|
||||||
LOG.debug(f"Loaded image size: {cfg.image_size} from processor")
|
|
||||||
|
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
@@ -524,15 +472,11 @@ class ModelLoader:
|
|||||||
# init model config
|
# init model config
|
||||||
self.model_config = load_model_config(cfg)
|
self.model_config = load_model_config(cfg)
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
if hasattr(self.model_config, "text_config"):
|
|
||||||
self.text_model_config = self.model_config.text_config
|
self.text_model_config = self.model_config.text_config
|
||||||
else:
|
|
||||||
# for qwen2_vl
|
|
||||||
self.text_model_config = self.model_config.get_text_config()
|
|
||||||
else:
|
else:
|
||||||
self.text_model_config = self.model_config
|
self.text_model_config = self.model_config
|
||||||
|
|
||||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
# load any patches from plugins
|
# load any patches from plugins
|
||||||
@@ -603,14 +547,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
|
||||||
|
|
||||||
# Initialize ring attn for sequence parallelism. This must be done after
|
|
||||||
# model init but before the first forward pass, since it modifies flash
|
|
||||||
# attn to use ring comm for SP training across multiple GPUs.
|
|
||||||
register_ring_attn(self.cfg.sequence_parallel_degree)
|
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||||
@@ -667,7 +603,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
def patch_llama_derived_model(self):
|
def patch_llama_derived_model(self) -> None:
|
||||||
"""Modify all llama derived models in one block"""
|
"""Modify all llama derived models in one block"""
|
||||||
self.patch_loss_llama()
|
self.patch_loss_llama()
|
||||||
|
|
||||||
@@ -717,15 +653,24 @@ class ModelLoader:
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_auto_model_loader(self):
|
def set_auto_model_loader(self) -> None:
|
||||||
"""
|
"""set self.AutoModelLoader
|
||||||
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
|
- default value: AutoModelForCausalLM (set at __init__)
|
||||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
- when using a multi modality model, self.AutoModelLoader should
|
||||||
should be set according to the type of the model.
|
be set according to model type of the model
|
||||||
"""
|
"""
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
if self.model_config.model_type == "llava":
|
||||||
self.model_config.model_type, AutoModelForVision2Seq
|
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||||
|
LlavaForConditionalGeneration
|
||||||
|
)
|
||||||
|
elif self.model_config.model_type == "mllama":
|
||||||
|
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||||
|
MllamaForConditionalGeneration
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.AutoModelLoader = (
|
||||||
|
AutoModelForVision2Seq # pylint: disable=invalid-name
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_device_map_config(self) -> None:
|
def set_device_map_config(self) -> None:
|
||||||
@@ -750,7 +695,7 @@ class ModelLoader:
|
|||||||
from accelerate import infer_auto_device_map
|
from accelerate import infer_auto_device_map
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model_canvas = self.auto_model_loader.from_config(
|
model_canvas = self.AutoModelLoader.from_config(
|
||||||
self.model_config,
|
self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
)
|
)
|
||||||
@@ -971,23 +916,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
# Load model with random initialization if specified
|
|
||||||
if self.cfg.random_init_weights:
|
|
||||||
# AutoModel classes support the from_config method
|
|
||||||
if self.auto_model_loader in [
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoModelForVision2Seq,
|
|
||||||
]:
|
|
||||||
self.model = self.auto_model_loader.from_config(
|
|
||||||
config=self.model_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = self.auto_model_loader(
|
|
||||||
config=self.model_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
@@ -1029,7 +958,7 @@ class ModelLoader:
|
|||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1062,7 +991,7 @@ class ModelLoader:
|
|||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1082,7 +1011,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1245,9 +1174,7 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None and not (
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
self.model_config.model_type == "llava"
|
|
||||||
):
|
|
||||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -1380,7 +1307,7 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
model_loader = ModelLoader(
|
loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@@ -1388,7 +1315,7 @@ def load_model(
|
|||||||
reference_model=reference_model,
|
reference_model=reference_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return model_loader.load_model()
|
return loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
|
|||||||
@@ -104,7 +104,9 @@ def allocate(
|
|||||||
|
|
||||||
|
|
||||||
class MultipackBatchSampler(BatchSampler):
|
class MultipackBatchSampler(BatchSampler):
|
||||||
"""Batch sampler class for multipack"""
|
"""
|
||||||
|
Batch Sampler class for multipack
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,165 +0,0 @@
|
|||||||
"""Pydantic models for datasets-related configuration"""
|
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import ChatTemplate
|
|
||||||
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedPrompterType(BaseModel):
|
|
||||||
"""Structure for user defined prompt types"""
|
|
||||||
|
|
||||||
system_prompt: str | None = None
|
|
||||||
system_format: str | None = None
|
|
||||||
field_system: str | None = None
|
|
||||||
field_instruction: str | None = None
|
|
||||||
field_input: str | None = None
|
|
||||||
field_output: str | None = None
|
|
||||||
|
|
||||||
format: str | None = None
|
|
||||||
no_input_format: str | None = None
|
|
||||||
field: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SFTDataset(BaseModel):
|
|
||||||
"""SFT configuration subset"""
|
|
||||||
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = None
|
|
||||||
type: str | UserDefinedPrompterType | None = None
|
|
||||||
input_transform: str | None = None
|
|
||||||
shards: int | None = None
|
|
||||||
shards_idx: int | None = None
|
|
||||||
preprocess_shards: int | None = None
|
|
||||||
conversation: str | None = None
|
|
||||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
|
||||||
chat_template: ChatTemplate | str | None = None
|
|
||||||
chat_template_jinja: str | None = None
|
|
||||||
data_files: str | list[str] | None = None
|
|
||||||
input_format: str | None = None
|
|
||||||
name: str | None = None
|
|
||||||
ds_type: str | None = None
|
|
||||||
train_on_split: str | None = None
|
|
||||||
field: str | None = None
|
|
||||||
field_human: str | None = None
|
|
||||||
field_model: str | None = None
|
|
||||||
field_messages: str | None = None
|
|
||||||
# deprecated, use message_property_mappings
|
|
||||||
message_field_role: str | None = None
|
|
||||||
# deprecated, use message_property_mappings
|
|
||||||
message_field_content: str | None = None
|
|
||||||
message_property_mappings: dict[str, str] | None = None
|
|
||||||
message_field_training: str | None = None
|
|
||||||
message_field_training_detail: str | None = None
|
|
||||||
logprobs_field: str | None = None
|
|
||||||
temperature: float | None = None
|
|
||||||
roles_to_train: list[str] | None = None
|
|
||||||
train_on_eos: str | None = None
|
|
||||||
roles: dict[str, list[str]] | None = None
|
|
||||||
drop_system_message: bool | None = None
|
|
||||||
trust_remote_code: bool | None = False
|
|
||||||
revision: str | None = None
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def handle_legacy_message_fields(cls, data):
|
|
||||||
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
|
||||||
return handle_legacy_message_fields_logic(data)
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def check_chat_template_config(cls, data):
|
|
||||||
if isinstance(data, BaseModel):
|
|
||||||
data = data.model_dump()
|
|
||||||
|
|
||||||
# Set chat_template to tokenizer_default if not set
|
|
||||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
|
||||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
|
||||||
|
|
||||||
# if chat_template is set to jinja, chat_template_jinja is required
|
|
||||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
|
||||||
"chat_template_jinja"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"chat_template_jinja is required when chat_template is set to jinja"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If chat_template_jinja is set, set chat_template to jinja
|
|
||||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
|
||||||
data["chat_template"] = ChatTemplate.jinja
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class PretrainingDataset(BaseModel):
|
|
||||||
"""Pretraining dataset configuration subset"""
|
|
||||||
|
|
||||||
name: str | None = None
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = "train"
|
|
||||||
text_column: str | None = "text"
|
|
||||||
type: str | None = "pretrain"
|
|
||||||
trust_remote_code: bool | None = False
|
|
||||||
data_files: str | None = None
|
|
||||||
skip: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
|
||||||
"""User defined typing for DPO"""
|
|
||||||
|
|
||||||
field_system: str | None = None
|
|
||||||
field_prompt: str | None = None
|
|
||||||
field_chosen: str | None = None
|
|
||||||
field_rejected: str | None = None
|
|
||||||
prompt_format: str | None = None
|
|
||||||
chosen_format: str | None = None
|
|
||||||
rejected_format: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class DPODataset(BaseModel):
|
|
||||||
"""DPO configuration subset"""
|
|
||||||
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = None
|
|
||||||
type: UserDefinedDPOType | str | None = None
|
|
||||||
data_files: list[str] | None = None
|
|
||||||
revision: str | None = None
|
|
||||||
field_messages: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class StepwiseSupervisedDataset(BaseModel):
|
|
||||||
"""Stepwise supervised dataset configuration subset"""
|
|
||||||
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = None
|
|
||||||
data_files: list[str] | None = None
|
|
||||||
revision: str | None = None
|
|
||||||
step_separator: str | None = None
|
|
||||||
max_completion_length: int | None = None
|
|
||||||
train_on_last_step_only: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedKTOType(BaseModel):
|
|
||||||
"""User defined typing for KTO"""
|
|
||||||
|
|
||||||
field_system: str | None = None
|
|
||||||
field_prompt: str | None = None
|
|
||||||
field_completion: str | None = None
|
|
||||||
field_label: bool | None = None
|
|
||||||
prompt_format: str | None = None
|
|
||||||
completion_format: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class KTODataset(BaseModel):
|
|
||||||
"""KTO configuration subset"""
|
|
||||||
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = None
|
|
||||||
type: UserDefinedKTOType | str | None = None
|
|
||||||
data_files: list[str] | None = None
|
|
||||||
trust_remote_code: bool | None = False
|
|
||||||
revision: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
"""Pydantic models for deprecated and remapped configuration parameters"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
|
||||||
"""configurations that are deprecated"""
|
|
||||||
|
|
||||||
max_packed_sequence_len: int | None = None
|
|
||||||
rope_scaling: Any | None = None
|
|
||||||
noisy_embedding_alpha: float | None = None
|
|
||||||
dpo_beta: float | None = None
|
|
||||||
evaluation_strategy: str | None = None
|
|
||||||
|
|
||||||
@field_validator("max_packed_sequence_len")
|
|
||||||
@classmethod
|
|
||||||
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
|
|
||||||
if max_packed_sequence_len:
|
|
||||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
|
||||||
return max_packed_sequence_len
|
|
||||||
|
|
||||||
@field_validator("rope_scaling")
|
|
||||||
@classmethod
|
|
||||||
def validate_rope_scaling(cls, rope_scaling):
|
|
||||||
if rope_scaling:
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
|
|
||||||
)
|
|
||||||
return rope_scaling
|
|
||||||
|
|
||||||
@field_validator("noisy_embedding_alpha")
|
|
||||||
@classmethod
|
|
||||||
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
|
|
||||||
if noisy_embedding_alpha:
|
|
||||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
|
||||||
return noisy_embedding_alpha
|
|
||||||
|
|
||||||
@field_validator("dpo_beta")
|
|
||||||
@classmethod
|
|
||||||
def validate_dpo_beta(cls, dpo_beta):
|
|
||||||
if dpo_beta is not None:
|
|
||||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
|
||||||
return dpo_beta
|
|
||||||
|
|
||||||
@field_validator("evaluation_strategy")
|
|
||||||
@classmethod
|
|
||||||
def validate_evaluation_strategy(cls, evaluation_strategy):
|
|
||||||
if evaluation_strategy is not None:
|
|
||||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
|
||||||
return evaluation_strategy
|
|
||||||
|
|
||||||
|
|
||||||
class RemappedParameters(BaseModel):
|
|
||||||
"""Parameters that have been remapped to other names"""
|
|
||||||
|
|
||||||
overrides_of_model_config: dict[str, Any] | None = Field(
|
|
||||||
default=None, alias="model_config"
|
|
||||||
)
|
|
||||||
overrides_of_model_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None, alias="model_kwargs"
|
|
||||||
)
|
|
||||||
type_of_model: str | None = Field(default=None, alias="model_type")
|
|
||||||
revision_of_model: str | None = Field(default=None, alias="model_revision")
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
"""Enums for Axolotl input config"""
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class RLType(str, Enum):
|
|
||||||
"""RL trainer type configuration subset"""
|
|
||||||
|
|
||||||
dpo = "dpo" # pylint: disable=invalid-name
|
|
||||||
grpo = "grpo" # pylint: disable=invalid-name
|
|
||||||
ipo = "ipo" # pylint: disable=invalid-name
|
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
|
||||||
simpo = "simpo" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
|
||||||
"""Chat templates configuration subset"""
|
|
||||||
|
|
||||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
|
||||||
chatml = "chatml" # pylint: disable=invalid-name
|
|
||||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
|
||||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
|
||||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
|
||||||
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
|
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
|
||||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
|
||||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
|
||||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
|
||||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
|
||||||
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
|
||||||
jamba = "jamba" # pylint: disable=invalid-name
|
|
||||||
jinja = "jinja" # pylint: disable=invalid-name
|
|
||||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
|
||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
|
||||||
exaone = "exaone" # pylint: disable=invalid-name
|
|
||||||
metharme = "metharme" # pylint: disable=invalid-name
|
|
||||||
pixtral = "pixtral" # pylint: disable=invalid-name
|
|
||||||
llava = "llava" # pylint: disable=invalid-name
|
|
||||||
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
|
||||||
gemma3 = "gemma3" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class CustomSupportedOptimizers(str, Enum):
|
|
||||||
"""Custom supported optimizers"""
|
|
||||||
|
|
||||||
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
|
||||||
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
|
||||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
|
||||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
|
||||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
|
||||||
muon = "muon" # pylint: disable=invalid-name
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
"""Pydantic models for Axolotl integrations"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MLFlowConfig(BaseModel):
|
|
||||||
"""MLFlow configuration subset"""
|
|
||||||
|
|
||||||
use_mlflow: bool | None = None
|
|
||||||
mlflow_tracking_uri: str | None = None
|
|
||||||
mlflow_experiment_name: str | None = None
|
|
||||||
mlflow_run_name: str | None = None
|
|
||||||
hf_mlflow_log_artifacts: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LISAConfig(BaseModel):
|
|
||||||
"""LISA configuration subset"""
|
|
||||||
|
|
||||||
lisa_n_layers: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "the number of activate layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_step_interval: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "how often to switch layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_layers_attribute: str | None = Field(
|
|
||||||
default="model.layers",
|
|
||||||
json_schema_extra={"description": "path under the model to access the layers"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WandbConfig(BaseModel):
|
|
||||||
"""Wandb configuration subset"""
|
|
||||||
|
|
||||||
use_wandb: bool | None = None
|
|
||||||
wandb_name: str | None = None
|
|
||||||
wandb_run_id: str | None = None
|
|
||||||
wandb_mode: str | None = None
|
|
||||||
wandb_project: str | None = None
|
|
||||||
wandb_entity: str | None = None
|
|
||||||
wandb_watch: str | None = None
|
|
||||||
wandb_log_model: str | None = None
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_wandb_run(cls, data):
|
|
||||||
if data.get("wandb_run_id") and not data.get("wandb_name"):
|
|
||||||
data["wandb_name"] = data.get("wandb_run_id")
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class CometConfig(BaseModel):
|
|
||||||
"""Comet configuration subset"""
|
|
||||||
|
|
||||||
use_comet: bool | None = None
|
|
||||||
comet_api_key: str | None = None
|
|
||||||
comet_workspace: str | None = None
|
|
||||||
comet_project_name: str | None = None
|
|
||||||
comet_experiment_key: str | None = None
|
|
||||||
comet_mode: str | None = None
|
|
||||||
comet_online: bool | None = None
|
|
||||||
comet_experiment_config: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class GradioConfig(BaseModel):
|
|
||||||
"""Gradio configuration subset"""
|
|
||||||
|
|
||||||
gradio_title: str | None = None
|
|
||||||
gradio_share: bool | None = None
|
|
||||||
gradio_server_name: str | None = None
|
|
||||||
gradio_server_port: int | None = None
|
|
||||||
gradio_max_new_tokens: int | None = None
|
|
||||||
gradio_temperature: float | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class RayConfig(BaseModel):
|
|
||||||
"""Ray launcher configuration subset"""
|
|
||||||
|
|
||||||
use_ray: bool = Field(default=False)
|
|
||||||
ray_run_name: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"help": "The training results will be saved at `saves/ray_run_name`."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ray_num_workers: int = Field(
|
|
||||||
default=1,
|
|
||||||
json_schema_extra={
|
|
||||||
"help": "The number of workers for Ray training. Default is 1 worker."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resources_per_worker: dict = Field(
|
|
||||||
default_factory=lambda: {"GPU": 1},
|
|
||||||
json_schema_extra={
|
|
||||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
"""Pydantic models for model input / output, etc. configuration"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInputConfig(BaseModel):
|
|
||||||
"""Model configuration subset"""
|
|
||||||
|
|
||||||
model_config = {"protected_namespaces": ()}
|
|
||||||
|
|
||||||
base_model: str
|
|
||||||
base_model_config: str | None = None
|
|
||||||
cls_model_config: str | None = None
|
|
||||||
tokenizer_config: str | None = None
|
|
||||||
tokenizer_use_fast: bool | None = None
|
|
||||||
tokenizer_legacy: bool | None = None
|
|
||||||
tokenizer_type: str | None = Field(
|
|
||||||
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
|
||||||
)
|
|
||||||
processor_type: str | None = Field(
|
|
||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
|
||||||
)
|
|
||||||
trust_remote_code: bool | None = None
|
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
|
||||||
@classmethod
|
|
||||||
def hint_trust_remote_code(cls, trust_remote_code):
|
|
||||||
if trust_remote_code:
|
|
||||||
LOG.warning(
|
|
||||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
|
||||||
)
|
|
||||||
return trust_remote_code
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputConfig(BaseModel):
|
|
||||||
"""model save configuration subset"""
|
|
||||||
|
|
||||||
output_dir: str = Field(default="./model-out")
|
|
||||||
hub_model_id: str | None = None
|
|
||||||
hub_strategy: str | None = None
|
|
||||||
save_safetensors: bool | None = True
|
|
||||||
|
|
||||||
|
|
||||||
class SpecialTokensConfig(BaseModel):
|
|
||||||
"""Special tokens configuration subset"""
|
|
||||||
|
|
||||||
bos_token: str | None = None
|
|
||||||
eos_token: str | None = None
|
|
||||||
pad_token: str | None = None
|
|
||||||
unk_token: str | None = None
|
|
||||||
additional_special_tokens: list[str] | None = None
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
"""Pydantic models for multimodal-related configuration"""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from PIL.Image import Resampling
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalConfig(BaseModel):
|
|
||||||
"""Multi-modal configuration subset"""
|
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": (
|
|
||||||
"The size of the image to resize to. It can be an integer (resized into padded-square image) or a tuple (width, height)."
|
|
||||||
"If not provided, we will attempt to load from preprocessor.size, otherwise, images won't be resized."
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
image_resize_algorithm: (
|
|
||||||
Literal["bilinear", "bicubic", "lanczos"] | Resampling | None
|
|
||||||
) = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("image_resize_algorithm", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def convert_image_resize_algorithm(cls, image_resize_algorithm):
|
|
||||||
"""
|
|
||||||
Convert the image resize algorithm to a PIL.Image.Resampling enum.
|
|
||||||
"""
|
|
||||||
if isinstance(image_resize_algorithm, str):
|
|
||||||
image_resize_algorithm = image_resize_algorithm.lower()
|
|
||||||
if image_resize_algorithm == "bilinear":
|
|
||||||
image_resize_algorithm = Resampling.BILINEAR
|
|
||||||
elif image_resize_algorithm == "bicubic":
|
|
||||||
image_resize_algorithm = Resampling.BICUBIC
|
|
||||||
elif image_resize_algorithm == "lanczos":
|
|
||||||
image_resize_algorithm = Resampling.LANCZOS
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid image resize algorithm: {image_resize_algorithm}"
|
|
||||||
)
|
|
||||||
return image_resize_algorithm
|
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
"""Pydantic models for PEFT-related configuration"""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
|
||||||
"""LoftQ configuration subset"""
|
|
||||||
|
|
||||||
loftq_bits: int = Field(
|
|
||||||
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
|
||||||
)
|
|
||||||
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
|
||||||
|
|
||||||
|
|
||||||
class PeftConfig(BaseModel):
|
|
||||||
"""peftq configuration subset"""
|
|
||||||
|
|
||||||
loftq_config: LoftQConfig | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LoraConfig(BaseModel):
|
|
||||||
"""Peft / LoRA configuration subset"""
|
|
||||||
|
|
||||||
load_in_8bit: bool | None = Field(default=False)
|
|
||||||
load_in_4bit: bool | None = Field(default=False)
|
|
||||||
|
|
||||||
adapter: str | None = None
|
|
||||||
lora_model_dir: str | None = None
|
|
||||||
lora_r: int | None = None
|
|
||||||
lora_alpha: int | None = None
|
|
||||||
lora_fan_in_fan_out: bool | None = None
|
|
||||||
lora_target_modules: str | list[str] | None = None
|
|
||||||
lora_target_linear: bool | None = None
|
|
||||||
lora_modules_to_save: list[str] | None = None
|
|
||||||
lora_dropout: float | None = 0.0
|
|
||||||
peft_layers_to_transform: list[int] | None = None
|
|
||||||
peft_layers_pattern: list[str] | None = None
|
|
||||||
peft: PeftConfig | None = None
|
|
||||||
peft_use_dora: bool | None = None
|
|
||||||
peft_use_rslora: bool | None = None
|
|
||||||
peft_layer_replication: list[tuple[int, int]] | None = None
|
|
||||||
peft_init_lora_weights: bool | str | None = None
|
|
||||||
|
|
||||||
qlora_sharded_model_loading: bool | None = Field(
|
|
||||||
default=False,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
lora_on_cpu: bool | None = None
|
|
||||||
gptq: bool | None = None
|
|
||||||
bnb_config_kwargs: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
loraplus_lr_ratio: float | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
loraplus_lr_embedding: float | None = Field(
|
|
||||||
default=1e-6,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "loraplus learning rate for lora embedding layers."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
merge_lora: bool | None = None
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_adapter(cls, data):
|
|
||||||
if (
|
|
||||||
not data.get("adapter")
|
|
||||||
and not data.get("inference")
|
|
||||||
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
|
||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_qlora(self):
|
|
||||||
if self.adapter == "qlora":
|
|
||||||
if self.merge_lora:
|
|
||||||
# can't merge qlora if loaded in 8bit or 4bit
|
|
||||||
if self.load_in_8bit:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
|
||||||
|
|
||||||
if self.gptq:
|
|
||||||
raise ValueError("Can't merge qlora if gptq")
|
|
||||||
|
|
||||||
if self.load_in_4bit:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
|
||||||
|
|
||||||
else:
|
|
||||||
if self.load_in_8bit:
|
|
||||||
raise ValueError("Can't load qlora in 8bit")
|
|
||||||
|
|
||||||
if self.gptq:
|
|
||||||
raise ValueError("Can't load qlora if gptq")
|
|
||||||
|
|
||||||
if not self.load_in_4bit:
|
|
||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
|
||||||
return self
|
|
||||||
|
|
||||||
@field_validator("loraplus_lr_embedding")
|
|
||||||
@classmethod
|
|
||||||
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
|
|
||||||
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
|
|
||||||
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
|
||||||
return loraplus_lr_embedding
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_lora_dropout(cls, data):
|
|
||||||
if data.get("adapter") is not None and data.get("lora_dropout") is None:
|
|
||||||
data["lora_dropout"] = 0.0
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
|
||||||
"""ReLoRA configuration subset"""
|
|
||||||
|
|
||||||
relora_steps: int | None = None
|
|
||||||
relora_warmup_steps: int | None = None
|
|
||||||
relora_anneal_steps: int | None = None
|
|
||||||
relora_prune_ratio: float | None = None
|
|
||||||
relora_cpu_offload: bool | None = None
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
"""Pydantic models for training hyperparameters"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
from transformers import SchedulerType
|
|
||||||
from transformers.training_args import OptimizerNames
|
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LrGroup(BaseModel):
|
|
||||||
"""Custom learning rate group configuration"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
modules: list[str]
|
|
||||||
lr: float
|
|
||||||
|
|
||||||
|
|
||||||
class HyperparametersConfig(BaseModel):
|
|
||||||
"""Training hyperparams configuration subset"""
|
|
||||||
|
|
||||||
gradient_accumulation_steps: int | None = Field(default=1)
|
|
||||||
micro_batch_size: int | None = Field(
|
|
||||||
default=1,
|
|
||||||
json_schema_extra={"description": "per gpu micro batch size for training"},
|
|
||||||
)
|
|
||||||
batch_size: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Total batch size, we do not recommended setting this manually"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
eval_batch_size: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
auto_find_batch_size: bool | None = None
|
|
||||||
|
|
||||||
train_on_inputs: bool | None = False
|
|
||||||
group_by_length: bool | None = None
|
|
||||||
|
|
||||||
learning_rate: str | float
|
|
||||||
embedding_lr: float | None = None
|
|
||||||
embedding_lr_scale: float | None = None
|
|
||||||
weight_decay: float | None = 0.0
|
|
||||||
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
|
|
||||||
OptimizerNames.ADAMW_TORCH_FUSED
|
|
||||||
)
|
|
||||||
optim_args: (str | dict[str, Any]) | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
|
||||||
)
|
|
||||||
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
torchdistx_path: str | None = None
|
|
||||||
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
|
|
||||||
SchedulerType.COSINE
|
|
||||||
)
|
|
||||||
lr_scheduler_kwargs: dict[str, Any] | None = None
|
|
||||||
lr_quadratic_warmup: bool | None = None
|
|
||||||
cosine_min_lr_ratio: float | None = None
|
|
||||||
cosine_constant_lr_ratio: float | None = None
|
|
||||||
lr_div_factor: float | None = None
|
|
||||||
lr_groups: list[LrGroup] | None = None
|
|
||||||
|
|
||||||
adam_epsilon: float | None = None
|
|
||||||
adam_beta1: float | None = None
|
|
||||||
adam_beta2: float | None = None
|
|
||||||
max_grad_norm: float | None = None
|
|
||||||
num_epochs: float = Field(default=1.0)
|
|
||||||
|
|
||||||
@field_validator("batch_size")
|
|
||||||
@classmethod
|
|
||||||
def hint_batch_size_set(cls, batch_size):
|
|
||||||
if batch_size:
|
|
||||||
LOG.warning(
|
|
||||||
"%s\n%s",
|
|
||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
|
||||||
)
|
|
||||||
return batch_size
|
|
||||||
|
|
||||||
@field_validator("learning_rate")
|
|
||||||
@classmethod
|
|
||||||
def convert_learning_rate(cls, learning_rate):
|
|
||||||
if learning_rate and isinstance(learning_rate, str):
|
|
||||||
learning_rate = float(learning_rate)
|
|
||||||
return learning_rate
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
"""Utilities for Axolotl Pydantic models"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
|
||||||
|
|
||||||
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
|
||||||
- message_field_role: Mapped to the role field
|
|
||||||
- message_field_content: Mapped to the content field
|
|
||||||
|
|
||||||
The new system uses message_property_mappings to support arbitrary field mappings:
|
|
||||||
message_property_mappings:
|
|
||||||
role: source_role_field
|
|
||||||
content: source_content_field
|
|
||||||
additional_field: source_field
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Dictionary containing configuration data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated dictionary with message field mappings consolidated
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If there are conflicts between legacy and new mappings
|
|
||||||
"""
|
|
||||||
data = data.copy() # Create a copy to avoid modifying the original
|
|
||||||
|
|
||||||
if data.get("message_property_mappings") is None:
|
|
||||||
data["message_property_mappings"] = {}
|
|
||||||
|
|
||||||
# Check for conflicts and handle role
|
|
||||||
if "message_field_role" in data:
|
|
||||||
LOG.warning(
|
|
||||||
"message_field_role is deprecated, use message_property_mappings instead. "
|
|
||||||
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"role" in data["message_property_mappings"]
|
|
||||||
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
|
||||||
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
|
||||||
)
|
|
||||||
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
|
||||||
|
|
||||||
del data["message_field_role"]
|
|
||||||
elif "role" not in data["message_property_mappings"]:
|
|
||||||
data["message_property_mappings"]["role"] = "role"
|
|
||||||
|
|
||||||
# Check for conflicts and handle content
|
|
||||||
if "message_field_content" in data:
|
|
||||||
LOG.warning(
|
|
||||||
"message_field_content is deprecated, use message_property_mappings instead. "
|
|
||||||
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"content" in data["message_property_mappings"]
|
|
||||||
and data["message_property_mappings"]["content"]
|
|
||||||
!= data["message_field_content"]
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
|
||||||
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
|
||||||
)
|
|
||||||
data["message_property_mappings"]["content"] = (
|
|
||||||
data["message_field_content"] or "content"
|
|
||||||
)
|
|
||||||
|
|
||||||
del data["message_field_content"]
|
|
||||||
elif "content" not in data["message_property_mappings"]:
|
|
||||||
data["message_property_mappings"]["content"] = "content"
|
|
||||||
|
|
||||||
return data
|
|
||||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
elif cfg.sample_packing:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
if cfg.eval_sample_packing is not False:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
@@ -443,7 +443,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||||
@@ -474,11 +473,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(
|
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||||
math.floor(
|
|
||||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
@@ -499,12 +494,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
len(train_dataset)
|
|
||||||
* cfg.num_epochs
|
|
||||||
* cfg.sequence_parallel_degree
|
|
||||||
/ cfg.batch_size
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|||||||
90
styles.css
90
styles.css
@@ -14,7 +14,7 @@
|
|||||||
h1 {
|
h1 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 3rem;
|
font-size: 5rem;
|
||||||
line-height: 1.1;
|
line-height: 1.1;
|
||||||
letter-spacing: -0.05em;
|
letter-spacing: -0.05em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -24,7 +24,7 @@ h1 {
|
|||||||
h2 {
|
h2 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
font-size: 1.5rem;
|
font-size: 2rem;
|
||||||
line-height: 1.2;
|
line-height: 1.2;
|
||||||
letter-spacing: -0.03em;
|
letter-spacing: -0.03em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -35,7 +35,7 @@ h3,
|
|||||||
h4 {
|
h4 {
|
||||||
font-family: var(--font-body);
|
font-family: var(--font-body);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 1.25rem;
|
font-size: 1.5rem;
|
||||||
line-height: 1.5;
|
line-height: 1.5;
|
||||||
letter-spacing: -0.02em;
|
letter-spacing: -0.02em;
|
||||||
}
|
}
|
||||||
@@ -191,87 +191,3 @@ code span.er {
|
|||||||
color: #5cb85c !important;
|
color: #5cb85c !important;
|
||||||
text-decoration: none !important;
|
text-decoration: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* API Documentation Styling */
|
|
||||||
|
|
||||||
/* Improve docstring section rendering */
|
|
||||||
.level3 p {
|
|
||||||
white-space: pre-line !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Format docstring sections */
|
|
||||||
.level3 p strong {
|
|
||||||
display: block;
|
|
||||||
margin-top: 1em;
|
|
||||||
font-weight: bold;
|
|
||||||
color: var(--cyan);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Add spacing after sections */
|
|
||||||
.level3 p:has(strong) {
|
|
||||||
margin-bottom: 0.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Format Args and Returns sections */
|
|
||||||
p:has(code) {
|
|
||||||
line-height: 1.6;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Function signatures */
|
|
||||||
.sourceCode {
|
|
||||||
margin-bottom: 1.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Parameter tables */
|
|
||||||
.doc-section-parameters table,
|
|
||||||
.doc-section-returns table {
|
|
||||||
margin-top: 1em;
|
|
||||||
margin-bottom: 1.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Make parameter and returns headers smaller */
|
|
||||||
h2.anchored[data-anchor-id="parameters"],
|
|
||||||
h2.anchored[data-anchor-id="returns"],
|
|
||||||
.doc-section-parameters h4,
|
|
||||||
.doc-section-returns h4 {
|
|
||||||
font-size: 1.25rem;
|
|
||||||
margin-top: 2rem;
|
|
||||||
margin-bottom: 1rem;
|
|
||||||
color: var(--lime);
|
|
||||||
border-bottom: 1px solid var(--lime);
|
|
||||||
padding-bottom: 0.3rem;
|
|
||||||
font-family: var(--font-body);
|
|
||||||
font-weight: 500;
|
|
||||||
letter-spacing: normal;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Style documentation tables */
|
|
||||||
table {
|
|
||||||
width: 100%;
|
|
||||||
margin-bottom: 1.5rem;
|
|
||||||
border-collapse: collapse;
|
|
||||||
}
|
|
||||||
|
|
||||||
table th {
|
|
||||||
background-color: #1a1a1a;
|
|
||||||
padding: 0.5rem 1rem;
|
|
||||||
border-bottom: 2px solid var(--greige-600);
|
|
||||||
text-align: left;
|
|
||||||
}
|
|
||||||
|
|
||||||
table td {
|
|
||||||
padding: 0.5rem 1rem;
|
|
||||||
border-bottom: 1px solid var(--greige-600);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Code in table cells */
|
|
||||||
table td code {
|
|
||||||
background-color: transparent !important;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Improve spacing in parameter and return tables */
|
|
||||||
.doc-section-parameters,
|
|
||||||
.doc-section-returns {
|
|
||||||
margin-top: 1rem;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
|
|||||||
def test_geglu_model_integration():
|
def test_geglu_model_integration():
|
||||||
"""Test GeGLU activation with Gemma model."""
|
"""Test GeGLU activation with Gemma model."""
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
|
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
|
||||||
)
|
)
|
||||||
peft_config = get_peft_config(
|
peft_config = get_peft_config(
|
||||||
{
|
{
|
||||||
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
|
|||||||
"""Test LoRA kernel patches across different model architectures."""
|
"""Test LoRA kernel patches across different model architectures."""
|
||||||
# Load model with appropriate dtype
|
# Load model with appropriate dtype
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
|
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply LoRA configuration
|
# Apply LoRA configuration
|
||||||
@@ -408,7 +408,7 @@ def test_kernel_training_integration():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, _, _ = load_model_and_tokenizer(cfg=cfg)
|
model, _ = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
# Verify correct activation function
|
# Verify correct activation function
|
||||||
layer = model.model.model.layers[0]
|
layer = model.model.model.layers[0]
|
||||||
|
|||||||
@@ -1,209 +0,0 @@
|
|||||||
"""Tests for sequence parallelism functionality."""
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name,unused-argument
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from accelerate.state import PartialState
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
|
||||||
get_ring_attn_group,
|
|
||||||
set_ring_attn_group,
|
|
||||||
)
|
|
||||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def partial_state():
|
|
||||||
"""Create a real PartialState instance for testing."""
|
|
||||||
state = PartialState()
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="cfg")
|
|
||||||
def fixture_cfg():
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"learning_rate": 1e-3,
|
|
||||||
"output_dir": "./model-out",
|
|
||||||
"sequence_len": 512,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
class TestSequenceParallelHelpers:
|
|
||||||
"""Test helper functions used in sequence parallelism."""
|
|
||||||
|
|
||||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
|
||||||
"""Test position_ids adjustment for sequence slices."""
|
|
||||||
# Create sample position_ids with multiple sequences
|
|
||||||
position_ids = torch.tensor(
|
|
||||||
[
|
|
||||||
# First sequence with 2 samples
|
|
||||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
|
||||||
# Second sequence with 3 samples
|
|
||||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adjust as if this was the second slice (start_idx = 4)
|
|
||||||
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
|
||||||
|
|
||||||
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
|
||||||
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
|
||||||
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
|
||||||
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
|
||||||
|
|
||||||
assert torch.all(adjusted[0] == expected_first_seq)
|
|
||||||
assert torch.all(adjusted[1] == expected_second_seq)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRingAttention:
|
|
||||||
"""Tests for the ring attention functionality."""
|
|
||||||
|
|
||||||
@patch("torch.distributed.get_rank")
|
|
||||||
@patch("torch.distributed.get_world_size")
|
|
||||||
def test_get_ring_attn_group_no_registration(
|
|
||||||
self, mock_world_size, mock_rank, partial_state
|
|
||||||
):
|
|
||||||
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
|
||||||
# Setup mocks
|
|
||||||
mock_world_size.return_value = 4
|
|
||||||
mock_rank.return_value = 0
|
|
||||||
|
|
||||||
# Get the group without registration
|
|
||||||
group = get_ring_attn_group()
|
|
||||||
|
|
||||||
# Verify that None was returned
|
|
||||||
assert group is None
|
|
||||||
|
|
||||||
@patch("torch.distributed.new_group")
|
|
||||||
@patch("torch.distributed.get_rank")
|
|
||||||
@patch("torch.distributed.get_world_size")
|
|
||||||
def test_register_ring_attn(
|
|
||||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
|
||||||
):
|
|
||||||
"""Test that ring attention groups are created correctly."""
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
|
||||||
|
|
||||||
# Setup mocks
|
|
||||||
mock_world_size.return_value = 8 # 8 GPUs total
|
|
||||||
mock_rank.return_value = 3 # GPU #3
|
|
||||||
mock_group = MagicMock()
|
|
||||||
mock_new_group.return_value = mock_group
|
|
||||||
|
|
||||||
# Call register_ring_attn with size 4
|
|
||||||
register_ring_attn(sequence_parallel_degree=4)
|
|
||||||
|
|
||||||
# Verify the number of calls without examining the arguments
|
|
||||||
assert mock_new_group.call_count == 2
|
|
||||||
|
|
||||||
# Verify that new_group was called
|
|
||||||
mock_new_group.assert_called()
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
set_ring_attn_group(None)
|
|
||||||
|
|
||||||
|
|
||||||
# Mock a simplified DataCollator test
|
|
||||||
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
|
|
||||||
@patch("torch.distributed.get_rank")
|
|
||||||
@patch("torch.distributed.get_world_size")
|
|
||||||
def test_sequence_parallel_slicing(
|
|
||||||
mock_world_size, mock_rank, mock_get_group, partial_state
|
|
||||||
):
|
|
||||||
"""Test the basic sequence slicing logic without full collator instantiation."""
|
|
||||||
# Setup mocks
|
|
||||||
mock_get_group.return_value = MagicMock()
|
|
||||||
mock_rank.return_value = 1 # Second GPU
|
|
||||||
mock_world_size.return_value = 4 # 4 GPUs total
|
|
||||||
|
|
||||||
# Create a sample batch
|
|
||||||
batch = {
|
|
||||||
"input_ids": torch.tensor(
|
|
||||||
[
|
|
||||||
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
|
|
||||||
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
|
|
||||||
]
|
|
||||||
),
|
|
||||||
"attention_mask": torch.ones(2, 12),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Simplified slicing logic from SequenceParallelDataCollator
|
|
||||||
def slice_batch(batch, rank, world_size):
|
|
||||||
result = {}
|
|
||||||
for key in batch:
|
|
||||||
seq_len = batch[key].shape[1]
|
|
||||||
slice_size = seq_len // world_size
|
|
||||||
start_idx = rank * slice_size
|
|
||||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
|
||||||
result[key] = batch[key][:, start_idx:end_idx]
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Slice the batch
|
|
||||||
result = slice_batch(
|
|
||||||
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check slicing
|
|
||||||
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
|
|
||||||
expected_input_ids = torch.tensor(
|
|
||||||
[
|
|
||||||
[104, 105, 106], # Second slice of first sequence
|
|
||||||
[204, 205, 206], # Second slice of second sequence
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert torch.all(result["input_ids"] == expected_input_ids)
|
|
||||||
|
|
||||||
|
|
||||||
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
|
|
||||||
def test_config_validation_with_valid_inputs(cfg):
|
|
||||||
"""Test that valid sequence parallelism configurations pass validation."""
|
|
||||||
# Import the actual model class with appropriate mocks
|
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
|
||||||
|
|
||||||
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
|
||||||
cfg = cfg | {
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Should validate without errors
|
|
||||||
config = AxolotlInputConfig(**cfg)
|
|
||||||
assert config.sequence_parallel_degree == 2
|
|
||||||
assert config.flash_attention is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_config_validation_with_invalid_inputs(cfg):
|
|
||||||
"""Test that invalid sequence parallelism configurations fail validation."""
|
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
|
||||||
|
|
||||||
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
|
||||||
cfg = cfg | {
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Should raise ValidationError
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
|
||||||
AxolotlInputConfig(**cfg)
|
|
||||||
|
|
||||||
# Verify error message
|
|
||||||
assert "flash_attention: true must be set" in str(excinfo.value)
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
E2E tests for deepseekv3
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for gemma2
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGemma2:
|
|
||||||
"""
|
|
||||||
Test case for Gemma2 models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sample_packing",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_lora_gemma2(self, temp_dir, sample_packing):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "axolotl-ai-co/gemma-2-33M",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_property_mappings": {
|
|
||||||
"role": "from",
|
|
||||||
"content": "value",
|
|
||||||
},
|
|
||||||
"drop_system_message": True,
|
|
||||||
"split": "train[:1%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"special_tokens": {
|
|
||||||
"bos_token": "<bos>",
|
|
||||||
"eos_token": "<eos>",
|
|
||||||
},
|
|
||||||
"chat_template": "gemma", # gemma2's template is same as gemma
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sample_packing",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_fft_gemma2(self, temp_dir, sample_packing):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "axolotl-ai-co/gemma-2-33M",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"val_set_size": 0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_property_mappings": {
|
|
||||||
"role": "from",
|
|
||||||
"content": "value",
|
|
||||||
},
|
|
||||||
"split": "train[:1%]",
|
|
||||||
"drop_system_message": True,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"chat_template": "gemma", # gemma2's template is same as gemma
|
|
||||||
"special_tokens": {
|
|
||||||
"bos_token": "<bos>",
|
|
||||||
"eos_token": "<eos>",
|
|
||||||
},
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for gemma3_text
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGemma3Text:
|
|
||||||
"""
|
|
||||||
Test case for Gemma3Text models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sample_packing",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_lora_gemma3_text(self, temp_dir, sample_packing):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "axolotl-ai-co/gemma-3-34M",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_property_mappings": {
|
|
||||||
"role": "from",
|
|
||||||
"content": "value",
|
|
||||||
},
|
|
||||||
"split": "train[:1%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"special_tokens": {
|
|
||||||
"bos_token": "<bos>",
|
|
||||||
"eos_token": "<eos>",
|
|
||||||
},
|
|
||||||
"chat_template": "gemma3",
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sample_packing",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_fft_gemma3_text(self, temp_dir, sample_packing):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "axolotl-ai-co/gemma-3-34M",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"val_set_size": 0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_property_mappings": {
|
|
||||||
"role": "from",
|
|
||||||
"content": "value",
|
|
||||||
},
|
|
||||||
"split": "train[:1%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"chat_template": "gemma3",
|
|
||||||
"special_tokens": {
|
|
||||||
"bos_token": "<bos>",
|
|
||||||
"eos_token": "<eos>",
|
|
||||||
},
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
|
||||||
@@ -54,7 +54,7 @@ class TestCustomSchedulers(unittest.TestCase):
|
|||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_hf",
|
||||||
"max_steps": 20,
|
"max_steps": 20,
|
||||||
"lr_scheduler": "rex",
|
"lr_scheduler": "rex",
|
||||||
"warmup_steps": 5,
|
"warmup_steps": 5,
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from axolotl.utils import is_comet_available
|
from axolotl.utils import is_comet_available
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import check_model_config
|
from axolotl.utils.models import check_model_config
|
||||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||||
@@ -263,7 +262,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||||
self.cfg_1 = DictDefault(
|
self.cfg_1 = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "huggyllama/llama-7b",
|
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"dataset_exact_deduplication": True,
|
"dataset_exact_deduplication": True,
|
||||||
@@ -284,7 +282,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(self.cfg_1)
|
|
||||||
|
|
||||||
def test_prepare_dataset_with_deduplication_train(self):
|
def test_prepare_dataset_with_deduplication_train(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.schemas.datasets import ChatTemplate
|
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class TestModelsUtils:
|
|||||||
|
|
||||||
def test_message_property_mapping(self):
|
def test_message_property_mapping(self):
|
||||||
"""Test message property mapping configuration validation"""
|
"""Test message property mapping configuration validation"""
|
||||||
from axolotl.utils.schemas.datasets import SFTDataset
|
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
||||||
|
|
||||||
# Test legacy fields are mapped orrectly
|
# Test legacy fields are mapped orrectly
|
||||||
dataset = SFTDataset(
|
dataset = SFTDataset(
|
||||||
|
|||||||
Reference in New Issue
Block a user