Compare commits
18 Commits
feat/soap-
...
quartodoc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bffef25d0 | ||
|
|
94c00c1d04 | ||
|
|
ddd84d7c65 | ||
|
|
42bdf0bd74 | ||
|
|
b03d96a228 | ||
|
|
2653f170fc | ||
|
|
3bfcce9f0a | ||
|
|
8feb746953 | ||
|
|
a563815fe7 | ||
|
|
81f2203151 | ||
|
|
5b7e688fc5 | ||
|
|
5134aa66cd | ||
|
|
ba9a867adb | ||
|
|
c618f42c39 | ||
|
|
fc1f985296 | ||
|
|
a5e37f183c | ||
|
|
e6a7bbe9ff | ||
|
|
e4fd7aad0b |
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter quartodoc
|
python3 -m pip install jupyter quartodoc
|
||||||
python3 -m pip install -e . --no-deps
|
python3 -m pip install -e .
|
||||||
- name: Build autodoc
|
- name: Build autodoc
|
||||||
run: quartodoc build
|
run: quartodoc build
|
||||||
- name: Publish to GitHub Pages (and render)
|
- name: Publish to GitHub Pages (and render)
|
||||||
|
|||||||
2
.github/workflows/tests-nightly.yml
vendored
2
.github/workflows/tests-nightly.yml
vendored
@@ -136,4 +136,4 @@ jobs:
|
|||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.tests
|
||||||
|
|||||||
17
.github/workflows/tests.yml
vendored
17
.github/workflows/tests.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -98,9 +98,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
pytest -v tests/cli/
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -137,7 +136,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -171,14 +170,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
- name: Show HF cache
|
|
||||||
run: huggingface-cli scan-cache
|
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
pytest -v tests/cli/
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -232,7 +227,7 @@ jobs:
|
|||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
@@ -279,4 +274,4 @@ jobs:
|
|||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.tests
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
[settings]
|
[settings]
|
||||||
profile=black
|
profile=black
|
||||||
known_third_party=wandb,comet_ml
|
known_third_party=wandb,comet_ml
|
||||||
known_local_folder=src,tests
|
|
||||||
|
|||||||
@@ -133,7 +133,6 @@ quartodoc:
|
|||||||
- utils.schemas.datasets
|
- utils.schemas.datasets
|
||||||
- utils.schemas.peft
|
- utils.schemas.peft
|
||||||
- utils.schemas.trl
|
- utils.schemas.trl
|
||||||
- utils.schemas.multimodal
|
|
||||||
- utils.schemas.integrations
|
- utils.schemas.integrations
|
||||||
- utils.schemas.enums
|
- utils.schemas.enums
|
||||||
- utils.schemas.utils
|
- utils.schemas.utils
|
||||||
|
|||||||
@@ -33,9 +33,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -3,10 +3,9 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
|
||||||
|
|||||||
@@ -32,9 +32,6 @@ tokenizer_legacy:
|
|||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
shrink_embeddings:
|
shrink_embeddings:
|
||||||
# Whether to load the model with randomly initialized weights. Useful for
|
|
||||||
# pre-training a model from scratch or debugging purposes.
|
|
||||||
random_init_weights:
|
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -466,7 +463,6 @@ auto_find_batch_size: # Optional[bool]
|
|||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
|
|
||||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||||
|
|
||||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||||
@@ -507,58 +503,36 @@ lr_div_factor: # Learning rate div factor
|
|||||||
|
|
||||||
# Specify optimizer
|
# Specify optimizer
|
||||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||||
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||||
#
|
#
|
||||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||||
# in the examples/ for your model and fine-tuning use case.
|
# in the examples/ for your model and fine-tuning use case.
|
||||||
#
|
#
|
||||||
# Valid values for 'optimizer' include:
|
# Valid values for 'optimizer' include:
|
||||||
|
# - adamw_hf
|
||||||
# - adamw_torch
|
# - adamw_torch
|
||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
# - adamw_torch_npu_fused
|
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
# - adamw_torch_4bit
|
|
||||||
# - ademamix
|
|
||||||
# - sgd
|
# - sgd
|
||||||
# - adagrad
|
# - adagrad
|
||||||
# - adamw_bnb_8bit
|
# - adamw_bnb_8bit
|
||||||
# - adamw_8bit # alias for adamw_bnb_8bit
|
|
||||||
# - ademamix_8bit
|
|
||||||
# - lion_8bit
|
# - lion_8bit
|
||||||
# - lion_32bit
|
# - lion_32bit
|
||||||
# - paged_adamw_32bit
|
# - paged_adamw_32bit
|
||||||
# - paged_adamw_8bit
|
# - paged_adamw_8bit
|
||||||
# - paged_ademamix_32bit
|
|
||||||
# - paged_ademamix_8bit
|
|
||||||
# - paged_lion_32bit
|
# - paged_lion_32bit
|
||||||
# - paged_lion_8bit
|
# - paged_lion_8bit
|
||||||
# - rmsprop
|
|
||||||
# - rmsprop_bnb
|
|
||||||
# - rmsprop_bnb_8bit
|
|
||||||
# - rmsprop_bnb_32bit
|
|
||||||
# - galore_adamw
|
# - galore_adamw
|
||||||
# - galore_adamw_8bit
|
# - galore_adamw_8bit
|
||||||
# - galore_adafactor
|
# - galore_adafactor
|
||||||
# - galore_adamw_layerwise
|
# - galore_adamw_layerwise
|
||||||
# - galore_adamw_8bit_layerwise
|
# - galore_adamw_8bit_layerwise
|
||||||
# - galore_adafactor_layerwise
|
# - galore_adafactor_layerwise
|
||||||
# - lomo
|
|
||||||
# - adalomo
|
|
||||||
# - grokadamw
|
|
||||||
# - schedule_free_adamw
|
|
||||||
# - schedule_free_sgd
|
|
||||||
# - apollo_adamw
|
|
||||||
# - apollo_adamw_layerwise
|
|
||||||
#
|
|
||||||
# Additional custom optimizers include:
|
|
||||||
# - optimi_adamw
|
|
||||||
# - ao_adamw_8bit
|
|
||||||
# - ao_adamw_fp8
|
|
||||||
optimizer:
|
optimizer:
|
||||||
# Dictionary of arguments to pass to the optimizer
|
# Dictionary of arguments to pass to the optimizer
|
||||||
optim_args:
|
optim_args:
|
||||||
@@ -610,14 +584,6 @@ resume_from_checkpoint:
|
|||||||
# Be careful with this being turned on between different models.
|
# Be careful with this being turned on between different models.
|
||||||
auto_resume_from_checkpoints: false
|
auto_resume_from_checkpoints: false
|
||||||
|
|
||||||
## Multimodal section
|
|
||||||
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
|
||||||
# Will read from model/processor config if not set.
|
|
||||||
image_size:
|
|
||||||
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
|
||||||
image_resize_algorithm: 'bilinear'
|
|
||||||
## End of multimodal section
|
|
||||||
|
|
||||||
# Don't mess with this, it's here for accelerate and torchrun
|
# Don't mess with this, it's here for accelerate and torchrun
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|
||||||
@@ -651,14 +617,6 @@ ddp_timeout:
|
|||||||
ddp_bucket_cap_mb:
|
ddp_bucket_cap_mb:
|
||||||
ddp_broadcast_buffers:
|
ddp_broadcast_buffers:
|
||||||
|
|
||||||
# Sequence parallelism
|
|
||||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
|
||||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
|
||||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
|
||||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
|
||||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
|
||||||
sequence_parallel_degree:
|
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|
||||||
|
|||||||
@@ -103,7 +103,8 @@ This uses the same tags as the [`main` image](#sec-main-tags).
|
|||||||
|
|
||||||
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
||||||
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
||||||
- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service.
|
- `PUBLIC_KEY`: Add a public key for the SSH service.
|
||||||
|
- `SSH_KEY`: Add a private key for the SSH service.
|
||||||
|
|
||||||
#### Volume mounts
|
#### Volume mounts
|
||||||
|
|
||||||
|
|||||||
@@ -37,10 +37,6 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||||
|
|
||||||
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
|
|
||||||
|
|
||||||
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
|
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
@@ -1,171 +1,28 @@
|
|||||||
---
|
# MultiModal / Vision Language Models (BETA)
|
||||||
title: MultiModal / Vision Language Models (BETA)
|
|
||||||
format:
|
|
||||||
html:
|
|
||||||
toc: true
|
|
||||||
toc-depth: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## Supported Models
|
### Supported Models
|
||||||
|
|
||||||
- [Mllama](#sec-mllama)
|
- Mllama, i.e. llama with vision models
|
||||||
- [Pixtral](#sec-pixtral)
|
|
||||||
- [Llava-1.5](#sec-llava-15)
|
|
||||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
|
||||||
- [Gemma-3](#sec-gemma-3)
|
|
||||||
- [Qwen2-VL](#sec-qwen2-vl)
|
|
||||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
|
||||||
|
|
||||||
## Usage
|
### Usage
|
||||||
|
|
||||||
Multimodal support is limited and doesn't have full feature parity.
|
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
|
||||||
|
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
|
||||||
Here are the hyperparams you'll need to use to finetune a multimodal model.
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
skip_prepare_dataset: true
|
skip_prepare_dataset: true
|
||||||
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
|
||||||
sample_packing: false # not yet supported with multimodal
|
|
||||||
|
|
||||||
chat_template: # see in next section
|
chat_template: llama3_2_vision
|
||||||
|
|
||||||
# example dataset
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
# (optional) if doing lora, only finetune the Language model,
|
# only finetune the Language model, leave the vision model and vision tower frozen
|
||||||
# leave the vision model and vision tower frozen
|
|
||||||
# load_in_8bit: true
|
|
||||||
adapter: lora
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
# (optional) if you want to resize images to a set size
|
|
||||||
image_size: 512
|
|
||||||
image_resize_algorithm: bilinear
|
|
||||||
```
|
|
||||||
|
|
||||||
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
|
||||||
|
|
||||||
::: {.callout-warning}
|
|
||||||
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
|
||||||
:::
|
|
||||||
|
|
||||||
### Mllama {#sec-mllama}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
|
|
||||||
|
|
||||||
chat_template: llama3_2_vision
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pixtral {#sec-pixtral}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: mistralai/Pixtral-12B-2409
|
|
||||||
|
|
||||||
chat_template: pixtral
|
|
||||||
```
|
|
||||||
|
|
||||||
### Llava-1.5 {#sec-llava-15}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: llava-hf/llava-1.5-7b-hf
|
|
||||||
|
|
||||||
chat_template: llava
|
|
||||||
```
|
|
||||||
|
|
||||||
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
|
||||||
|
|
||||||
chat_template: mistral_v7_tekken
|
|
||||||
```
|
|
||||||
|
|
||||||
### Gemma-3 {#sec-gemma-3}
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
The Gemma3-1B model is a text-only model, so please train as regular text model.
|
|
||||||
:::
|
|
||||||
|
|
||||||
For multi-modal 4B/12B/27B models, use the following config:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: google/gemma-3-4b-it
|
|
||||||
|
|
||||||
chat_template: gemma3
|
|
||||||
```
|
|
||||||
|
|
||||||
### Qwen2-VL {#sec-qwen2-vl}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
|
||||||
|
|
||||||
chat_template: qwen2_vl
|
|
||||||
```
|
|
||||||
|
|
||||||
### Qwen2.5-VL {#sec-qwen25-vl}
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
|
||||||
|
|
||||||
chat_template: qwen2_vl # same as qwen2-vl
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dataset Format
|
|
||||||
|
|
||||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
|
||||||
|
|
||||||
- A message is a list of `role` and `content`.
|
|
||||||
- `role` can be `system`, `user`, `assistant`, etc.
|
|
||||||
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
|
|
||||||
|
|
||||||
::: {.callout-note}
|
|
||||||
For backwards compatibility:
|
|
||||||
|
|
||||||
- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key.
|
|
||||||
- If `content` is a string, it will be converted to a list with `type` as `text`.
|
|
||||||
:::
|
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
|
|
||||||
|
|
||||||
- `"path": "/path/to/image.jpg"`
|
|
||||||
- `"url": "https://example.com/image.jpg"`
|
|
||||||
- `"base64": "..."`
|
|
||||||
- `"image": PIL.Image`
|
|
||||||
:::
|
|
||||||
|
|
||||||
Here is an example of a multi-modal dataset:
|
|
||||||
```json
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "You are a helpful assistant."}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
|
||||||
{"type": "text", "text": "Describe this image in detail."}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "The image is a bee."}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,90 +0,0 @@
|
|||||||
---
|
|
||||||
title: Sequence Parallelism
|
|
||||||
description: Train with long sequences split across multiple GPUs.
|
|
||||||
---
|
|
||||||
|
|
||||||
# Sequence Parallelism
|
|
||||||
|
|
||||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
|
||||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
|
||||||
GPU processes a different portion of the sequence, and the results are aggregated
|
|
||||||
through a ring communication pattern.
|
|
||||||
|
|
||||||
## When to Use Sequence Parallelism
|
|
||||||
|
|
||||||
Use sequence parallelism when:
|
|
||||||
|
|
||||||
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
|
||||||
- You have multiple GPUs available
|
|
||||||
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
To enable sequence parallelism, add the following to your configuration file:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Set to a divisor (> 1) of the number of GPUs available
|
|
||||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
|
||||||
```
|
|
||||||
|
|
||||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
|
||||||
|
|
||||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
|
||||||
- With 4 GPUs, valid values would be 2 or 4
|
|
||||||
|
|
||||||
## Implementation Details
|
|
||||||
|
|
||||||
When sequence parallelism is enabled:
|
|
||||||
|
|
||||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
|
||||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
|
||||||
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
|
||||||
4. The trainer uses special ring communication patterns for attention operations
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
To use sequence parallelism, you need:
|
|
||||||
|
|
||||||
- Multiple GPUs (at least 2)
|
|
||||||
- The `ring-flash-attn` package. Install with:
|
|
||||||
- `pip install axolotl[ring-flash-attn]` (preferred)
|
|
||||||
- `pip install ring-flash-attn>=0.1.4`
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
|
||||||
- May have a small performance overhead due to communication between GPUs
|
|
||||||
|
|
||||||
## Example
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Example config with sequence parallelism
|
|
||||||
base_model: meta-llama/Llama-3-8B-Instruct
|
|
||||||
sequence_len: 8192
|
|
||||||
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
|
||||||
flash_attention: true # Required with sequence parallelism
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
|
||||||
into 2 subsequences of length 4096 across 2 GPUs.
|
|
||||||
|
|
||||||
## Sample Packing with Sequence Parallelism
|
|
||||||
|
|
||||||
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
|
||||||
|
|
||||||
1. Samples are first packed together
|
|
||||||
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
|
||||||
3. Position IDs are automatically adjusted to maintain proper relative positions
|
|
||||||
|
|
||||||
## Effect on Batch Size
|
|
||||||
|
|
||||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
|
||||||
|
|
||||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
|
||||||
- The number of batches processed per step decreases
|
|
||||||
|
|
||||||
For example:
|
|
||||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
|
||||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
|
||||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
base_model: CohereForAI/c4ai-command-r7b-12-2024
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: cohere
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
base_model: google/gemma-3-1b-it
|
|
||||||
# optionally might have model_type or tokenizer_type
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: gemma3
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: auto
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
eval_table_size:
|
|
||||||
eval_max_new_tokens: 128
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
base_model: google/gemma-3-4b-it
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: gemma3
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -19,6 +19,7 @@ val_set_size: 0.0
|
|||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
dataset_exact_deduplication: true
|
dataset_exact_deduplication: true
|
||||||
|
test_value: true
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
base_model: llava-hf/llava-1.5-7b-hf
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: llava
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: mistral_v7_tekken
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
base_model: mistral-community/pixtral-12b
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: pixtral
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <pad>
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: qwen2_vl
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
@@ -4,18 +4,19 @@
|
|||||||
bitsandbytes==0.45.3
|
bitsandbytes==0.45.3
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
flash-attn==2.7.4.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
liger-kernel==0.5.5
|
liger-kernel==0.5.3
|
||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.15.0
|
peft==0.15.0
|
||||||
transformers==4.50.0
|
transformers==4.49.0
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.5.2
|
accelerate==1.5.2
|
||||||
datasets==3.5.0
|
datasets==3.4.1
|
||||||
deepspeed==0.16.4
|
deepspeed==0.16.4
|
||||||
trl==0.15.1
|
trl==0.15.1
|
||||||
|
|
||||||
@@ -35,7 +36,6 @@ einops
|
|||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4,<=2.0.1
|
numpy>=1.24.4,<=2.0.1
|
||||||
|
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
|
|||||||
315
requirements_env.txt
Normal file
315
requirements_env.txt
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
accelerate==0.34.1
|
||||||
|
addict==2.4.0
|
||||||
|
aiofiles==23.2.1
|
||||||
|
aiohttp==3.9.0
|
||||||
|
aiosignal==1.3.1
|
||||||
|
aiostream==0.5.2
|
||||||
|
alembic==1.13.1
|
||||||
|
annotated-types==0.6.0
|
||||||
|
annoy==1.17.3
|
||||||
|
ansible==6.7.0
|
||||||
|
ansible-core==2.13.13
|
||||||
|
ansible-vault==2.1.0
|
||||||
|
anyio==3.7.1
|
||||||
|
appdirs==1.4.4
|
||||||
|
art==6.0
|
||||||
|
asgiref==3.7.2
|
||||||
|
async-timeout==4.0.2
|
||||||
|
attrdict==2.0.1
|
||||||
|
attrs==22.2.0
|
||||||
|
awscli==1.32.75
|
||||||
|
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
||||||
|
backoff==2.2.1
|
||||||
|
base58==2.1.1
|
||||||
|
beartype==0.17.2
|
||||||
|
bitnet==0.2.1
|
||||||
|
bitsandbytes==0.42.0
|
||||||
|
bittensor==6.7.0
|
||||||
|
black==23.7.0
|
||||||
|
blinker==1.7.0
|
||||||
|
boto3==1.34.75
|
||||||
|
botocore==1.34.75
|
||||||
|
cachetools==5.3.3
|
||||||
|
cachy==0.1.1
|
||||||
|
certifi==2023.7.22
|
||||||
|
cffi==1.16.0
|
||||||
|
cfgv==3.3.1
|
||||||
|
chai-guanaco==1.2.4
|
||||||
|
charset-normalizer==3.2.0
|
||||||
|
cleo==0.6.8
|
||||||
|
click==8.1.7
|
||||||
|
cloudpickle==2.0.0
|
||||||
|
cohere==4.11.2
|
||||||
|
colorama==0.4.4
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
CoLT5-attention==0.10.20
|
||||||
|
contextlib2==21.6.0
|
||||||
|
contourpy==1.2.0
|
||||||
|
cryptography==41.0.3
|
||||||
|
cycler==0.12.1
|
||||||
|
cytoolz==0.12.3
|
||||||
|
databricks-cli==0.18.0
|
||||||
|
dataclasses-json==0.5.7
|
||||||
|
datasets==2.11.0
|
||||||
|
ddt==1.6.0
|
||||||
|
decorator==5.1.1
|
||||||
|
deepspeed==0.15.0
|
||||||
|
# Editable Git install with no remote (dialogpt==0.1)
|
||||||
|
-e /Users/wing/Projects/ml/dialogpt/src
|
||||||
|
dill==0.3.6
|
||||||
|
distlib==0.3.6
|
||||||
|
docker==7.0.0
|
||||||
|
docker-pycreds==0.4.0
|
||||||
|
docstring-parser==0.15
|
||||||
|
docutils==0.16
|
||||||
|
ecdsa==0.18.0
|
||||||
|
einops==0.7.0
|
||||||
|
einops-exts==0.0.4
|
||||||
|
einx==0.1.3
|
||||||
|
entrypoints==0.4
|
||||||
|
eth-hash==0.6.0
|
||||||
|
eth-keys==0.5.0
|
||||||
|
eth-typing==4.0.0
|
||||||
|
eth-utils==2.3.1
|
||||||
|
evaluate==0.4.0
|
||||||
|
exceptiongroup==1.1.1
|
||||||
|
fastapi==0.109.2
|
||||||
|
fastcore==1.5.29
|
||||||
|
ffmpy==0.4.0
|
||||||
|
filelock==3.12.2
|
||||||
|
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
||||||
|
fire==0.5.0
|
||||||
|
first==2.0.2
|
||||||
|
flake8==7.0.0
|
||||||
|
Flask==3.0.1
|
||||||
|
fonttools==4.47.2
|
||||||
|
frozendict==2.4.1
|
||||||
|
frozenlist==1.3.3
|
||||||
|
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||||
|
fsspec==2023.6.0
|
||||||
|
fuzzywuzzy==0.18.0
|
||||||
|
gitdb==4.0.10
|
||||||
|
GitPython==3.1.31
|
||||||
|
google-pasta==0.2.0
|
||||||
|
gradio==4.42.0
|
||||||
|
gradio_client==1.3.0
|
||||||
|
greenlet==2.0.2
|
||||||
|
grpclib==0.4.7
|
||||||
|
gunicorn==21.2.0
|
||||||
|
h11==0.14.0
|
||||||
|
h2==4.1.0
|
||||||
|
hpack==4.0.0
|
||||||
|
httpcore==0.17.3
|
||||||
|
httpx==0.24.1
|
||||||
|
huggingface-hub==0.23.4
|
||||||
|
humanfriendly==10.0
|
||||||
|
hyperframe==6.0.1
|
||||||
|
identify==2.5.24
|
||||||
|
idna==3.4
|
||||||
|
immutables==0.20
|
||||||
|
importlib-metadata==6.7.0
|
||||||
|
importlib-resources==6.1.1
|
||||||
|
inflection==0.5.1
|
||||||
|
iniconfig==2.0.0
|
||||||
|
itsdangerous==2.1.2
|
||||||
|
Jinja2==3.1.2
|
||||||
|
jmespath==1.0.1
|
||||||
|
joblib==1.3.2
|
||||||
|
jsonlines==3.1.0
|
||||||
|
jsonschema==2.6.0
|
||||||
|
kiwisolver==1.4.5
|
||||||
|
langchain==0.0.144
|
||||||
|
Levenshtein==0.24.0
|
||||||
|
libcst==1.1.0
|
||||||
|
liger-kernel==0.0.0
|
||||||
|
lion-pytorch==0.1.2
|
||||||
|
llama-cpp-python==0.1.36
|
||||||
|
llvmlite==0.40.1
|
||||||
|
local-attention==1.9.0
|
||||||
|
loguru==0.7.0
|
||||||
|
Mako==1.3.2
|
||||||
|
Markdown==3.5.2
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
markdown2==2.4.10
|
||||||
|
MarkupSafe==2.1.2
|
||||||
|
marshmallow==3.19.0
|
||||||
|
marshmallow-enum==1.5.1
|
||||||
|
matplotlib==3.8.2
|
||||||
|
mccabe==0.7.0
|
||||||
|
mdurl==0.1.2
|
||||||
|
MEGABYTE-pytorch==0.0.7
|
||||||
|
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
||||||
|
mlflow==2.10.0
|
||||||
|
modal==0.62.77
|
||||||
|
more-itertools==10.2.0
|
||||||
|
mpmath==1.2.1
|
||||||
|
msgpack==1.0.7
|
||||||
|
msgpack-numpy-opentensor==0.5.0
|
||||||
|
multidict==6.0.4
|
||||||
|
multiprocess==0.70.14
|
||||||
|
munch==2.5.0
|
||||||
|
mypy==1.3.0
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
nest-asyncio==1.6.0
|
||||||
|
netaddr==0.10.1
|
||||||
|
networkx==3.0rc1
|
||||||
|
nh3==0.2.14
|
||||||
|
nodeenv==1.8.0
|
||||||
|
nomic==2.0.2
|
||||||
|
numba==0.57.1
|
||||||
|
numexpr==2.8.4
|
||||||
|
numpy==1.24.4
|
||||||
|
oauthlib==3.2.2
|
||||||
|
openai==0.27.4
|
||||||
|
openapi==1.1.0
|
||||||
|
openapi-schema-pydantic==1.2.4
|
||||||
|
optimum==1.8.6
|
||||||
|
orjson==3.10.7
|
||||||
|
packaging==23.1
|
||||||
|
pandas==2.0.0
|
||||||
|
parameterized==0.9.0
|
||||||
|
password-strength==0.0.3.post2
|
||||||
|
pastel==0.1.1
|
||||||
|
pathos==0.3.0
|
||||||
|
pathspec==0.11.1
|
||||||
|
pathtools==0.1.2
|
||||||
|
peft==0.11.1
|
||||||
|
pendulum==3.0.0
|
||||||
|
Pillow==9.5.0
|
||||||
|
pip-tools==1.11.0
|
||||||
|
platformdirs==3.2.0
|
||||||
|
pluggy==1.4.0
|
||||||
|
poetry==0.7.1
|
||||||
|
pox==0.3.2
|
||||||
|
ppft==1.7.6.6
|
||||||
|
pre-commit==3.3.2
|
||||||
|
prettytable==3.10.0
|
||||||
|
prompt-toolkit==3.0.39
|
||||||
|
protobuf==3.20.2
|
||||||
|
protobuf3-to-dict==0.1.5
|
||||||
|
psutil==5.9.5
|
||||||
|
psycopg==3.1.18
|
||||||
|
PuLP==2.8.0
|
||||||
|
py==1.11.0
|
||||||
|
py-bip39-bindings==0.1.11
|
||||||
|
py-cpuinfo==9.0.0
|
||||||
|
py-ed25519-zebra-bindings==1.0.1
|
||||||
|
py-sr25519-bindings==0.2.0
|
||||||
|
pyarrow==11.0.0
|
||||||
|
pyasn1==0.6.0
|
||||||
|
pycodestyle==2.11.1
|
||||||
|
pycparser==2.21
|
||||||
|
pycryptodome==3.20.0
|
||||||
|
pydantic==2.5.3
|
||||||
|
pydantic_core==2.14.6
|
||||||
|
pydub==0.25.1
|
||||||
|
pyfiglet==0.8.post1
|
||||||
|
pyflakes==3.2.0
|
||||||
|
Pygments==2.15.1
|
||||||
|
PyJWT==2.8.0
|
||||||
|
pylev==1.4.0
|
||||||
|
PyNaCl==1.5.0
|
||||||
|
pynvml==11.5.0
|
||||||
|
pyparsing==2.4.7
|
||||||
|
pyrsistent==0.14.11
|
||||||
|
pytest==8.0.2
|
||||||
|
pytest-asyncio==0.23.4
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
python-Levenshtein==0.24.0
|
||||||
|
python-multipart==0.0.9
|
||||||
|
pytz==2023.3
|
||||||
|
PyYAML==6.0.1
|
||||||
|
querystring-parser==1.2.4
|
||||||
|
rapidfuzz==3.6.1
|
||||||
|
regex==2023.6.3
|
||||||
|
requests==2.31.0
|
||||||
|
requests-toolbelt==0.8.0
|
||||||
|
resolvelib==0.8.1
|
||||||
|
responses==0.18.0
|
||||||
|
retry==0.9.2
|
||||||
|
rich==13.7.0
|
||||||
|
rsa==4.7.2
|
||||||
|
ruff==0.6.3
|
||||||
|
s3transfer==0.10.1
|
||||||
|
safetensors==0.4.5
|
||||||
|
sagemaker==2.148.0
|
||||||
|
scalecodec==1.2.7
|
||||||
|
schedulefree==1.2.1
|
||||||
|
schema==0.7.5
|
||||||
|
scikit-learn==1.4.0
|
||||||
|
scipy==1.9.3
|
||||||
|
seaborn==0.13.2
|
||||||
|
semantic-version==2.10.0
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
sentry-sdk==1.19.1
|
||||||
|
setproctitle==1.3.2
|
||||||
|
shellingham==1.5.4
|
||||||
|
shortuuid==1.0.11
|
||||||
|
shtab==1.6.5
|
||||||
|
sigtools==4.0.1
|
||||||
|
six==1.16.0
|
||||||
|
skypilot==0.4.1
|
||||||
|
smdebug-rulesconfig==1.0.1
|
||||||
|
smmap==5.0.0
|
||||||
|
sniffio==1.3.0
|
||||||
|
SQLAlchemy==1.4.47
|
||||||
|
sqlparse==0.4.4
|
||||||
|
starlette==0.36.3
|
||||||
|
substrate-interface==1.5.2
|
||||||
|
svgwrite==1.4.3
|
||||||
|
sympy==1.11.1
|
||||||
|
synchronicity==0.6.7
|
||||||
|
tabulate==0.9.0
|
||||||
|
tblib==1.7.0
|
||||||
|
tenacity==8.2.2
|
||||||
|
tensor-parallel==2.0.0
|
||||||
|
termcolor==2.2.0
|
||||||
|
text2art==0.2.0
|
||||||
|
threadpoolctl==3.2.0
|
||||||
|
tiktoken==0.6.0
|
||||||
|
time-machine==2.14.1
|
||||||
|
timm==0.9.16
|
||||||
|
tokenizers==0.19.1
|
||||||
|
tokenmonster==1.1.12
|
||||||
|
toml==0.9.6
|
||||||
|
tomli==2.0.1
|
||||||
|
tomlkit==0.12.0
|
||||||
|
toolz==0.12.1
|
||||||
|
torch==2.2.0
|
||||||
|
torchdata==0.6.1
|
||||||
|
torchdiffeq==0.2.3
|
||||||
|
TorchFix==0.4.0
|
||||||
|
torchtext==0.15.2
|
||||||
|
torchvision==0.17.0
|
||||||
|
tqdm==4.66.2
|
||||||
|
transformers==4.44.2
|
||||||
|
trl==0.9.6
|
||||||
|
typer==0.12.5
|
||||||
|
types-certifi==2021.10.8.3
|
||||||
|
types-requests==2.31.0.20240125
|
||||||
|
types-setuptools==69.0.0.20240125
|
||||||
|
types-toml==0.10.8.7
|
||||||
|
typing==3.7.4.3
|
||||||
|
typing-inspect==0.8.0
|
||||||
|
typing_extensions==4.9.0
|
||||||
|
tyro==0.5.18
|
||||||
|
tzdata==2023.3
|
||||||
|
unique-names-generator==1.0.2
|
||||||
|
urllib3==2.2.2
|
||||||
|
uvicorn==0.22.0
|
||||||
|
vector_quantize_pytorch==1.14.1
|
||||||
|
virtualenv==20.23.0
|
||||||
|
voyager==2.0.2
|
||||||
|
wandb==0.16.2
|
||||||
|
watchfiles==0.21.0
|
||||||
|
wavedrom==2.0.3.post3
|
||||||
|
wcwidth==0.2.6
|
||||||
|
websocket-client==1.7.0
|
||||||
|
websockets==12.0
|
||||||
|
Werkzeug==3.0.1
|
||||||
|
wonderwords==2.2.0
|
||||||
|
xxhash==3.2.0
|
||||||
|
yarl==1.8.2
|
||||||
|
zetascale==2.2.7
|
||||||
|
zipp==3.15.0
|
||||||
22
setup.py
22
setup.py
@@ -16,7 +16,13 @@ def parse_requirements():
|
|||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
is_extras = (
|
||||||
|
"flash-attn" in line
|
||||||
|
or "flash-attention" in line
|
||||||
|
or "deepspeed" in line
|
||||||
|
or "mamba-ssm" in line
|
||||||
|
or "lion-pytorch" in line
|
||||||
|
)
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
_, url = line.split()
|
_, url = line.split()
|
||||||
@@ -33,6 +39,7 @@ def parse_requirements():
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"triton",
|
"triton",
|
||||||
"mamba-ssm",
|
"mamba-ssm",
|
||||||
|
"flash-attn",
|
||||||
"xformers",
|
"xformers",
|
||||||
"autoawq",
|
"autoawq",
|
||||||
"liger-kernel",
|
"liger-kernel",
|
||||||
@@ -117,8 +124,9 @@ setup(
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
"flash-attn": [
|
||||||
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
|
"flash-attn==2.7.4.post1",
|
||||||
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.4",
|
"deepspeed==0.16.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
@@ -133,15 +141,15 @@ setup(
|
|||||||
"mlflow": [
|
"mlflow": [
|
||||||
"mlflow",
|
"mlflow",
|
||||||
],
|
],
|
||||||
|
"lion-pytorch": [
|
||||||
|
"lion-pytorch==0.1.2",
|
||||||
|
],
|
||||||
"galore": [
|
"galore": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
],
|
],
|
||||||
"apollo": [
|
|
||||||
"apollo-torch",
|
|
||||||
],
|
|
||||||
"optimizers": [
|
"optimizers": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
"apollo-torch",
|
"lion-pytorch==0.1.2",
|
||||||
"lomo-optim==0.1.1",
|
"lomo-optim==0.1.1",
|
||||||
"torch-optimi==0.2.1",
|
"torch-optimi==0.2.1",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def do_inference(
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Inference-specific CLI arguments.
|
cli_args: Inference-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
@@ -151,7 +151,7 @@ def do_inference_gradio(
|
|||||||
"""
|
"""
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
|
||||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
LOG.info("Running merge of LoRA with base model...")
|
LOG.info("Running merge of LoRA with base model...")
|
||||||
@@ -44,9 +44,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
)
|
)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
if processor:
|
|
||||||
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -17,14 +17,13 @@ from axolotl.cli.config import load_cfg
|
|||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
|
||||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||||
@@ -34,9 +33,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Training-specific CLI arguments.
|
cli_args: Training-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
|
||||||
set_pytorch_cuda_alloc_conf()
|
|
||||||
|
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
@@ -48,13 +44,16 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
del model, tokenizer, trainer
|
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
|
del model
|
||||||
|
del tokenizer
|
||||||
|
del trainer
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
|
|||||||
@@ -13,16 +13,11 @@ from typing import Any, Callable, Type, Union, get_args, get_origin
|
|||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import (
|
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
PreTrainedModel,
|
|
||||||
PreTrainedTokenizer,
|
|
||||||
PreTrainedTokenizerFast,
|
|
||||||
ProcessorMixin,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -300,13 +295,9 @@ def load_model_and_tokenizer(
|
|||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
) -> tuple[
|
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
||||||
PreTrainedModel,
|
|
||||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
|
||||||
ProcessorMixin | None,
|
|
||||||
]:
|
|
||||||
"""
|
"""
|
||||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
||||||
config.
|
config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -314,7 +305,7 @@ def load_model_and_tokenizer(
|
|||||||
inference: Boolean denoting inference mode.
|
inference: Boolean denoting inference mode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
`transformers` model and tokenizer.
|
||||||
"""
|
"""
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
@@ -322,9 +313,4 @@ def load_model_and_tokenizer(
|
|||||||
LOG.info("loading model...")
|
LOG.info("loading model...")
|
||||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||||
|
|
||||||
processor = None
|
return model, tokenizer
|
||||||
if cfg.is_multimodal:
|
|
||||||
LOG.info("loading processor...")
|
|
||||||
processor = load_processor(cfg, tokenizer)
|
|
||||||
|
|
||||||
return model, tokenizer, processor
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from transformers import (
|
|||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers import (
|
from axolotl.core.trainers.base import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
@@ -60,7 +60,6 @@ from axolotl.core.training_args import (
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
from axolotl.processing_strategies import get_processing_strategy
|
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
@@ -663,11 +662,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
optimizer_cls = MuonOptimizerFactory
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "soap":
|
|
||||||
from axolotl.utils.optimizers.soap import SOAP
|
|
||||||
|
|
||||||
optimizer_cls = SOAP
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "optimi_adamw":
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
@@ -753,12 +747,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.accelerator_config
|
self.cfg.accelerator_config
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.image_size:
|
|
||||||
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
|
||||||
if self.cfg.image_resize_algorithm:
|
|
||||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
|
||||||
self.cfg.image_resize_algorithm
|
|
||||||
)
|
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
if self.cfg.kd_alpha is not None:
|
if self.cfg.kd_alpha is not None:
|
||||||
@@ -774,10 +762,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kd_top_k_before_softmax
|
self.cfg.kd_top_k_before_softmax
|
||||||
)
|
)
|
||||||
|
|
||||||
training_arguments_kwargs["sequence_parallel_degree"] = (
|
|
||||||
self.cfg.sequence_parallel_degree
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
elif self.cfg.process_reward_model:
|
elif self.cfg.process_reward_model:
|
||||||
@@ -861,10 +845,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if (
|
if self.cfg.pretraining_sample_concatenation is False:
|
||||||
self.cfg.pretraining_sample_concatenation is False
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
or self.cfg.micro_batch_size > 1
|
if self.cfg.micro_batch_size > 1:
|
||||||
):
|
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -892,7 +875,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
elif (
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
and self.cfg.flash_attention is not True
|
and self.cfg.flash_attention is not True
|
||||||
):
|
):
|
||||||
@@ -902,13 +887,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
if self.cfg.processor_type and self.processor:
|
if self.cfg.processor_type and self.processor:
|
||||||
collator = MultiModalChatDataCollator
|
collator = MultiModalChatDataCollator
|
||||||
kwargs["processing_strategy"] = get_processing_strategy(
|
kwargs["processor"] = self.processor
|
||||||
self.processor,
|
kwargs["chat_template"] = training_args.chat_template
|
||||||
training_args.chat_template,
|
|
||||||
self.cfg.chat_template,
|
|
||||||
image_size=training_args.image_size,
|
|
||||||
image_resize_algorithm=training_args.image_resize_algorithm,
|
|
||||||
)
|
|
||||||
elif self.cfg.batch_flattening:
|
elif self.cfg.batch_flattening:
|
||||||
collator = DataCollatorWithFlattening
|
collator = DataCollatorWithFlattening
|
||||||
collator_args.pop(0)
|
collator_args.pop(0)
|
||||||
@@ -928,8 +908,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
|
||||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
"""Init for axolotl.core.trainers"""
|
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from .base import AxolotlTrainer
|
|
||||||
from .dpo.trainer import AxolotlDPOTrainer
|
|
||||||
from .grpo.trainer import AxolotlGRPOTrainer
|
|
||||||
from .mamba import AxolotlMambaTrainer
|
|
||||||
from .relora import ReLoRATrainer
|
|
||||||
from .trl import (
|
|
||||||
AxolotlCPOTrainer,
|
|
||||||
AxolotlKTOTrainer,
|
|
||||||
AxolotlORPOTrainer,
|
|
||||||
AxolotlPRMTrainer,
|
|
||||||
AxolotlRewardTrainer,
|
|
||||||
TRLPPOTrainer,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,47 +1,365 @@
|
|||||||
"""Module for customized trainers"""
|
"""
|
||||||
|
module for customized trainers
|
||||||
# pylint: disable=too-many-lines
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# pylint: disable=too-many-lines
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Literal
|
from typing import Dict, Literal, Optional
|
||||||
|
|
||||||
import datasets
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import (
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
BatchSampler,
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
DataLoader,
|
|
||||||
RandomSampler,
|
|
||||||
Sampler,
|
|
||||||
SequentialSampler,
|
|
||||||
)
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
OptimizerMixin,
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
SchedulerMixin,
|
|
||||||
SequenceParallelMixin,
|
|
||||||
)
|
|
||||||
from axolotl.core.trainers.utils import (
|
|
||||||
sanitize_kwargs_for_ds_tagging,
|
|
||||||
sanitize_kwargs_for_tagging,
|
|
||||||
)
|
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
from axolotl.utils.schedulers import (
|
||||||
|
RexLR,
|
||||||
|
get_cosine_schedule_with_min_lr,
|
||||||
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
if isinstance(tag_names, str):
|
||||||
|
tag_names = [tag_names]
|
||||||
|
|
||||||
|
if kwargs is not None:
|
||||||
|
if "tags" not in kwargs:
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||||
|
kwargs["tags"].extend(tag_names)
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||||
|
tag_names.append(kwargs["tags"])
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||||
|
if isinstance(dataset_tags, str):
|
||||||
|
dataset_tags = [dataset_tags]
|
||||||
|
|
||||||
|
if (dataset_tags is not None) and (kwargs is not None):
|
||||||
|
if "dataset_tags" not in kwargs:
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||||
|
kwargs["dataset_tags"].extend(dataset_tags)
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||||
|
dataset_tags.append(kwargs["dataset_tags"])
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
passed as an argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_training_steps (int): The number of training steps to do.
|
||||||
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
|
"""
|
||||||
|
use_cosine_quadratic = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cosine_min_lr = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.cosine_min_lr_ratio is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||||
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
extra_lr_kwargs = {}
|
||||||
|
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["pct_start"] = pct_start
|
||||||
|
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||||
|
|
||||||
|
self.lr_scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
**extra_lr_kwargs,
|
||||||
|
**self.args.lr_scheduler_kwargs,
|
||||||
|
)
|
||||||
|
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
|
||||||
|
self.lr_scheduler = RexLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
)
|
||||||
|
elif use_cosine_quadratic:
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
|
else:
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for shared handling of building custom optimizers
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_optimizer_grouped_parameters(
|
||||||
|
self, opt_model, optimizer_kwargs
|
||||||
|
) -> list[dict]:
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
params: dict = {
|
||||||
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
|
"no_weight_decay": {},
|
||||||
|
}
|
||||||
|
lr_groups_lookup = {}
|
||||||
|
lr_groups_learning_rates = {}
|
||||||
|
if self.args.lr_groups:
|
||||||
|
for lr_group in self.args.lr_groups:
|
||||||
|
group_name = lr_group["name"]
|
||||||
|
group_modules = lr_group["modules"]
|
||||||
|
for module in group_modules:
|
||||||
|
lr_groups_lookup[module] = group_name
|
||||||
|
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||||
|
params[f"to_weight_decay_{group_name}"] = {}
|
||||||
|
|
||||||
|
for name, param in opt_model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
lr_group_modules = [
|
||||||
|
group_modules
|
||||||
|
for group_modules in lr_groups_lookup
|
||||||
|
if group_modules in name
|
||||||
|
]
|
||||||
|
if lr_groups_lookup and any(lr_group_modules):
|
||||||
|
lr_group_module = lr_group_modules[0]
|
||||||
|
group_name = lr_groups_lookup[lr_group_module]
|
||||||
|
params[f"to_weight_decay_{group_name}"][name] = param
|
||||||
|
else:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||||
|
if params[f"to_weight_decay_{group_name}"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(
|
||||||
|
params[f"to_weight_decay_{group_name}"].values()
|
||||||
|
),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": group_lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
|
and self.args.lr_groups is None
|
||||||
|
and self.optimizer_cls_and_kwargs is None
|
||||||
|
):
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self.optimizer
|
||||||
|
and self.optimizer_cls_and_kwargs is not None
|
||||||
|
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||||
|
):
|
||||||
|
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
self.optimizer = optimizer_factory_cls()(
|
||||||
|
opt_model, self.args, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.optimizer:
|
||||||
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
else:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args, opt_model
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
|
opt_model, optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(
|
||||||
|
self.args, "loraplus_lr_embedding", 1e-6
|
||||||
|
)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
|
**optimizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for GaLore optimizer.
|
||||||
|
if "params" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
|
|
||||||
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for LOMO optimizer.
|
||||||
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||||
|
"optimizer_dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer = optimizer_cls(
|
||||||
|
optimizer_grouped_parameters, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
skipped = 0
|
||||||
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum(
|
||||||
|
{
|
||||||
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||||
|
"""
|
||||||
|
Extend the base Trainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
@@ -58,18 +376,12 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
# Initialize sequence parallelism if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._setup_sequence_parallel()
|
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
@@ -82,247 +394,142 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def _create_multipack_sampler(
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
self, base_sampler: Sampler, dataset: Dataset
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
) -> MultipackBatchSampler:
|
if self.args.multipack_real_batches:
|
||||||
"""
|
batch_size = self.args.per_device_train_batch_size
|
||||||
Helper method to create a `MultipackBatchSampler` for multipacking sequences
|
batch_max_len = self.args.max_seq_length
|
||||||
for training.
|
else:
|
||||||
|
batch_size = 1
|
||||||
|
train_batch_size = (
|
||||||
|
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||||
|
)
|
||||||
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
|
|
||||||
Args:
|
if self.args.curriculum_sampling:
|
||||||
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
|
sampler = SequentialSampler(self.train_dataset)
|
||||||
dataset: Dataset to sample from.
|
else:
|
||||||
|
sampler = RandomSampler(self.train_dataset)
|
||||||
|
|
||||||
Returns:
|
return MultipackBatchSampler(
|
||||||
Multipack (sample packing) batch sampler.
|
sampler,
|
||||||
"""
|
lengths=get_dataset_lengths(self.train_dataset),
|
||||||
if self.args.multipack_real_batches:
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
batch_size = self.args.per_device_train_batch_size
|
batch_max_len=batch_max_len,
|
||||||
batch_max_len = self.args.max_seq_length
|
batch_size=batch_size,
|
||||||
else:
|
group_size=self.args.sample_packing_group_size,
|
||||||
batch_size = 1
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
train_batch_size = (
|
drop_last=True,
|
||||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
|
||||||
)
|
)
|
||||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
if self.args.curriculum_sampling:
|
||||||
|
return SequentialSampler(self.train_dataset)
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
def _get_eval_sampler(
|
||||||
base_sampler,
|
self, eval_dataset: Dataset
|
||||||
lengths=get_dataset_lengths(dataset),
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
batch_max_len=batch_max_len,
|
if self.args.multipack_real_batches:
|
||||||
batch_size=batch_size,
|
batch_size = self.args.per_device_eval_batch_size
|
||||||
drop_last=True,
|
batch_max_len = self.args.max_seq_length
|
||||||
)
|
else:
|
||||||
|
batch_size = 1
|
||||||
def _get_train_sampler(self) -> Sampler | None:
|
batch_max_len = (
|
||||||
"""
|
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||||
Helper method to get the sampler for training. Handles cases for sequence
|
)
|
||||||
parallelism, sample packing, and curriculum sampling (sequential).
|
return MultipackBatchSampler(
|
||||||
|
SequentialSampler(eval_dataset),
|
||||||
Returns:
|
lengths=get_dataset_lengths(self.eval_dataset),
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
depends on the passed training args.
|
batch_max_len=batch_max_len,
|
||||||
"""
|
batch_size=batch_size,
|
||||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
group_size=self.args.sample_packing_group_size,
|
||||||
|
bin_size=self.args.sample_packing_bin_size,
|
||||||
# Determine the base sampler first
|
drop_last=True,
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
|
||||||
elif self.args.curriculum_sampling:
|
|
||||||
base_sampler = SequentialSampler(self.train_dataset)
|
|
||||||
elif use_sample_packing:
|
|
||||||
base_sampler = RandomSampler(self.train_dataset)
|
|
||||||
else:
|
|
||||||
# Default to parent class implementation for standard random sampling
|
|
||||||
return super()._get_train_sampler()
|
|
||||||
|
|
||||||
# Apply multipack wrapper if needed
|
|
||||||
if use_sample_packing:
|
|
||||||
return self._create_multipack_sampler(
|
|
||||||
base_sampler=base_sampler,
|
|
||||||
dataset=self.train_dataset,
|
|
||||||
)
|
)
|
||||||
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
return base_sampler
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
if "length" in train_dataset.features.keys():
|
||||||
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
|
data_collator = self.data_collator
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_size": self._train_batch_size,
|
||||||
|
"collate_fn": data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
}
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
dataloader_params["prefetch_factor"] = (
|
||||||
|
self.args.dataloader_prefetch_factor
|
||||||
|
)
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
sampler = self._get_train_sampler()
|
||||||
"""
|
|
||||||
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
|
||||||
and sample packing cases.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
|
||||||
depends on the passed training args.
|
|
||||||
"""
|
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
||||||
|
|
||||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
|
||||||
use_multipack = (
|
|
||||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine the base sampler
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
|
||||||
elif use_multipack:
|
|
||||||
base_sampler = SequentialSampler(eval_dataset)
|
|
||||||
else:
|
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
|
||||||
|
|
||||||
# Apply multipack wrapper if needed
|
|
||||||
if use_multipack:
|
|
||||||
return self._create_multipack_sampler(
|
|
||||||
base_sampler=base_sampler,
|
|
||||||
dataset=eval_dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
return base_sampler
|
|
||||||
|
|
||||||
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
|
||||||
"""Create common dataloader parameters for train or eval."""
|
|
||||||
batch_size = custom_batch_size or (
|
|
||||||
self.args.eval_batch_size if is_eval else self._train_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"collate_fn": self.data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add persistent workers only for training
|
|
||||||
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
|
||||||
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
|
||||||
|
|
||||||
# Add prefetch factor if specified
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
|
||||||
|
|
||||||
return params
|
|
||||||
|
|
||||||
def _prepare_dataloader(
|
|
||||||
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
|
||||||
):
|
|
||||||
"""Prepare a dataloader with the given dataset and sampler."""
|
|
||||||
# Get base parameters
|
|
||||||
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
|
||||||
|
|
||||||
# Add sampler configuration
|
|
||||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
# batch_size and batch_sampler are mutually exclusive
|
|
||||||
dataloader_params["batch_sampler"] = sampler
|
dataloader_params["batch_sampler"] = sampler
|
||||||
del dataloader_params["batch_size"]
|
del dataloader_params["batch_size"]
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
if not is_eval:
|
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
|
||||||
|
|
||||||
# Create the dataloader
|
|
||||||
dataloader = DataLoader(dataset, **dataloader_params)
|
|
||||||
|
|
||||||
if self.args.sample_packing and (
|
|
||||||
(not is_eval and not self.args.pretraining)
|
|
||||||
or (is_eval and self.args.eval_sample_packing is not False)
|
|
||||||
):
|
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
return self.accelerator.prepare_data_loader(
|
||||||
# Return unprepared dataloader if using sequence parallelism
|
DataLoader(train_dataset, **dataloader_params)
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
return dataloader
|
|
||||||
|
|
||||||
# Otherwise prepare with accelerator
|
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
|
||||||
"""Get dataloader for training"""
|
|
||||||
train_dataset = self.train_dataset
|
|
||||||
data_collator = self.data_collator # type: ignore
|
|
||||||
|
|
||||||
# Handle dataset preprocessing
|
|
||||||
if isinstance(train_dataset, datasets.Dataset):
|
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
|
||||||
if not self.args.sample_packing or self.args.pretraining:
|
|
||||||
train_dataset = self._remove_unused_columns(
|
|
||||||
train_dataset, description="training"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
|
||||||
data_collator,
|
|
||||||
description="training",
|
|
||||||
)
|
)
|
||||||
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
# Get sampler and create dataloader
|
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||||
sampler = self._get_train_sampler()
|
|
||||||
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
|
||||||
"""Get dataloader for evaluation"""
|
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
|
||||||
|
|
||||||
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
)
|
)
|
||||||
if "length" in eval_dataset.column_names:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.train_data_collator
|
self.train_data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# Handle sample packing or sequence parallelism
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
if (
|
eval_dataset = (
|
||||||
self.args.sample_packing
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
and self.args.eval_sample_packing is not False
|
|
||||||
or self.args.sequence_parallel_degree > 1
|
|
||||||
):
|
|
||||||
# Get appropriate data collator
|
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.eval_data_collator
|
|
||||||
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
|
||||||
else self.data_collator
|
|
||||||
)
|
|
||||||
if "length" in eval_dataset.column_names:
|
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
|
||||||
|
|
||||||
# Handle dataset preprocessing for SP
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
if isinstance(eval_dataset, datasets.Dataset):
|
|
||||||
eval_dataset = self._remove_unused_columns(
|
|
||||||
eval_dataset, description="evaluation"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.data_collator, description="evaluation"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
|
||||||
batch_size = (
|
|
||||||
self.args.eval_batch_size
|
|
||||||
if self.args.sample_packing
|
|
||||||
else self.args.per_device_eval_batch_size
|
|
||||||
)
|
|
||||||
sampler = self._get_eval_sampler(eval_dataset)
|
|
||||||
dataloader = self._prepare_dataloader(
|
|
||||||
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataloader
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
|
data_collator = self.data_collator
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_size": self.args.eval_batch_size,
|
||||||
|
"collate_fn": data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
}
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
dataloader_params["prefetch_factor"] = (
|
||||||
|
self.args.dataloader_prefetch_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(eval_sampler, BatchSampler):
|
||||||
|
dataloader_params["batch_sampler"] = eval_sampler
|
||||||
|
del dataloader_params["batch_size"]
|
||||||
|
else:
|
||||||
|
dataloader_params["sampler"] = eval_sampler
|
||||||
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
|
self.accelerator.even_batches = False
|
||||||
|
return self.accelerator.prepare_data_loader(
|
||||||
|
DataLoader(eval_dataset, **dataloader_params)
|
||||||
|
)
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def _get_bench_sampler(
|
def _get_bench_sampler(
|
||||||
self, bench_dataset: Dataset
|
self, bench_dataset: Dataset
|
||||||
) -> torch.utils.data.Sampler | None:
|
) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return SequentialSampler(bench_dataset)
|
return SequentialSampler(bench_dataset)
|
||||||
return None
|
return None
|
||||||
@@ -347,7 +554,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||||
|
|
||||||
@override
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
):
|
):
|
||||||
@@ -364,7 +570,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
return_outputs=return_outputs,
|
return_outputs=return_outputs,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().compute_loss(
|
return super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -539,10 +744,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@@ -559,13 +764,15 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs: The values to log.
|
logs (`Dict[str, float]`):
|
||||||
start_time: The start of training.
|
The values to log.
|
||||||
|
start_time (`Optional[float]`):
|
||||||
|
The start of training.
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -577,7 +784,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
) -> None:
|
) -> None:
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
@@ -590,26 +797,110 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
def training_step(
|
|
||||||
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Mamba specific trainer to handle loss calculation
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "mamba"]
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model,
|
||||||
inputs: dict[str, torch.Tensor | Any],
|
inputs,
|
||||||
num_items_in_batch: int | None = None,
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
) -> torch.Tensor:
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
"""
|
):
|
||||||
Perform a training step on a batch of inputs. Overrides the
|
input_ids = inputs.pop("input_ids")
|
||||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
lm_logits = model(input_ids).logits
|
||||||
enabled.
|
|
||||||
|
|
||||||
Args:
|
labels = input_ids.to(lm_logits.device)
|
||||||
model: Model to perform training step for.
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
inputs: Dictionary mapping.
|
labels = labels[:, 1:].contiguous()
|
||||||
"""
|
|
||||||
# Set up sequence parallelism for this step if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._update_ring_flash_attn_params(inputs)
|
|
||||||
|
|
||||||
# Proceed with normal training step
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
return loss
|
return lm_loss
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRATrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "relora"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self,
|
||||||
|
num_training_steps: int,
|
||||||
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
|
):
|
||||||
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
|
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
if self.args.relora_steps:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||||
|
)
|
||||||
|
anneal_steps = (
|
||||||
|
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||||
|
)
|
||||||
|
self.lr_scheduler = ReLoRAScheduler(
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
self.args.relora_steps,
|
||||||
|
anneal_steps,
|
||||||
|
warmup_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base RewardTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base trl.PRMTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "prm"]
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import SchedulerMixin
|
from axolotl.core.trainers.base import (
|
||||||
from axolotl.core.trainers.utils import (
|
SchedulerMixin,
|
||||||
sanitize_kwargs_for_ds_tagging,
|
_sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
_sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
@@ -74,10 +74,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = sanitize_kwargs_for_ds_tagging(
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
"""Module for mamba trainer"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
|
||||||
"""Mamba specific trainer to handle loss calculation"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "mamba"]
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
inputs,
|
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
|
||||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
input_ids = inputs.pop("input_ids")
|
|
||||||
lm_logits = model(input_ids).logits
|
|
||||||
|
|
||||||
labels = input_ids.to(lm_logits.device)
|
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
lm_loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
return lm_loss
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
"""Init for axolotl.core.trainers.mixins"""
|
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
|
||||||
# flake8: noqa
|
|
||||||
|
|
||||||
from .optimizer import OptimizerMixin
|
|
||||||
from .scheduler import SchedulerMixin
|
|
||||||
from .sequence_parallel import SequenceParallelMixin
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
"""Module for Axolotl trainer optimizer mixin"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
|
||||||
from torch import nn
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
import smdistributed.modelparallel.torch as smp
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerMixin(Trainer):
|
|
||||||
"""Mixin class for shared handling of building custom optimizers"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(
|
|
||||||
self, opt_model, optimizer_kwargs
|
|
||||||
) -> list[dict]:
|
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
|
||||||
params: dict = {
|
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
|
||||||
"no_weight_decay": {},
|
|
||||||
}
|
|
||||||
lr_groups_lookup = {}
|
|
||||||
lr_groups_learning_rates = {}
|
|
||||||
if self.args.lr_groups:
|
|
||||||
for lr_group in self.args.lr_groups:
|
|
||||||
group_name = lr_group["name"]
|
|
||||||
group_modules = lr_group["modules"]
|
|
||||||
for module in group_modules:
|
|
||||||
lr_groups_lookup[module] = group_name
|
|
||||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
|
||||||
params[f"to_weight_decay_{group_name}"] = {}
|
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
if name.endswith("modules_to_save.default.weight") or any(
|
|
||||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
|
||||||
):
|
|
||||||
params["embeddings"][name] = param
|
|
||||||
elif name in decay_parameters:
|
|
||||||
lr_group_modules = [
|
|
||||||
group_modules
|
|
||||||
for group_modules in lr_groups_lookup
|
|
||||||
if group_modules in name
|
|
||||||
]
|
|
||||||
if lr_groups_lookup and any(lr_group_modules):
|
|
||||||
lr_group_module = lr_group_modules[0]
|
|
||||||
group_name = lr_groups_lookup[lr_group_module]
|
|
||||||
params[f"to_weight_decay_{group_name}"][name] = param
|
|
||||||
else:
|
|
||||||
params["to_weight_decay"][name] = param
|
|
||||||
else:
|
|
||||||
params["no_weight_decay"][name] = param
|
|
||||||
optimizer_grouped_parameters = []
|
|
||||||
if params["to_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["to_weight_decay"].values()),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["embeddings"]:
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
if self.args.embedding_lr_scale:
|
|
||||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
|
||||||
elif self.args.embedding_lr:
|
|
||||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["embeddings"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["no_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["no_weight_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
|
||||||
if params[f"to_weight_decay_{group_name}"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(
|
|
||||||
params[f"to_weight_decay_{group_name}"].values()
|
|
||||||
),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": group_lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer_grouped_parameters
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if (
|
|
||||||
self.args.loraplus_lr_ratio is None
|
|
||||||
and self.args.embedding_lr_scale is None
|
|
||||||
and self.args.embedding_lr is None
|
|
||||||
and self.args.lr_groups is None
|
|
||||||
and self.optimizer_cls_and_kwargs is None
|
|
||||||
):
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
|
|
||||||
if (
|
|
||||||
not self.optimizer
|
|
||||||
and self.optimizer_cls_and_kwargs is not None
|
|
||||||
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
|
||||||
):
|
|
||||||
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
self.optimizer = optimizer_factory_cls()(
|
|
||||||
opt_model, self.args, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.optimizer:
|
|
||||||
if self.optimizer_cls_and_kwargs is not None:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
|
||||||
else:
|
|
||||||
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args, opt_model
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
|
||||||
opt_model, optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
loraplus_lr_embedding = getattr(
|
|
||||||
self.args, "loraplus_lr_embedding", 1e-6
|
|
||||||
)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for GaLore optimizer.
|
|
||||||
if "params" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
|
||||||
|
|
||||||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
|
||||||
# e.g. for LOMO optimizer.
|
|
||||||
if "model" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
|
||||||
|
|
||||||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
|
||||||
# to avoid arguments conflicts.
|
|
||||||
if "optimizer_dict" in optimizer_kwargs:
|
|
||||||
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
|
||||||
"optimizer_dict"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.optimizer = optimizer_cls(
|
|
||||||
optimizer_grouped_parameters, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
for module in opt_model.modules():
|
|
||||||
if isinstance(module, nn.Embedding):
|
|
||||||
skipped += sum(
|
|
||||||
{
|
|
||||||
p.data_ptr(): p.numel() for p in module.parameters()
|
|
||||||
}.values()
|
|
||||||
)
|
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
|
||||||
manager.register_module_override(
|
|
||||||
module, "weight", {"optim_bits": 32}
|
|
||||||
)
|
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
"""Module for Axolotl trainer scheduler mixin"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
|
|
||||||
from axolotl.utils.schedulers import (
|
|
||||||
RexLR,
|
|
||||||
get_cosine_schedule_with_min_lr,
|
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
|
||||||
get_cosine_schedule_with_warmup_decay_constant,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Mixin class for scheduler setup in CausalTrainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
|
||||||
passed as an argument.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_training_steps (int): The number of training steps to do.
|
|
||||||
optimizer (torch.optim.Optimizer): The training optimizer
|
|
||||||
"""
|
|
||||||
use_cosine_quadratic = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.lr_quadratic_warmup is True
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cosine_min_lr = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.cosine_min_lr_ratio is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
|
||||||
# fmt: on
|
|
||||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
|
||||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
|
||||||
extra_lr_kwargs = {}
|
|
||||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["pct_start"] = pct_start
|
|
||||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
|
||||||
|
|
||||||
self.lr_scheduler = OneCycleLR(
|
|
||||||
optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
**extra_lr_kwargs,
|
|
||||||
**self.args.lr_scheduler_kwargs,
|
|
||||||
)
|
|
||||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
|
|
||||||
self.lr_scheduler = RexLR(
|
|
||||||
optimizer=optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
)
|
|
||||||
elif use_cosine_quadratic:
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
|
||||||
else:
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from datasets import Dataset
|
|
||||||
from torch.utils.data import DistributedSampler, Sampler
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from ring_flash_attn import update_ring_flash_attn_params
|
|
||||||
except ImportError:
|
|
||||||
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
|
||||||
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SequenceParallelMixin:
|
|
||||||
"""
|
|
||||||
Mixin class for sequence parallelism support in trainers.
|
|
||||||
|
|
||||||
This mixin provides functionality for handling sequence parallelism,
|
|
||||||
including creating appropriate samplers, managing data partitioning,
|
|
||||||
and updating ring flash attention parameters during training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def _setup_sequence_parallel(self):
|
|
||||||
"""Set up sequence parallelism environment."""
|
|
||||||
self.ring_attn_group = get_ring_attn_group()
|
|
||||||
|
|
||||||
def _create_sequence_parallel_sampler(
|
|
||||||
self,
|
|
||||||
dataset: Dataset,
|
|
||||||
shuffle: bool = True,
|
|
||||||
is_eval: bool = False,
|
|
||||||
) -> DistributedSampler:
|
|
||||||
"""
|
|
||||||
Helper method to create sampler for sequence parallelism (SP).
|
|
||||||
|
|
||||||
We create a distributed sampler with rank equal to the SP group ID, which
|
|
||||||
means that all ranks in the SP group receive the same sample / set of samples
|
|
||||||
per training step. We also set the number of replicas equal to the number of
|
|
||||||
SP groups, which is a bit of a hack / unintended use, but works!
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: Dataset to sample from.
|
|
||||||
shuffle: Whether to shuffle the dataset.
|
|
||||||
is_eval: Whether we are creating a sampler for evaluation or training.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Distributed sampler.
|
|
||||||
"""
|
|
||||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
|
||||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
return DistributedSampler(
|
|
||||||
dataset,
|
|
||||||
num_replicas=num_sp_groups,
|
|
||||||
rank=sp_group_id,
|
|
||||||
seed=self.args.seed if shuffle else None,
|
|
||||||
shuffle=shuffle,
|
|
||||||
drop_last=not is_eval,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
|
||||||
"""
|
|
||||||
Get a training sampler configured for sequence parallelism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: The training dataset
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured sequence parallel sampler.
|
|
||||||
"""
|
|
||||||
return self._create_sequence_parallel_sampler(
|
|
||||||
dataset,
|
|
||||||
shuffle=not self.args.curriculum_sampling,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
|
||||||
"""
|
|
||||||
Get an evaluation sampler configured for sequence parallelism.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
eval_dataset: The evaluation dataset.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured sequence parallel sampler.
|
|
||||||
"""
|
|
||||||
return self._create_sequence_parallel_sampler(
|
|
||||||
eval_dataset, shuffle=False, is_eval=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
|
||||||
"""
|
|
||||||
Calculate the cu_seqlens for the current forward pass and pass the value to
|
|
||||||
the substituted ring_flash_attn. This is accomplished by using the passed
|
|
||||||
`input_ids`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Current batch of inputs.
|
|
||||||
"""
|
|
||||||
# At this point, inputs should already be partitioned by the sequence
|
|
||||||
# parallel data collator
|
|
||||||
batch_size = inputs["input_ids"].shape[0]
|
|
||||||
seq_len = inputs["input_ids"].shape[1]
|
|
||||||
packed_seq_lens = [seq_len] * batch_size
|
|
||||||
|
|
||||||
# Calculate the full sequence length across all GPUs in this SP group
|
|
||||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
|
||||||
|
|
||||||
cu_seqlens = torch.cumsum(
|
|
||||||
torch.tensor(
|
|
||||||
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
cu_seqlens = F.pad(
|
|
||||||
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
"""Module for ReLoRA trainer"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRATrainer(AxolotlTrainer):
|
|
||||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "relora"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.lr_scheduler = None
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self,
|
|
||||||
num_training_steps: int,
|
|
||||||
optimizer: torch.optim.Optimizer | None = None,
|
|
||||||
):
|
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
|
||||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
|
||||||
|
|
||||||
if self.args.relora_steps:
|
|
||||||
warmup_steps = (
|
|
||||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
|
||||||
)
|
|
||||||
anneal_steps = (
|
|
||||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
|
||||||
)
|
|
||||||
self.lr_scheduler = ReLoRAScheduler(
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler,
|
|
||||||
self.args.relora_steps,
|
|
||||||
anneal_steps,
|
|
||||||
warmup_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
@@ -1,25 +1,16 @@
|
|||||||
"""Module for TRL PPO trainer"""
|
"""
|
||||||
|
module for TRL PPO training
|
||||||
from typing import Literal, Union
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trl import (
|
from trl import PPOTrainer
|
||||||
CPOTrainer,
|
|
||||||
KTOTrainer,
|
|
||||||
ORPOTrainer,
|
|
||||||
PPOTrainer,
|
|
||||||
PRMTrainer,
|
|
||||||
RewardTrainer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
class TRLPPOTrainer(PPOTrainer):
|
||||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
"""
|
||||||
|
wrapper for ppo trainer to handle customizations
|
||||||
tag_names = ["axolotl", "ppo"]
|
"""
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
@@ -40,7 +31,9 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, batch in tqdm(enumerate(self.dataloader)):
|
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
||||||
|
enumerate(self.dataloader)
|
||||||
|
):
|
||||||
query_tensors = batch["input_ids"]
|
query_tensors = batch["input_ids"]
|
||||||
|
|
||||||
# generate model response
|
# generate model response
|
||||||
@@ -72,189 +65,3 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
rewards,
|
rewards,
|
||||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
|
||||||
|
|
||||||
def get_batch_loss_metrics(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
batch: dict[str, Union[list, torch.LongTensor]],
|
|
||||||
train_eval: Literal["train", "eval"] = "train",
|
|
||||||
):
|
|
||||||
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
|
||||||
|
|
||||||
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
|
|
||||||
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
forward_output = self.concatenated_forward(model, batch)
|
|
||||||
(
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
policy_chosen_logits,
|
|
||||||
policy_rejected_logits,
|
|
||||||
policy_nll_loss,
|
|
||||||
) = forward_output[:5]
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
aux_loss = forward_output[5]
|
|
||||||
|
|
||||||
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
|
|
||||||
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
|
||||||
)
|
|
||||||
# full ORPO loss
|
|
||||||
loss = policy_nll_loss - losses.mean()
|
|
||||||
|
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
|
||||||
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
|
|
||||||
chosen_rewards
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
|
|
||||||
rejected_rewards
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
|
|
||||||
reward_accuracies
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
|
||||||
chosen_rewards - rejected_rewards
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}logps/rejected"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logps/chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
|
|
||||||
policy_rejected_logits.detach().mean()
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
|
|
||||||
policy_chosen_logits.detach().mean()
|
|
||||||
).mean()
|
|
||||||
metrics[f"{prefix}nll_loss"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}log_odds_ratio"] = (
|
|
||||||
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}log_odds_chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
|
|
||||||
)
|
|
||||||
for k, v in metrics.items():
|
|
||||||
metrics[k] = v.item()
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
loss += self.aux_loss_coef * aux_loss
|
|
||||||
|
|
||||||
return loss, metrics
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
|
||||||
|
|
||||||
def get_batch_loss_metrics(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
batch: dict[str, Union[list, torch.LongTensor]],
|
|
||||||
train_eval: Literal["train", "eval"] = "train",
|
|
||||||
):
|
|
||||||
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
forward_output = self.concatenated_forward(model, batch)
|
|
||||||
(
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
policy_chosen_logits,
|
|
||||||
policy_rejected_logits,
|
|
||||||
policy_nll_loss,
|
|
||||||
) = forward_output[:5]
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
aux_loss = forward_output[5]
|
|
||||||
|
|
||||||
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
|
||||||
policy_chosen_logps,
|
|
||||||
policy_rejected_logps,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
|
||||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
||||||
|
|
||||||
prefix = "eval_" if train_eval == "eval" else ""
|
|
||||||
metrics[f"{prefix}rewards/chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}rewards/rejected"] = (
|
|
||||||
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}rewards/accuracies"] = (
|
|
||||||
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}rewards/margins"] = (
|
|
||||||
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logps/rejected"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_rejected_logps)
|
|
||||||
.detach()
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logps/chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_chosen_logps)
|
|
||||||
.detach()
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logits/rejected"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}logits/chosen"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
|
|
||||||
.mean()
|
|
||||||
.item()
|
|
||||||
)
|
|
||||||
metrics[f"{prefix}nll_loss"] = (
|
|
||||||
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.aux_loss_enabled:
|
|
||||||
loss += self.aux_loss_coef * aux_loss
|
|
||||||
|
|
||||||
return loss, metrics
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "prm"]
|
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
"""Utils for Axolotl trainers"""
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|
||||||
if isinstance(tag_names, str):
|
|
||||||
tag_names = [tag_names]
|
|
||||||
|
|
||||||
if kwargs is not None:
|
|
||||||
if "tags" not in kwargs:
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
|
||||||
kwargs["tags"].extend(tag_names)
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
|
||||||
tag_names.append(kwargs["tags"])
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
|
||||||
if isinstance(dataset_tags, str):
|
|
||||||
dataset_tags = [dataset_tags]
|
|
||||||
|
|
||||||
if (dataset_tags is not None) and (kwargs is not None):
|
|
||||||
if "dataset_tags" not in kwargs:
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
|
||||||
kwargs["dataset_tags"].extend(dataset_tags)
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
|
||||||
dataset_tags.append(kwargs["dataset_tags"])
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
@@ -5,7 +5,6 @@ extra axolotl specific training args
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL.Image import Resampling
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
@@ -208,33 +207,14 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_parallel_degree: Optional[int] = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# multi-modal section
|
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The size of the image to resize to"},
|
|
||||||
)
|
|
||||||
|
|
||||||
image_resize_algorithm: Resampling | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The algorithm to use for image resizing"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# end of multi-modal section
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
||||||
default value so it can't be used as a mixin.
|
so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# Cut Cross Entropy
|
# Cut Cross Entropy
|
||||||
|
|
||||||
Cut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
|
||||||
|
|
||||||
See https://github.com/apple/ml-cross-entropy
|
See https://github.com/apple/ml-cross-entropy
|
||||||
|
|
||||||
@@ -29,20 +29,6 @@ plugins:
|
|||||||
cut_cross_entropy: true
|
cut_cross_entropy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
- llama
|
|
||||||
- phi3
|
|
||||||
- gemma
|
|
||||||
- gemma2
|
|
||||||
- gemma3
|
|
||||||
- gemma3_text
|
|
||||||
- mistral
|
|
||||||
- mistral3
|
|
||||||
- qwen2
|
|
||||||
- cohere
|
|
||||||
- cohere2
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
```bib
|
```bib
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ import torch
|
|||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils import get_pytorch_version
|
from axolotl.utils import get_pytorch_version
|
||||||
from axolotl.utils.distributed import zero_only
|
|
||||||
|
|
||||||
|
from ...utils.distributed import zero_only
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||||
@@ -72,9 +72,7 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
if cfg.cut_cross_entropy:
|
if cfg.cut_cross_entropy:
|
||||||
self._check_requirements()
|
self._check_requirements()
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
from cut_cross_entropy.transformers import cce_patch
|
||||||
cce_patch,
|
|
||||||
)
|
|
||||||
|
|
||||||
with zero_only():
|
with zero_only():
|
||||||
LOG.info(
|
LOG.info(
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
"""Cohere and Cohere2 CCE patch."""
|
|
||||||
|
|
||||||
# This patch is based off transformers 4.50.0.
|
|
||||||
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
|
||||||
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
|
||||||
# operation is done internally and should be mathematically equivalent.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.cohere.modeling_cohere import (
|
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
COHERE_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>> from transformers import AutoTokenizer, CohereForCausalLM
|
|
||||||
|
|
||||||
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
|
||||||
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
|
||||||
|
|
||||||
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>> # Generate
|
|
||||||
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
# scale weight by logit_scale in-place of logits
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight * self.logit_scale,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
logits = logits * self.logit_scale # main diff from Llama
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_cohere(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.cohere import modeling_cohere
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_cohere.CohereForCausalLM
|
|
||||||
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_cohere2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.cohere2 import modeling_cohere2
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
|
||||||
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
"""Gemma CCE patch"""
|
|
||||||
|
|
||||||
# This patch is based off transformers 4.50.0.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.gemma.modeling_gemma import (
|
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
||||||
|
|
||||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
|
||||||
|
|
||||||
>>> prompt = "What is your favorite condiment?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is your favorite condiment?"
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma import modeling_gemma
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma.GemmaForCausalLM
|
|
||||||
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,459 +0,0 @@
|
|||||||
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
|
||||||
|
|
||||||
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
|
||||||
# and updated for transformers 4.50.0.
|
|
||||||
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
|
||||||
# with both gemma3 (text and multimodal) models.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache, HybridCache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA3_INPUTS_DOCSTRING,
|
|
||||||
Gemma3CausalLMOutputWithPast,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[HybridCache] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
|
||||||
|
|
||||||
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
||||||
|
|
||||||
>>> prompt = "What is your favorite condiment?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is your favorite condiment?"
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if self.config.final_logit_softcapping is not None:
|
|
||||||
logits = logits / self.config.final_logit_softcapping
|
|
||||||
logits = torch.tanh(logits)
|
|
||||||
logits = logits * self.config.final_logit_softcapping
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
|
||||||
|
|
||||||
>>> prompt = "answer en Where is the cow standing?"
|
|
||||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"answer en Where is the cow standing?\nbeach"
|
|
||||||
```"""
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
is_training = token_type_ids is not None and labels is not None
|
|
||||||
|
|
||||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
|
||||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
|
||||||
special_image_mask = input_ids == self.config.image_token_index
|
|
||||||
llm_input_ids = input_ids.clone()
|
|
||||||
llm_input_ids[special_image_mask] = 0
|
|
||||||
else:
|
|
||||||
llm_input_ids = input_ids # type: ignore
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
|
||||||
|
|
||||||
if cache_position is None:
|
|
||||||
past_seen_tokens = (
|
|
||||||
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
|
||||||
)
|
|
||||||
cache_position = torch.arange( # type: ignore
|
|
||||||
past_seen_tokens,
|
|
||||||
past_seen_tokens + inputs_embeds.shape[1],
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge text and images
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(pixel_values)
|
|
||||||
|
|
||||||
if input_ids is None:
|
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
||||||
torch.tensor(
|
|
||||||
self.config.image_token_index,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
|
||||||
-1
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not is_torchdynamo_compiling()
|
|
||||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
|
||||||
):
|
|
||||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of images does not match number of special image tokens in the input text. "
|
|
||||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
|
||||||
"tokens from image embeddings."
|
|
||||||
)
|
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
|
||||||
|
|
||||||
# mask out pad-token-ids in labels for BC
|
|
||||||
if labels is not None and self.pad_token_id in labels:
|
|
||||||
logger.warning_once(
|
|
||||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
|
||||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
|
||||||
)
|
|
||||||
labels = torch.where( # type: ignore
|
|
||||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
|
||||||
)
|
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
|
||||||
attention_mask,
|
|
||||||
token_type_ids,
|
|
||||||
past_key_values,
|
|
||||||
cache_position,
|
|
||||||
inputs_embeds,
|
|
||||||
is_training,
|
|
||||||
)
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
shift_logits = logits[..., :-1, :]
|
|
||||||
shift_labels = labels[..., 1:]
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = shift_logits[
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = shift_labels[
|
|
||||||
shift_attention_mask.to(shift_labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = shift_logits.contiguous()
|
|
||||||
shift_labels = shift_labels.contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
|
||||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
||||||
loss = loss_fct(flat_logits, flat_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Gemma3CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma2 import modeling_gemma2
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
|
||||||
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma3_text(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma3 import modeling_gemma3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
|
||||||
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma3 import modeling_gemma3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
|
||||||
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,392 +0,0 @@
|
|||||||
"""Mistral and Mistral3 CCE patch."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.mistral3.modeling_mistral3 import (
|
|
||||||
Mistral3CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
MISTRAL_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] | None = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
|
||||||
|
|
||||||
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
image_sizes: torch.Tensor | None = None,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
||||||
|
|
||||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
vision_feature_layer = (
|
|
||||||
vision_feature_layer
|
|
||||||
if vision_feature_layer is not None
|
|
||||||
else self.config.vision_feature_layer
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
vision_feature_layer=vision_feature_layer,
|
|
||||||
image_sizes=image_sizes,
|
|
||||||
)
|
|
||||||
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
not is_torchdynamo_compiling()
|
|
||||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
|
||||||
):
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
|
||||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
||||||
)
|
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = logits[..., :-1, :][
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = labels[..., 1:][
|
|
||||||
shift_attention_mask.to(labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1).to(shift_logits.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Mistral3CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mistral import modeling_mistral
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mistral.MistralForCausalLM
|
|
||||||
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mistral import modeling_mistral
|
|
||||||
from transformers.models.mistral3 import modeling_mistral3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
|
||||||
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,379 +0,0 @@
|
|||||||
"""Mllama CCE patch."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.mllama.modeling_mllama import (
|
|
||||||
MLLAMA_INPUTS_DOCSTRING,
|
|
||||||
_prepare_cross_attention_mask,
|
|
||||||
)
|
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
|
|
||||||
)
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_mask: Optional[torch.LongTensor] = None,
|
|
||||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
|
||||||
|
|
||||||
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
|
||||||
|
|
||||||
>>> prompt = "If I had to write a haiku, it would be:"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
|
||||||
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
>>> print(result)
|
|
||||||
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
|
||||||
I love the idea of snowflakes gently falling, each one
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cross_attention_mask=cross_attention_mask,
|
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
|
||||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
|
||||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
|
||||||
|
|
||||||
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
|
||||||
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
|
||||||
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
|
||||||
|
|
||||||
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> output = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
|
|
||||||
>>> prompt_len = inputs.input_ids.shape[-1]
|
|
||||||
>>> generated_ids = output[:, prompt_len:]
|
|
||||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
||||||
>>> print(generated_text)
|
|
||||||
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None and cross_attention_states is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
if aspect_ratio_ids is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
|
||||||
)
|
|
||||||
# get vision tokens from vision model
|
|
||||||
vision_outputs = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
aspect_ratio_ids=aspect_ratio_ids,
|
|
||||||
aspect_ratio_mask=aspect_ratio_mask,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
cross_attention_states = vision_outputs[0]
|
|
||||||
cross_attention_states = self.multi_modal_projector(
|
|
||||||
cross_attention_states
|
|
||||||
).reshape(
|
|
||||||
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
if cross_attention_mask is not None:
|
|
||||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
|
||||||
_prepare_cross_attention_mask(
|
|
||||||
cross_attention_mask,
|
|
||||||
num_vision_tokens=self.vision_model.num_patches,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
full_text_row_masked_out_mask = None
|
|
||||||
|
|
||||||
if cross_attention_mask is not None and cache_position is not None:
|
|
||||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
|
||||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
|
||||||
:, :, cache_position
|
|
||||||
]
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
cross_attention_mask=cross_attention_mask,
|
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
|
||||||
logits = hidden_states
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (loss,) + outputs if loss is not None else outputs
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=outputs.logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mllama(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mllama import modeling_mllama
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
|
||||||
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the language model
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
|
|
||||||
# patch the causal language model
|
|
||||||
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
||||||
|
|
||||||
"""Cut Cross Entropy patcher"""
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
|
||||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
|
||||||
from cut_cross_entropy.transformers.llama import patch_llama
|
|
||||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
|
||||||
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
|
|
||||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
|
||||||
patch_cohere,
|
|
||||||
patch_cohere2,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
|
||||||
patch_gemma2,
|
|
||||||
patch_gemma3,
|
|
||||||
patch_gemma3_text,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
|
||||||
patch_mistral,
|
|
||||||
patch_mistral3,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
|
||||||
|
|
||||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
|
||||||
"llama": patch_llama,
|
|
||||||
"mllama": patch_mllama,
|
|
||||||
"phi3": patch_phi3,
|
|
||||||
"gemma": patch_gemma,
|
|
||||||
"gemma2": patch_gemma2,
|
|
||||||
"gemma3": patch_gemma3,
|
|
||||||
"gemma3_text": patch_gemma3_text,
|
|
||||||
"mistral": patch_mistral,
|
|
||||||
"mistral3": patch_mistral3,
|
|
||||||
"qwen2": patch_qwen2,
|
|
||||||
"cohere": patch_cohere,
|
|
||||||
"cohere2": patch_cohere2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def cce_patch(
|
|
||||||
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
|
||||||
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
|
||||||
reduction: str = "mean",
|
|
||||||
filter_eps: float | str | None = "auto",
|
|
||||||
accum_e_fp32: bool = False,
|
|
||||||
accum_c_fp32: bool = False,
|
|
||||||
filter_e_grad: bool = True,
|
|
||||||
filter_c_grad: bool = True,
|
|
||||||
train_only: bool = False,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
if isinstance(impl, LinearCrossEntropyImpl):
|
|
||||||
impl = impl.name.lower()
|
|
||||||
|
|
||||||
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
|
||||||
raise ValueError(f"Unknown {impl=}")
|
|
||||||
|
|
||||||
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
|
||||||
model_type = model_type_or_model.config.model_type
|
|
||||||
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
|
||||||
model_type = model_type_or_model.model_type
|
|
||||||
else:
|
|
||||||
model_type = model_type_or_model
|
|
||||||
|
|
||||||
patch_options = PatchOptions(
|
|
||||||
impl=impl,
|
|
||||||
reduction=reduction,
|
|
||||||
filter_eps=filter_eps,
|
|
||||||
accum_e_fp32=accum_e_fp32,
|
|
||||||
accum_c_fp32=accum_c_fp32,
|
|
||||||
filter_e_grad=filter_e_grad,
|
|
||||||
filter_c_grad=filter_c_grad,
|
|
||||||
train_only=train_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
|
||||||
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
|
||||||
model_type_or_model, patch_options
|
|
||||||
)
|
|
||||||
|
|
||||||
raise RuntimeError(f"Unknown model type {model_type}")
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
||||||
|
|
||||||
"""Monkeypatch for apply_lce to add softcap."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from cut_cross_entropy import linear_cross_entropy
|
|
||||||
from cut_cross_entropy.transformers.utils import PatchOptions
|
|
||||||
|
|
||||||
|
|
||||||
def apply_lce(
|
|
||||||
e: torch.Tensor,
|
|
||||||
c: torch.Tensor,
|
|
||||||
labels: torch.Tensor,
|
|
||||||
opts: PatchOptions,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
softcap: float | None = None,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Monkey patch for apply_lce to support softcap kwarg."""
|
|
||||||
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
|
||||||
cce_kwargs = opts.to_kwargs()
|
|
||||||
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
|
||||||
cce_kwargs["reduction"] = "sum"
|
|
||||||
else:
|
|
||||||
num_items_in_batch = None
|
|
||||||
|
|
||||||
loss = linear_cross_entropy(
|
|
||||||
e,
|
|
||||||
c,
|
|
||||||
labels.to(e.device),
|
|
||||||
bias=bias,
|
|
||||||
shift=True,
|
|
||||||
softcap=softcap,
|
|
||||||
**cce_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_items_in_batch is not None:
|
|
||||||
loss = loss / num_items_in_batch
|
|
||||||
|
|
||||||
return loss
|
|
||||||
@@ -20,26 +20,6 @@ liger_layer_norm: true
|
|||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
```
|
```
|
||||||
|
|
||||||
## Supported Models
|
|
||||||
|
|
||||||
- deepseek_v2
|
|
||||||
- gemma
|
|
||||||
- gemma2
|
|
||||||
- gemma3 (partial support, no support for FLCE yet)
|
|
||||||
- granite
|
|
||||||
- jamba
|
|
||||||
- llama
|
|
||||||
- mistral
|
|
||||||
- mixtral
|
|
||||||
- mllama
|
|
||||||
- mllama_text_model
|
|
||||||
- olmo2
|
|
||||||
- paligemma
|
|
||||||
- phi3
|
|
||||||
- qwen2
|
|
||||||
- qwen2_5_vl
|
|
||||||
- qwen2_vl
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
```bib
|
```bib
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight.
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
@@ -42,18 +41,11 @@ class LigerPlugin(BasePlugin):
|
|||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
|
||||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||||
@@ -90,8 +82,6 @@ class LigerPlugin(BasePlugin):
|
|||||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
if cfg.liger_layer_norm:
|
|
||||||
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
from transformers.loss.loss_utils import nn
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
@@ -114,51 +104,13 @@ class LigerPlugin(BasePlugin):
|
|||||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||||
if cfg.liger_glu_activation:
|
|
||||||
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
|
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
if cfg.liger_layer_norm:
|
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||||
# nn.CrossEntropyLoss in the forward method.
|
# nn.CrossEntropyLoss in the forward method.
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
|
|
||||||
from transformers.models.gemma3 import modeling_gemma3
|
|
||||||
|
|
||||||
if cfg.liger_rope:
|
|
||||||
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
||||||
if cfg.liger_rms_norm:
|
|
||||||
|
|
||||||
def _liger_rms_norm_wrapper(dim, **kwargs):
|
|
||||||
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
|
|
||||||
return LigerRMSNorm(hidden_size=dim, **kwargs)
|
|
||||||
|
|
||||||
modeling_gemma3.Gemma3RMSNorm = partial(
|
|
||||||
_liger_rms_norm_wrapper,
|
|
||||||
offset=1.0,
|
|
||||||
casting_mode="gemma",
|
|
||||||
init_fn="zeros",
|
|
||||||
in_place=False,
|
|
||||||
)
|
|
||||||
if cfg.liger_glu_activation:
|
|
||||||
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
|
|
||||||
if cfg.liger_layer_norm:
|
|
||||||
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
|
|
||||||
|
|
||||||
if cfg.liger_cross_entropy:
|
|
||||||
from transformers.loss.loss_utils import nn
|
|
||||||
|
|
||||||
nn.functional.cross_entropy = liger_cross_entropy
|
|
||||||
|
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Fused linear cross entropy is not yet supported for Gemma3."
|
|
||||||
)
|
|
||||||
elif cfg.model_config_type in ["deepseek_v3"]:
|
|
||||||
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
|
||||||
|
|||||||
@@ -1,89 +0,0 @@
|
|||||||
"""
|
|
||||||
Ring attention group registration and flash attention patching.
|
|
||||||
|
|
||||||
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
|
||||||
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
|
||||||
their sequence parallel version of Flash Attention 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch.distributed as dist
|
|
||||||
from accelerate.logging import get_logger
|
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
RING_ATTN_GROUP = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
|
||||||
"""
|
|
||||||
Getter for ring attention group on this rank.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The process group for ring attention for this rank.
|
|
||||||
"""
|
|
||||||
return RING_ATTN_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
|
||||||
"""
|
|
||||||
Setter for ring attention group on this rank.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
Process group for ring attention.
|
|
||||||
"""
|
|
||||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
|
||||||
RING_ATTN_GROUP = ring_attn_group
|
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(sequence_parallel_degree: int):
|
|
||||||
"""
|
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequence_parallel_degree: Sequence parallelism factor.
|
|
||||||
"""
|
|
||||||
LOG.info(
|
|
||||||
"Enabling ring attention sequence parallelism: "
|
|
||||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
|
||||||
)
|
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
assert sequence_parallel_degree <= world_size, (
|
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
|
||||||
f"must be less than or equal to world_size ({world_size})"
|
|
||||||
)
|
|
||||||
assert world_size % sequence_parallel_degree == 0, (
|
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
|
||||||
f"must evenly divide world_size ({world_size})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Detailed logging of group formation
|
|
||||||
rank = dist.get_rank()
|
|
||||||
group_assignments = {}
|
|
||||||
|
|
||||||
for i in range(world_size // sequence_parallel_degree):
|
|
||||||
ring_attn_ranks = list(
|
|
||||||
range(
|
|
||||||
i * sequence_parallel_degree,
|
|
||||||
(i + 1) * sequence_parallel_degree,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
|
||||||
|
|
||||||
# Track which GPUs are in which groups
|
|
||||||
for r in ring_attn_ranks:
|
|
||||||
group_assignments[r] = i
|
|
||||||
|
|
||||||
if rank in ring_attn_ranks:
|
|
||||||
set_ring_attn_group(group)
|
|
||||||
|
|
||||||
# Log the GPU group assignments
|
|
||||||
if rank == 0:
|
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
|
||||||
|
|
||||||
from ring_flash_attn import substitute_hf_flash_attn
|
|
||||||
|
|
||||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
|
||||||
@@ -22,9 +22,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"phi3",
|
"phi3",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
"gemma3_text",
|
|
||||||
"cohere",
|
|
||||||
"cohere2",
|
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
"deepseek_v2",
|
"deepseek_v2",
|
||||||
|
|||||||
@@ -1,278 +0,0 @@
|
|||||||
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from PIL import Image, ImageOps
|
|
||||||
from PIL.Image import Resampling
|
|
||||||
from torch import Tensor
|
|
||||||
from transformers import ProcessorMixin
|
|
||||||
from transformers.image_utils import load_image
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessingStrategy:
|
|
||||||
"""Base Processing Strategy class"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
processor: ProcessorMixin,
|
|
||||||
chat_template: Optional[str] = None,
|
|
||||||
image_size: int | tuple[int, int] | None = None,
|
|
||||||
image_resize_algorithm: Resampling | None = None,
|
|
||||||
):
|
|
||||||
self.processor = processor
|
|
||||||
self.chat_template = chat_template
|
|
||||||
self.image_token = None
|
|
||||||
self.image_token_id = None
|
|
||||||
|
|
||||||
self.image_size = image_size
|
|
||||||
self.image_resize_algorithm = (
|
|
||||||
image_resize_algorithm or Image.Resampling.BILINEAR
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(processor, "image_token"):
|
|
||||||
self.image_token = processor.image_token
|
|
||||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
|
||||||
self.image_token
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, examples: list[dict]) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Preprocess conversation examples to ensure consistent format.
|
|
||||||
Converts different conversation formats to OpenAI format with 'messages'.
|
|
||||||
Supports two formats:
|
|
||||||
1. OpenAI format with 'messages'
|
|
||||||
2. Legacy format with 'conversations'
|
|
||||||
|
|
||||||
Args:
|
|
||||||
examples: list of conversation dictionaries
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of dicts in OpenAI format with 'messages' key
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the conversation format is not supported
|
|
||||||
"""
|
|
||||||
role_mapping = {
|
|
||||||
"human": "user",
|
|
||||||
"gpt": "assistant",
|
|
||||||
}
|
|
||||||
|
|
||||||
def normalize_role(role: str) -> str:
|
|
||||||
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
|
||||||
return role_mapping.get(role, role)
|
|
||||||
|
|
||||||
def convert_legacy_format(example: dict) -> dict:
|
|
||||||
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
|
||||||
messages = [
|
|
||||||
{"role": normalize_role(convo["from"]), "content": convo["value"]}
|
|
||||||
for convo in example["conversations"]
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create new dict without 'conversations' key
|
|
||||||
result = deepcopy(example)
|
|
||||||
result.pop("conversations")
|
|
||||||
result["messages"] = messages
|
|
||||||
return result
|
|
||||||
|
|
||||||
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
|
|
||||||
"""Convert regular messages format to Messages format with content type"""
|
|
||||||
|
|
||||||
new_messages = []
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message["content"], str):
|
|
||||||
new_messages.append(
|
|
||||||
{
|
|
||||||
"role": message["role"],
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": message["content"],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif isinstance(message["content"], list):
|
|
||||||
content = message["content"]
|
|
||||||
|
|
||||||
new_messages.append(
|
|
||||||
{
|
|
||||||
"role": message["role"],
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_messages
|
|
||||||
|
|
||||||
processed_examples = []
|
|
||||||
for example in examples:
|
|
||||||
if not ("messages" in example or "conversations" in example):
|
|
||||||
raise ValueError(
|
|
||||||
"Only `messages` and `conversations` message keys are currently supported."
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_example = None
|
|
||||||
if "messages" in example: # OpenAI format
|
|
||||||
processed_example = example
|
|
||||||
else: # Legacy format
|
|
||||||
processed_example = convert_legacy_format(example)
|
|
||||||
|
|
||||||
# convert regular messages format to Messages format with content type
|
|
||||||
# for compatibility with apply_chat_template
|
|
||||||
processed_example["messages"] = convert_messages_to_multimedia_messages(
|
|
||||||
processed_example["messages"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# find the image key if it exists
|
|
||||||
possible_image_keys = ["images", "image"]
|
|
||||||
image_key = None
|
|
||||||
for key in possible_image_keys:
|
|
||||||
if key in processed_example:
|
|
||||||
image_key = key
|
|
||||||
break
|
|
||||||
|
|
||||||
# if the image key exists, add the image to the first message
|
|
||||||
if image_key is not None:
|
|
||||||
# TODO: check if it's normal to be single image only for common datasets
|
|
||||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
|
||||||
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
|
|
||||||
image_value = processed_example[image_key][0]
|
|
||||||
|
|
||||||
# Handle image loading (Image, url, path, base64)
|
|
||||||
image_value = load_image(image_value)
|
|
||||||
|
|
||||||
if self.image_size is not None:
|
|
||||||
assert hasattr(
|
|
||||||
image_value, "resize"
|
|
||||||
), "Image does not have a resize method"
|
|
||||||
|
|
||||||
if isinstance(self.image_size, tuple):
|
|
||||||
image_value = image_value.resize(
|
|
||||||
self.image_size, self.image_resize_algorithm
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Set the padding value; here we use black (0, 0, 0) for RGB images
|
|
||||||
padding_color = (0, 0, 0)
|
|
||||||
|
|
||||||
# When image_size is an int (square target), preserve aspect ratio then pad
|
|
||||||
# This is to prevent aspect ratio distortion when resizing to square
|
|
||||||
image_value = ImageOps.pad(
|
|
||||||
image_value,
|
|
||||||
(self.image_size, self.image_size),
|
|
||||||
method=self.image_resize_algorithm,
|
|
||||||
color=padding_color,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Look for any image type in the first message
|
|
||||||
# some dataset have an {type: "image"} in the first message
|
|
||||||
ind_to_add = None
|
|
||||||
|
|
||||||
for i, content in enumerate(
|
|
||||||
processed_example["messages"][0]["content"]
|
|
||||||
):
|
|
||||||
# Usually datasets created with image columns, don't have it in the messages itself
|
|
||||||
if content["type"] == "image" and all(
|
|
||||||
k not in content for k in ["image", "url", "path", "base64"]
|
|
||||||
):
|
|
||||||
ind_to_add = i
|
|
||||||
break
|
|
||||||
|
|
||||||
# If an image type is found, add the image to that index
|
|
||||||
if ind_to_add is not None:
|
|
||||||
processed_example["messages"][0]["content"][ind_to_add][
|
|
||||||
"image"
|
|
||||||
] = image_value
|
|
||||||
else:
|
|
||||||
# if no image type is found, add it to end of the first message
|
|
||||||
processed_example["messages"][0]["content"].append(
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"image": image_value,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_examples.append(processed_example)
|
|
||||||
|
|
||||||
return processed_examples
|
|
||||||
|
|
||||||
def process_labels(self, input_ids: Tensor) -> Tensor:
|
|
||||||
labels = input_ids.clone()
|
|
||||||
|
|
||||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
|
||||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
|
||||||
|
|
||||||
# Ignore the image token index in the loss computation (model specific)
|
|
||||||
labels[labels == self.image_token_id] = -100
|
|
||||||
|
|
||||||
return labels
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLProcessingStrategy(ProcessingStrategy):
|
|
||||||
"""Processing Strategy class for Qwen2-VL"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
processor: ProcessorMixin,
|
|
||||||
chat_template: Optional[str] = None,
|
|
||||||
image_size: int | tuple[int, int] | None = None,
|
|
||||||
image_resize_algorithm: Resampling | None = None,
|
|
||||||
):
|
|
||||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
|
||||||
self.image_token = "<|image_pad|>" # nosec
|
|
||||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
|
||||||
self.image_token
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3ProcessingStrategy(ProcessingStrategy):
|
|
||||||
"""Processing Strategy class for Gemma3"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
processor: ProcessorMixin,
|
|
||||||
chat_template: Optional[str] = None,
|
|
||||||
image_size: int | tuple[int, int] | None = None,
|
|
||||||
image_resize_algorithm: Resampling | None = None,
|
|
||||||
):
|
|
||||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
|
||||||
self.image_token = processor.tokenizer.special_tokens_map["boi_token"]
|
|
||||||
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
|
||||||
self.image_token
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_labels(self, input_ids):
|
|
||||||
labels = input_ids.clone()
|
|
||||||
|
|
||||||
# Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora
|
|
||||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
|
||||||
labels[labels == self.image_token_id] = -100
|
|
||||||
labels[labels == 262144] = -100 # corresponds to <image_soft_token>
|
|
||||||
|
|
||||||
return labels
|
|
||||||
|
|
||||||
|
|
||||||
def get_processing_strategy(
|
|
||||||
processor: ProcessorMixin,
|
|
||||||
chat_template,
|
|
||||||
chat_template_type,
|
|
||||||
image_size: int | tuple[int, int] | None = None,
|
|
||||||
image_resize_algorithm: Resampling | None = None,
|
|
||||||
):
|
|
||||||
if chat_template_type == "qwen2_vl":
|
|
||||||
return Qwen2VLProcessingStrategy(
|
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
|
||||||
)
|
|
||||||
if chat_template_type == "gemma3":
|
|
||||||
return Gemma3ProcessingStrategy(
|
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
|
||||||
)
|
|
||||||
if chat_template_type in [
|
|
||||||
"llama3_2_vision",
|
|
||||||
"llava",
|
|
||||||
"mistral_v7_tekken",
|
|
||||||
"pixtral",
|
|
||||||
]:
|
|
||||||
return ProcessingStrategy(
|
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
|
||||||
)
|
|
||||||
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
|
|
||||||
@@ -411,15 +411,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
if turn_idx >= len(turns):
|
if turn_idx >= len(turns):
|
||||||
raise ValueError(f"Turn index {turn_idx} out of range")
|
raise ValueError(f"Turn index {turn_idx} out of range")
|
||||||
|
|
||||||
# mistral/gemma3 does not output message if it contains only system message
|
# mistral does not output message if it contains only system message
|
||||||
if (
|
if (
|
||||||
turn_idx == 0
|
turn_idx == 0
|
||||||
and turns[0].get("role") == "system"
|
and turns[0].get("role") == "system"
|
||||||
and (
|
and "mistral" in self.tokenizer.name_or_path.lower()
|
||||||
"mistral" in self.tokenizer.name_or_path.lower()
|
|
||||||
# gemma3 uses gemma tokenizer
|
|
||||||
or "gemma" in self.tokenizer.name_or_path.lower()
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
return -1, -1
|
return -1, -1
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import transformers.modelcard
|
|||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
|
||||||
from peft import PeftConfig, PeftModel
|
from peft import PeftConfig, PeftModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
@@ -170,7 +169,7 @@ def execute_training(
|
|||||||
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
cfg: DictDefault, trainer: Any, resume_from_checkpoint: str | None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute the training process with appropriate SDP kernel configurations.
|
Execute the training process with appropriate backend configurations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
@@ -178,6 +177,9 @@ def execute_training(
|
|||||||
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
|
||||||
"""
|
"""
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
|
if cfg.group_by_length:
|
||||||
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||||
@@ -303,7 +305,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
|||||||
model_card_kwarg["dataset_tags"] = dataset_tags
|
model_card_kwarg["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
trainer.create_model_card(**model_card_kwarg)
|
trainer.create_model_card(**model_card_kwarg)
|
||||||
except (AttributeError, UnicodeDecodeError, OfflineModeIsEnabled):
|
except (AttributeError, UnicodeDecodeError):
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# Defensively push to the hub to ensure the model card is updated
|
# Defensively push to the hub to ensure the model card is updated
|
||||||
@@ -315,7 +317,6 @@ def save_initial_configs(
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
peft_config: PeftConfig | None,
|
peft_config: PeftConfig | None,
|
||||||
processor: ProcessorMixin | None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save initial configurations before training.
|
Save initial configurations before training.
|
||||||
@@ -343,10 +344,6 @@ def save_initial_configs(
|
|||||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||||
model.config.save_pretrained(str(output_dir))
|
model.config.save_pretrained(str(output_dir))
|
||||||
|
|
||||||
if processor:
|
|
||||||
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
|
||||||
processor.save_pretrained(str(output_dir))
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model_card(cfg: DictDefault):
|
def setup_model_card(cfg: DictDefault):
|
||||||
"""
|
"""
|
||||||
@@ -414,7 +411,6 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PeftConfig | None,
|
PeftConfig | None,
|
||||||
ProcessorMixin | None,
|
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
|
||||||
@@ -430,7 +426,6 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
- Model
|
- Model
|
||||||
- Tokenizer
|
- Tokenizer
|
||||||
- PEFT config
|
- PEFT config
|
||||||
- Processor
|
|
||||||
"""
|
"""
|
||||||
# Load tokenizer, processor and model
|
# Load tokenizer, processor and model
|
||||||
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
|
||||||
@@ -461,7 +456,6 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
peft_config,
|
peft_config,
|
||||||
processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -484,7 +478,6 @@ def train(
|
|||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
peft_config,
|
peft_config,
|
||||||
processor,
|
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
# Determine if we need to resume from a checkpoint
|
# Determine if we need to resume from a checkpoint
|
||||||
@@ -500,7 +493,7 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save initial configs
|
# Save initial configs
|
||||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
save_initial_configs(cfg, tokenizer, model, peft_config)
|
||||||
|
|
||||||
# Set up signal handler for graceful termination
|
# Set up signal handler for graceful termination
|
||||||
setup_signal_handler(cfg, model, safe_serialization)
|
setup_signal_handler(cfg, model, safe_serialization)
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,59 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
includes logic for handling sequence parallelism collation.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_position_ids_for_slice(
|
|
||||||
position_ids: torch.Tensor, start_idx: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Adjust position IDs for a sliced sequence to maintain proper relative positions.
|
|
||||||
This handles the case where position IDs might not be contiguous due to sample
|
|
||||||
packing.
|
|
||||||
"""
|
|
||||||
# Convert to tensor if not already
|
|
||||||
# Find the boundaries between samples (where position_ids reset)
|
|
||||||
adjusted_pos_ids = position_ids.clone()
|
|
||||||
|
|
||||||
# Process each sequence in the batch
|
|
||||||
for i in range(position_ids.shape[0]):
|
|
||||||
seq = position_ids[i]
|
|
||||||
|
|
||||||
# Find sample boundaries
|
|
||||||
boundaries = []
|
|
||||||
for j in range(1, len(seq)):
|
|
||||||
if seq[j] < seq[j - 1]:
|
|
||||||
boundaries.append(j)
|
|
||||||
|
|
||||||
# No need to adjust if there are no boundaries or this is a single sample
|
|
||||||
if not boundaries:
|
|
||||||
adjusted_pos_ids[i] = seq - start_idx
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Adjust each segment separately
|
|
||||||
prev_boundary = 0
|
|
||||||
for boundary in boundaries:
|
|
||||||
adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx
|
|
||||||
prev_boundary = boundary
|
|
||||||
|
|
||||||
# Last segment
|
|
||||||
adjusted_pos_ids[i, prev_boundary:] -= start_idx
|
|
||||||
|
|
||||||
return adjusted_pos_ids
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -88,8 +43,6 @@ class DataCollatorForSeq2Seq:
|
|||||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
return_tensors (`str`):
|
return_tensors (`str`):
|
||||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
sequence_parallel_degree (`int`):
|
|
||||||
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
@@ -100,16 +53,6 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
sequence_parallel_degree: int = 1
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.sequence_parallel_degree > 1:
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
|
||||||
|
|
||||||
# Get information about our position in the SP group
|
|
||||||
sp_group = get_ring_attn_group()
|
|
||||||
self.local_rank = dist.get_rank(group=sp_group)
|
|
||||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
labels = None
|
labels = None
|
||||||
@@ -176,43 +119,8 @@ class DataCollatorForSeq2Seq:
|
|||||||
)
|
)
|
||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
if self.sequence_parallel_degree > 1:
|
|
||||||
features = self.apply_sequence_parallelism(features)
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def apply_sequence_parallelism(
|
|
||||||
self, batch: dict[str, torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply sequence parallelism slicing to a batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: Batch dictionary from parent collator.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sliced batch dictionary.
|
|
||||||
"""
|
|
||||||
keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"]
|
|
||||||
|
|
||||||
for key in keys_to_slice:
|
|
||||||
if key in batch:
|
|
||||||
seq_len = batch[key].shape[1]
|
|
||||||
slice_size = seq_len // self.local_world_size
|
|
||||||
start_idx = self.local_rank * slice_size
|
|
||||||
end_idx = (
|
|
||||||
start_idx + slice_size
|
|
||||||
if self.local_rank < self.local_world_size - 1
|
|
||||||
else seq_len
|
|
||||||
)
|
|
||||||
batch[key] = batch[key][:, start_idx:end_idx]
|
|
||||||
|
|
||||||
# Special handling for position_ids
|
|
||||||
if key == "position_ids" and self.local_rank > 0:
|
|
||||||
batch[key] = adjust_position_ids_for_slice(batch[key], start_idx)
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
@@ -240,7 +148,6 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
@@ -270,7 +177,6 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,17 +2,15 @@
|
|||||||
Collators for multi-modal chat messages and packing
|
Collators for multi-modal chat messages and packing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
from PIL import Image
|
||||||
from torch import Tensor
|
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
from transformers.data.data_collator import DataCollatorMixin
|
from transformers.data.data_collator import DataCollatorMixin
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
from axolotl.processing_strategies import ProcessingStrategy
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiModalChatDataCollator(DataCollatorMixin):
|
class MultiModalChatDataCollator(DataCollatorMixin):
|
||||||
@@ -21,9 +19,11 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
processing_strategy: ProcessingStrategy
|
processor: ProcessorMixin
|
||||||
packing: bool = False
|
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
chat_template: Optional[str] = None
|
||||||
|
packing: bool = False
|
||||||
|
max_images: int = -1
|
||||||
padding: Union[bool, str, PaddingStrategy] = True
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
@@ -31,62 +31,162 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
if self.packing:
|
if self.packing:
|
||||||
raise ValueError("Packing is currently not supported.")
|
raise ValueError("Packing is currently not supported.")
|
||||||
|
|
||||||
def torch_call(self, examples: list[dict]) -> dict[str, Any]:
|
def torch_call(
|
||||||
return self.process_rows(examples)
|
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
|
|
||||||
|
return self.__class__.process_rows(
|
||||||
|
examples, self.processor, self.chat_template, self.max_images
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||||
|
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||||
|
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||||
|
|
||||||
|
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||||
|
# use this as a starting point
|
||||||
|
|
||||||
|
def _preprocess(examples: list[dict]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Preprocess conversation examples to ensure consistent format.
|
||||||
|
|
||||||
|
Converts different conversation formats to OpenAI format with 'messages'.
|
||||||
|
Supports two formats:
|
||||||
|
1. OpenAI format with 'messages'
|
||||||
|
2. Legacy format with 'conversations'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: list of conversation dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict in OpenAI format with 'messages' key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the conversation format is not supported
|
||||||
|
"""
|
||||||
|
role_mapping = {
|
||||||
|
"human": "user",
|
||||||
|
"gpt": "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
|
def normalize_role(role: str) -> str:
|
||||||
|
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||||
|
return role_mapping.get(role, role)
|
||||||
|
|
||||||
|
def convert_legacy_format(example: dict) -> dict:
|
||||||
|
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": normalize_role(convo["from"]),
|
||||||
|
"content": convo["value"],
|
||||||
|
}
|
||||||
|
for convo in example["conversations"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create new dict without 'conversations' key
|
||||||
|
result = deepcopy(example)
|
||||||
|
result.pop("conversations")
|
||||||
|
return {"messages": messages, **result}
|
||||||
|
|
||||||
|
processed_examples = []
|
||||||
|
for example in examples:
|
||||||
|
# OpenAI format
|
||||||
|
if "messages" in example:
|
||||||
|
processed_examples.append(example)
|
||||||
|
|
||||||
|
# Legacy format
|
||||||
|
elif "conversations" in example:
|
||||||
|
processed_examples.append(convert_legacy_format(example))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Only `messages` and `conversations` message keys are currently supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_examples
|
||||||
|
|
||||||
|
def _process_images(examples, max_images):
|
||||||
|
"""
|
||||||
|
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: List of dictionaries that may contain 'images' key
|
||||||
|
max_images: Maximum number of images to keep per example (0 means no limit)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either None (if no images) or List[Image objects] (if all examples have images)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there's a mix of None and non-None images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_image(example):
|
||||||
|
if "images" not in example:
|
||||||
|
return None
|
||||||
|
images = example["images"]
|
||||||
|
if isinstance(images, str):
|
||||||
|
return Image.open(images)
|
||||||
|
return images
|
||||||
|
|
||||||
|
images = [get_image(example) for example in examples]
|
||||||
|
|
||||||
|
# Count None and non-None images
|
||||||
|
none_count = sum(1 for img in images if img is None)
|
||||||
|
|
||||||
|
# All images are None
|
||||||
|
if none_count == len(images):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Mix of None and non-None images
|
||||||
|
if none_count > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"All images should be either None or not None. "
|
||||||
|
"Please provide images for all examples or None."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_images limit if specified
|
||||||
|
if max_images > 0:
|
||||||
|
images = [
|
||||||
|
(
|
||||||
|
img_batch[:max_images]
|
||||||
|
if isinstance(img_batch, (list, tuple))
|
||||||
|
else img_batch
|
||||||
|
)
|
||||||
|
for img_batch in images
|
||||||
|
]
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
def process_rows(
|
|
||||||
self,
|
|
||||||
examples: list[dict],
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
# Preprocess the examples
|
# Preprocess the examples
|
||||||
examples = self.processing_strategy(examples)
|
examples = _preprocess(examples)
|
||||||
|
|
||||||
# Initialize batch
|
# Get the texts and images, and apply the chat template
|
||||||
batch: dict[str, Any] = {}
|
texts = [
|
||||||
|
processor.apply_chat_template(
|
||||||
# Process each example
|
example["messages"], chat_template=chat_template, tokenize=False
|
||||||
for example in examples:
|
|
||||||
# Apply chat template to process the example
|
|
||||||
# This method requires transformers>=4.49.0
|
|
||||||
result = self.processing_strategy.processor.apply_chat_template(
|
|
||||||
example["messages"],
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
return_dict=True,
|
|
||||||
chat_template=self.processing_strategy.chat_template,
|
|
||||||
)
|
)
|
||||||
|
for example in examples
|
||||||
|
]
|
||||||
|
|
||||||
# TODO: Check if need handling for len(input_ids) > sequence_len
|
images = _process_images(examples, max_images=max_images)
|
||||||
|
|
||||||
# Add the processed tensors to our batch
|
# Tokenize the texts and process the images
|
||||||
for key in result.keys():
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
if key not in batch:
|
|
||||||
batch[key] = []
|
|
||||||
|
|
||||||
batch[key].append(result[key].squeeze(0))
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||||
|
labels = batch["input_ids"].clone()
|
||||||
# Pad sequences to the same length
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
# Ignore the image token index in the loss computation (model specific)
|
||||||
batch["input_ids"],
|
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||||
batch_first=True,
|
processor.image_token
|
||||||
padding_value=self.tokenizer.pad_token_id,
|
|
||||||
)
|
)
|
||||||
|
labels[labels == image_token_id] = -100
|
||||||
|
batch["labels"] = labels
|
||||||
|
|
||||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
if length_only:
|
||||||
batch["attention_mask"], batch_first=True, padding_value=0
|
return {
|
||||||
)
|
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||||
|
}
|
||||||
# Create the final batch
|
return batch
|
||||||
final_batch = {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Process the labels
|
|
||||||
final_batch["labels"] = self.processing_strategy.process_labels(
|
|
||||||
final_batch["input_ids"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return final_batch
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from axolotl.integrations.base import PluginManager
|
|||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config
|
from axolotl.utils.models import load_model_config
|
||||||
from axolotl.utils.schemas.config import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
@@ -125,9 +125,6 @@ def normalize_config(cfg):
|
|||||||
with open(ds_config_path, encoding="utf-8") as f:
|
with open(ds_config_path, encoding="utf-8") as f:
|
||||||
cfg.deepspeed = json.load(f)
|
cfg.deepspeed = json.load(f)
|
||||||
|
|
||||||
if cfg.sequence_parallel_degree is None:
|
|
||||||
cfg.sequence_parallel_degree = 1
|
|
||||||
|
|
||||||
if cfg.saves_per_epoch:
|
if cfg.saves_per_epoch:
|
||||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
if save_steps < 1.0: # prevent saves on every step
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
@@ -158,7 +155,7 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
cfg.is_multimodal = (
|
cfg.is_multimodal = (
|
||||||
hasattr(model_config, "model_type")
|
hasattr(model_config, "model_type")
|
||||||
and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING
|
and model_config.model_type in ["llava", "mllama"]
|
||||||
or any(
|
or any(
|
||||||
multimodal_name in cfg.base_model.lower()
|
multimodal_name in cfg.base_model.lower()
|
||||||
for multimodal_name in [
|
for multimodal_name in [
|
||||||
@@ -171,6 +168,7 @@ def normalize_config(cfg):
|
|||||||
cfg.processor_config = (
|
cfg.processor_config = (
|
||||||
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
||||||
)
|
)
|
||||||
|
model_config = model_config.text_config
|
||||||
|
|
||||||
cfg.model_config_type = model_config.model_type
|
cfg.model_config_type = model_config.model_type
|
||||||
|
|
||||||
|
|||||||
@@ -6,12 +6,8 @@ from pathlib import Path
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.errors import (
|
from huggingface_hub.errors import HFValidationError
|
||||||
HFValidationError,
|
|
||||||
RepositoryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -74,25 +70,20 @@ def load_dataset_w_config(
|
|||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
|
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||||
try:
|
try:
|
||||||
# this is just a basic check to see if the path is a
|
# this is just a basic check to see if the path is a
|
||||||
# valid HF dataset that's loadable
|
# valid HF dataset that's loadable
|
||||||
snapshot_download(
|
load_dataset(
|
||||||
repo_id=config_dataset.path,
|
config_dataset.path,
|
||||||
repo_type="dataset",
|
name=config_dataset.name,
|
||||||
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
ignore_patterns=["*"],
|
trust_remote_code=ds_trust_remote_code,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||||
RepositoryNotFoundError,
|
|
||||||
RevisionNotFoundError,
|
|
||||||
FileNotFoundError,
|
|
||||||
ConnectionError,
|
|
||||||
HFValidationError,
|
|
||||||
ValueError,
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ds_from_cloud = False
|
ds_from_cloud = False
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import types
|
import types
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
@@ -25,7 +25,7 @@ from peft import (
|
|||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import (
|
from transformers import ( # noqa: F401
|
||||||
AddedToken,
|
AddedToken,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -34,17 +34,12 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
AwqConfig,
|
AwqConfig,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
Gemma3ForConditionalGeneration,
|
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
Mistral3ForConditionalGeneration,
|
|
||||||
MllamaForConditionalGeneration,
|
MllamaForConditionalGeneration,
|
||||||
PretrainedConfig,
|
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
|
||||||
Qwen2VLForConditionalGeneration,
|
|
||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import (
|
from transformers.integrations.deepspeed import (
|
||||||
HfTrainerDeepSpeedConfig,
|
HfTrainerDeepSpeedConfig,
|
||||||
@@ -72,16 +67,7 @@ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrap
|
|||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
|
||||||
"mllama": MllamaForConditionalGeneration,
|
|
||||||
"llava": LlavaForConditionalGeneration,
|
|
||||||
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
|
||||||
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
|
||||||
"mistral3": Mistral3ForConditionalGeneration,
|
|
||||||
"gemma3": Gemma3ForConditionalGeneration,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# copied from accelerator.FullyShardedDataParallelPlugin
|
# copied from accelerator.FullyShardedDataParallelPlugin
|
||||||
@@ -108,30 +94,9 @@ def get_module_class_from_name(module, name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||||
# Set use_cache to False
|
|
||||||
if hasattr(model_config, "use_cache"):
|
|
||||||
model_config.use_cache = False
|
|
||||||
|
|
||||||
if cfg.is_multimodal:
|
if cfg.is_multimodal:
|
||||||
# For multimodal configs, use_cache is set in the text_config
|
model_config = model_config.text_config
|
||||||
if hasattr(model_config, "get_text_config"):
|
|
||||||
text_config = model_config.get_text_config()
|
|
||||||
if hasattr(text_config, "use_cache"):
|
|
||||||
text_config.use_cache = False
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No text config found for multimodal model. Please raise an Issue with model details."
|
|
||||||
)
|
|
||||||
|
|
||||||
# check if image_size is not set and load image size from model config if available
|
|
||||||
if (
|
|
||||||
cfg.image_size is None
|
|
||||||
and hasattr(model_config, "vision_config")
|
|
||||||
and hasattr(model_config.vision_config, "image_size")
|
|
||||||
):
|
|
||||||
cfg.image_size = model_config.vision_config.image_size
|
|
||||||
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
|
||||||
|
|
||||||
quant_config_exists = (
|
quant_config_exists = (
|
||||||
hasattr(model_config, "quantization_config")
|
hasattr(model_config, "quantization_config")
|
||||||
@@ -470,31 +435,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attempt to load image size from processor if available
|
|
||||||
if (
|
|
||||||
cfg.image_size is None
|
|
||||||
and hasattr(processor, "size")
|
|
||||||
and any(dim in processor.size for dim in ["width", "height"])
|
|
||||||
):
|
|
||||||
im_width = None
|
|
||||||
im_height = None
|
|
||||||
if "width" in processor.size:
|
|
||||||
im_width = processor.size["width"]
|
|
||||||
if "height" in processor.size:
|
|
||||||
im_height = processor.size["height"]
|
|
||||||
|
|
||||||
# If both width and height are set, use a tuple
|
|
||||||
if im_width is not None and im_height is not None:
|
|
||||||
cfg.image_size = (im_width, im_height)
|
|
||||||
# If only width is set, use as integer
|
|
||||||
elif im_width is not None:
|
|
||||||
cfg.image_size = im_width
|
|
||||||
# If only height is set, use as integer
|
|
||||||
elif im_height is not None:
|
|
||||||
cfg.image_size = im_height
|
|
||||||
|
|
||||||
LOG.debug(f"Loaded image size: {cfg.image_size} from processor")
|
|
||||||
|
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
@@ -531,8 +471,12 @@ class ModelLoader:
|
|||||||
|
|
||||||
# init model config
|
# init model config
|
||||||
self.model_config = load_model_config(cfg)
|
self.model_config = load_model_config(cfg)
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
self.text_model_config = self.model_config.text_config
|
||||||
|
else:
|
||||||
|
self.text_model_config = self.model_config
|
||||||
|
|
||||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
# load any patches from plugins
|
# load any patches from plugins
|
||||||
@@ -603,14 +547,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
|
||||||
|
|
||||||
# Initialize ring attn for sequence parallelism. This must be done after
|
|
||||||
# model init but before the first forward pass, since it modifies flash
|
|
||||||
# attn to use ring comm for SP training across multiple GPUs.
|
|
||||||
register_ring_attn(self.cfg.sequence_parallel_degree)
|
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||||
@@ -667,7 +603,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
def patch_llama_derived_model(self):
|
def patch_llama_derived_model(self) -> None:
|
||||||
"""Modify all llama derived models in one block"""
|
"""Modify all llama derived models in one block"""
|
||||||
self.patch_loss_llama()
|
self.patch_loss_llama()
|
||||||
|
|
||||||
@@ -717,16 +653,25 @@ class ModelLoader:
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_auto_model_loader(self):
|
def set_auto_model_loader(self) -> None:
|
||||||
"""
|
"""set self.AutoModelLoader
|
||||||
Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM`
|
- default value: AutoModelForCausalLM (set at __init__)
|
||||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
- when using a multi modality model, self.AutoModelLoader should
|
||||||
should be set according to the type of the model.
|
be set according to model type of the model
|
||||||
"""
|
"""
|
||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
if self.model_config.model_type == "llava":
|
||||||
self.model_config.model_type, AutoModelForVision2Seq
|
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||||
)
|
LlavaForConditionalGeneration
|
||||||
|
)
|
||||||
|
elif self.model_config.model_type == "mllama":
|
||||||
|
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||||
|
MllamaForConditionalGeneration
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.AutoModelLoader = (
|
||||||
|
AutoModelForVision2Seq # pylint: disable=invalid-name
|
||||||
|
)
|
||||||
|
|
||||||
def set_device_map_config(self) -> None:
|
def set_device_map_config(self) -> None:
|
||||||
device_map = self.cfg.device_map
|
device_map = self.cfg.device_map
|
||||||
@@ -750,7 +695,7 @@ class ModelLoader:
|
|||||||
from accelerate import infer_auto_device_map
|
from accelerate import infer_auto_device_map
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model_canvas = self.auto_model_loader.from_config(
|
model_canvas = self.AutoModelLoader.from_config(
|
||||||
self.model_config,
|
self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
)
|
)
|
||||||
@@ -947,6 +892,8 @@ class ModelLoader:
|
|||||||
quantization_config = (
|
quantization_config = (
|
||||||
quantization_config or self.model_kwargs["quantization_config"]
|
quantization_config or self.model_kwargs["quantization_config"]
|
||||||
)
|
)
|
||||||
|
if self.cfg.is_multimodal:
|
||||||
|
self.model_config.text_config = self.text_model_config
|
||||||
self.model = load_sharded_model_quant(
|
self.model = load_sharded_model_quant(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
@@ -967,26 +914,13 @@ class ModelLoader:
|
|||||||
|
|
||||||
_ = _configure_zero3_memory_efficient_loading()
|
_ = _configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
# Load model with random initialization if specified
|
if self.cfg.is_multimodal:
|
||||||
if self.cfg.random_init_weights:
|
self.model_config.text_config = self.text_model_config
|
||||||
# AutoModel classes support the from_config method
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
if self.auto_model_loader in [
|
self.base_model,
|
||||||
AutoModelForCausalLM,
|
config=self.model_config,
|
||||||
AutoModelForVision2Seq,
|
**self.model_kwargs,
|
||||||
]:
|
)
|
||||||
self.model = self.auto_model_loader.from_config(
|
|
||||||
config=self.model_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = self.auto_model_loader(
|
|
||||||
config=self.model_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO (MengqingCao) split these patches seperately
|
# TODO (MengqingCao) split these patches seperately
|
||||||
if self.cfg.flash_attention and not self.inference:
|
if self.cfg.flash_attention and not self.inference:
|
||||||
@@ -1021,8 +955,10 @@ class ModelLoader:
|
|||||||
and self.model_type != "AutoModelForCausalLM"
|
and self.model_type != "AutoModelForCausalLM"
|
||||||
and not self.cfg.trust_remote_code
|
and not self.cfg.trust_remote_code
|
||||||
):
|
):
|
||||||
|
if self.cfg.is_multimodal:
|
||||||
|
self.model_config.text_config = self.text_model_config
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1036,8 +972,26 @@ class ModelLoader:
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||||
|
# when training starts
|
||||||
|
if (
|
||||||
|
hasattr(self.text_model_config, "max_seq_len")
|
||||||
|
and self.text_model_config.max_seq_len
|
||||||
|
and self.cfg.sequence_len > self.text_model_config.max_seq_len
|
||||||
|
):
|
||||||
|
self.text_model_config.max_seq_len = self.cfg.sequence_len
|
||||||
|
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
||||||
|
elif (
|
||||||
|
hasattr(self.text_model_config, "max_sequence_length")
|
||||||
|
and self.text_model_config.max_sequence_length
|
||||||
|
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
|
||||||
|
):
|
||||||
|
self.text_model_config.max_sequence_length = self.cfg.sequence_len
|
||||||
|
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
if self.cfg.is_multimodal:
|
||||||
|
self.model_config.text_config = self.text_model_config
|
||||||
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1055,7 +1009,9 @@ class ModelLoader:
|
|||||||
|
|
||||||
_ = _configure_zero3_memory_efficient_loading()
|
_ = _configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
if self.cfg.is_multimodal:
|
||||||
|
self.model_config.text_config = self.text_model_config
|
||||||
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
@@ -1218,9 +1174,7 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
resize_kwargs = {}
|
resize_kwargs = {}
|
||||||
if self.cfg.mean_resizing_embeddings is not None and not (
|
if self.cfg.mean_resizing_embeddings is not None:
|
||||||
self.model_config.model_type == "llava"
|
|
||||||
):
|
|
||||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -1319,6 +1273,8 @@ class ModelLoader:
|
|||||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||||
if len(requires_grad) == 0:
|
if len(requires_grad) == 0:
|
||||||
LOG.warning("there are no parameters that require gradient updates")
|
LOG.warning("there are no parameters that require gradient updates")
|
||||||
|
if hasattr(self.model, "config"):
|
||||||
|
self.model.config.use_cache = False
|
||||||
|
|
||||||
if self.cfg.flash_optimum:
|
if self.cfg.flash_optimum:
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
@@ -1351,7 +1307,7 @@ def load_model(
|
|||||||
"""
|
"""
|
||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
model_loader = ModelLoader(
|
loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
@@ -1359,7 +1315,7 @@ def load_model(
|
|||||||
reference_model=reference_model,
|
reference_model=reference_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return model_loader.load_model()
|
return loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2024 Nikhil Vyas
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
@@ -1,495 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# Copied from https://github.com/nikhilvyas/SOAP
|
|
||||||
from itertools import chain
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.optim as optim
|
|
||||||
|
|
||||||
# Parts of the code are modifications of Pytorch's AdamW optimizer
|
|
||||||
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
|
|
||||||
|
|
||||||
|
|
||||||
class SOAP(optim.Optimizer):
|
|
||||||
"""
|
|
||||||
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
params (`Iterable[nn.parameter.Parameter]`):
|
|
||||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
|
||||||
lr (`float`, *optional*, defaults to 0.003):
|
|
||||||
The learning rate to use.
|
|
||||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
|
|
||||||
Adam's betas parameters (b1, b2).
|
|
||||||
shampoo_beta (`float`, *optional*, defaults to -1):
|
|
||||||
If >= 0, use this beta for the preconditioner (L and R in paper, state["GG"] below) moving average instead of betas[1].
|
|
||||||
eps (`float`, *optional*, defaults to 1e-08):
|
|
||||||
Adam's epsilon for numerical stability.
|
|
||||||
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
|
|
||||||
precondition_frequency (`int`, *optional*, defaults to 10):
|
|
||||||
How often to update the preconditioner.
|
|
||||||
max_precond_dim (`int`, *optional*, defaults to 10000):
|
|
||||||
Maximum dimension of the preconditioner.
|
|
||||||
Set to 10000, so that we exclude most common vocab sizes while including layers.
|
|
||||||
merge_dims (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to merge dimensions of the preconditioner.
|
|
||||||
precondition_1d (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to precondition 1D gradients.
|
|
||||||
normalize_grads (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to normalize gradients per layer.
|
|
||||||
Helps at large precondition_frequency (~100 in our experiments),
|
|
||||||
but hurts performance at small precondition_frequency (~10 in our experiments).
|
|
||||||
data_format (`str`, *optional*, defaults to `channels_first`):
|
|
||||||
Data format of the input for convolutional layers.
|
|
||||||
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
|
|
||||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to use bias correction in Adam.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr: float = 3e-3,
|
|
||||||
betas=(0.95, 0.95),
|
|
||||||
shampoo_beta: float = -1,
|
|
||||||
eps: float = 1e-8,
|
|
||||||
weight_decay: float = 0.01,
|
|
||||||
precondition_frequency: int = 10,
|
|
||||||
max_precond_dim: int = 10000, #
|
|
||||||
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
|
|
||||||
precondition_1d: bool = False,
|
|
||||||
normalize_grads: bool = False,
|
|
||||||
data_format: str = "channels_first",
|
|
||||||
correct_bias: bool = True,
|
|
||||||
):
|
|
||||||
defaults = {
|
|
||||||
"lr": lr,
|
|
||||||
"betas": betas,
|
|
||||||
"shampoo_beta": shampoo_beta,
|
|
||||||
"eps": eps,
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
"precondition_frequency": precondition_frequency,
|
|
||||||
"max_precond_dim": max_precond_dim,
|
|
||||||
"merge_dims": merge_dims,
|
|
||||||
"precondition_1d": precondition_1d,
|
|
||||||
"normalize_grads": normalize_grads,
|
|
||||||
"correct_bias": correct_bias,
|
|
||||||
}
|
|
||||||
super().__init__(params, defaults)
|
|
||||||
self._data_format = data_format
|
|
||||||
|
|
||||||
def merge_dims(self, grad, max_precond_dim):
|
|
||||||
"""
|
|
||||||
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
|
|
||||||
"""
|
|
||||||
assert self._data_format in ["channels_first", "channels_last"]
|
|
||||||
if self._data_format == "channels_last" and grad.dim() == 4:
|
|
||||||
grad = grad.permute(0, 3, 1, 2)
|
|
||||||
shape = grad.shape
|
|
||||||
new_shape = []
|
|
||||||
|
|
||||||
curr_shape = 1
|
|
||||||
for sh in shape:
|
|
||||||
temp_shape = curr_shape * sh
|
|
||||||
if temp_shape > max_precond_dim:
|
|
||||||
if curr_shape > 1:
|
|
||||||
new_shape.append(curr_shape)
|
|
||||||
curr_shape = sh
|
|
||||||
else:
|
|
||||||
new_shape.append(sh)
|
|
||||||
curr_shape = 1
|
|
||||||
else:
|
|
||||||
curr_shape = temp_shape
|
|
||||||
|
|
||||||
if curr_shape > 1 or len(new_shape) == 0:
|
|
||||||
new_shape.append(curr_shape)
|
|
||||||
|
|
||||||
new_grad = grad.reshape(new_shape)
|
|
||||||
return new_grad
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self, closure=None):
|
|
||||||
"""
|
|
||||||
Performs a single optimization step.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
|
||||||
"""
|
|
||||||
if closure is None:
|
|
||||||
loss = None
|
|
||||||
else:
|
|
||||||
loss = closure()
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group["params"]:
|
|
||||||
if p.grad is None:
|
|
||||||
continue
|
|
||||||
grad = p.grad
|
|
||||||
|
|
||||||
state = self.state[p]
|
|
||||||
|
|
||||||
if "step" not in state:
|
|
||||||
state["step"] = 0
|
|
||||||
|
|
||||||
# State initialization
|
|
||||||
if "exp_avg" not in state:
|
|
||||||
# Exponential moving average of gradient values
|
|
||||||
state["exp_avg"] = torch.zeros_like(grad)
|
|
||||||
# Exponential moving average of squared gradient values
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
|
||||||
|
|
||||||
if "Q" not in state:
|
|
||||||
self.init_preconditioner(
|
|
||||||
grad,
|
|
||||||
state,
|
|
||||||
precondition_frequency=group["precondition_frequency"],
|
|
||||||
precondition_1d=group["precondition_1d"],
|
|
||||||
shampoo_beta=(
|
|
||||||
group["shampoo_beta"]
|
|
||||||
if group["shampoo_beta"] >= 0
|
|
||||||
else group["betas"][1]
|
|
||||||
),
|
|
||||||
max_precond_dim=group["max_precond_dim"],
|
|
||||||
merge_dims=group["merge_dims"],
|
|
||||||
)
|
|
||||||
self.update_preconditioner(
|
|
||||||
grad,
|
|
||||||
state,
|
|
||||||
max_precond_dim=group["max_precond_dim"],
|
|
||||||
merge_dims=group["merge_dims"],
|
|
||||||
precondition_1d=group["precondition_1d"],
|
|
||||||
)
|
|
||||||
continue # first step is skipped so that we never use the current gradients in the projection.
|
|
||||||
|
|
||||||
# Projecting gradients to the eigenbases of Shampoo's preconditioner
|
|
||||||
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
|
||||||
grad_projected = self.project(
|
|
||||||
grad,
|
|
||||||
state,
|
|
||||||
merge_dims=group["merge_dims"],
|
|
||||||
max_precond_dim=group["max_precond_dim"],
|
|
||||||
)
|
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
|
||||||
beta1, beta2 = group["betas"]
|
|
||||||
|
|
||||||
state["step"] += 1
|
|
||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
|
||||||
# In-place operations to update the averages at the same time
|
|
||||||
exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
|
|
||||||
exp_avg_sq.mul_(beta2).add_(
|
|
||||||
grad_projected.square(), alpha=(1.0 - beta2)
|
|
||||||
)
|
|
||||||
|
|
||||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
|
||||||
|
|
||||||
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
|
|
||||||
# i.e. projecting to the eigenbases of matrices in state["GG"]
|
|
||||||
# exp_avg_projected = self.project(
|
|
||||||
# exp_avg,
|
|
||||||
# state,
|
|
||||||
# merge_dims=group["merge_dims"],
|
|
||||||
# max_precond_dim=group["max_precond_dim"],
|
|
||||||
# )
|
|
||||||
exp_avg_projected = exp_avg
|
|
||||||
|
|
||||||
step_size = group["lr"]
|
|
||||||
if group["correct_bias"]:
|
|
||||||
bias_correction1 = 1.0 - beta1 ** (state["step"])
|
|
||||||
bias_correction2 = 1.0 - beta2 ** (state["step"])
|
|
||||||
step_size = step_size * (bias_correction2**0.5) / bias_correction1
|
|
||||||
|
|
||||||
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
|
|
||||||
# to the original space
|
|
||||||
norm_grad = self.project_back(
|
|
||||||
exp_avg_projected / denom,
|
|
||||||
state,
|
|
||||||
merge_dims=group["merge_dims"],
|
|
||||||
max_precond_dim=group["max_precond_dim"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if group["normalize_grads"]:
|
|
||||||
norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5)
|
|
||||||
|
|
||||||
p.add_(norm_grad, alpha=-step_size)
|
|
||||||
|
|
||||||
# From AdamW code: Just adding the square of the weights to the loss function is *not*
|
|
||||||
# the correct way of using L2 regularization/weight decay with Adam,
|
|
||||||
# since that will interact with the m and v parameters in strange ways.
|
|
||||||
#
|
|
||||||
# Instead we want to decay the weights in a manner that doesn't interact
|
|
||||||
# with the m/v parameters. This is equivalent to adding the square
|
|
||||||
# of the weights to the loss with plain (non-momentum) SGD.
|
|
||||||
# Add weight decay at the end (fixed version)
|
|
||||||
if group["weight_decay"] > 0.0:
|
|
||||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
|
||||||
|
|
||||||
# Update is done after the gradient step to avoid using current gradients in the projection.
|
|
||||||
self.update_preconditioner(
|
|
||||||
grad,
|
|
||||||
state,
|
|
||||||
max_precond_dim=group["max_precond_dim"],
|
|
||||||
merge_dims=group["merge_dims"],
|
|
||||||
precondition_1d=group["precondition_1d"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def init_preconditioner(
|
|
||||||
self,
|
|
||||||
grad,
|
|
||||||
state,
|
|
||||||
precondition_frequency=10,
|
|
||||||
shampoo_beta=0.95,
|
|
||||||
max_precond_dim=10000,
|
|
||||||
precondition_1d=False,
|
|
||||||
merge_dims=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initializes the preconditioner matrices (L and R in the paper).
|
|
||||||
"""
|
|
||||||
state["GG"] = (
|
|
||||||
[]
|
|
||||||
) # Will hold all the preconditioner matrices (L and R in the paper).
|
|
||||||
if grad.dim() == 1:
|
|
||||||
if not precondition_1d or grad.shape[0] > max_precond_dim:
|
|
||||||
state["GG"].append([])
|
|
||||||
else:
|
|
||||||
state["GG"].append(
|
|
||||||
torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if merge_dims:
|
|
||||||
grad = self.merge_dims(grad, max_precond_dim)
|
|
||||||
|
|
||||||
for sh in grad.shape:
|
|
||||||
if sh > max_precond_dim:
|
|
||||||
state["GG"].append([])
|
|
||||||
else:
|
|
||||||
state["GG"].append(torch.zeros(sh, sh, device=grad.device))
|
|
||||||
|
|
||||||
state["Q"] = None # Will hold all the eigenbases of the preconditioner.
|
|
||||||
state["precondition_frequency"] = precondition_frequency
|
|
||||||
state["shampoo_beta"] = shampoo_beta
|
|
||||||
|
|
||||||
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
|
||||||
"""
|
|
||||||
Projects the gradient to the eigenbases of the preconditioner.
|
|
||||||
"""
|
|
||||||
original_shape = grad.shape
|
|
||||||
if merge_dims:
|
|
||||||
if grad.dim() == 4 and self._data_format == "channels_last":
|
|
||||||
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
|
||||||
grad = self.merge_dims(grad, max_precond_dim)
|
|
||||||
|
|
||||||
for mat in state["Q"]:
|
|
||||||
if len(mat) > 0:
|
|
||||||
grad = torch.tensordot(
|
|
||||||
grad,
|
|
||||||
mat,
|
|
||||||
dims=[[0], [0]],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
permute_order = list(range(1, len(grad.shape))) + [0]
|
|
||||||
grad = grad.permute(permute_order)
|
|
||||||
|
|
||||||
if merge_dims:
|
|
||||||
if self._data_format == "channels_last" and len(original_shape) == 4:
|
|
||||||
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
|
||||||
else:
|
|
||||||
grad = grad.reshape(original_shape)
|
|
||||||
return grad
|
|
||||||
|
|
||||||
def update_preconditioner(
|
|
||||||
self,
|
|
||||||
grad,
|
|
||||||
state,
|
|
||||||
max_precond_dim=10000,
|
|
||||||
merge_dims=False,
|
|
||||||
precondition_1d=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
|
|
||||||
"""
|
|
||||||
if state["Q"] is not None:
|
|
||||||
state["exp_avg"] = self.project_back(
|
|
||||||
state["exp_avg"],
|
|
||||||
state,
|
|
||||||
merge_dims=merge_dims,
|
|
||||||
max_precond_dim=max_precond_dim,
|
|
||||||
)
|
|
||||||
if grad.dim() == 1:
|
|
||||||
if precondition_1d and grad.shape[0] <= max_precond_dim:
|
|
||||||
state["GG"][0].lerp_(
|
|
||||||
grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if merge_dims:
|
|
||||||
new_grad = self.merge_dims(grad, max_precond_dim)
|
|
||||||
for idx, sh in enumerate(new_grad.shape):
|
|
||||||
if sh <= max_precond_dim:
|
|
||||||
outer_product = torch.tensordot(
|
|
||||||
new_grad,
|
|
||||||
new_grad,
|
|
||||||
dims=[
|
|
||||||
[
|
|
||||||
*chain(
|
|
||||||
range(idx), range(idx + 1, len(new_grad.shape))
|
|
||||||
)
|
|
||||||
]
|
|
||||||
]
|
|
||||||
* 2,
|
|
||||||
)
|
|
||||||
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
|
||||||
else:
|
|
||||||
for idx, sh in enumerate(grad.shape):
|
|
||||||
if sh <= max_precond_dim:
|
|
||||||
outer_product = torch.tensordot(
|
|
||||||
grad,
|
|
||||||
grad,
|
|
||||||
# Contracts across all dimensions except for k.
|
|
||||||
dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]]
|
|
||||||
* 2,
|
|
||||||
)
|
|
||||||
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
|
|
||||||
|
|
||||||
if state["Q"] is None:
|
|
||||||
state["Q"] = self.get_orthogonal_matrix(state["GG"])
|
|
||||||
if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0:
|
|
||||||
state["Q"] = self.get_orthogonal_matrix_QR(
|
|
||||||
state, max_precond_dim, merge_dims
|
|
||||||
)
|
|
||||||
# state["Q"] = self.get_fast_QR(state, max_precond_dim, merge_dims)
|
|
||||||
|
|
||||||
if state["step"] > 0:
|
|
||||||
state["exp_avg"] = self.project(
|
|
||||||
state["exp_avg"],
|
|
||||||
state,
|
|
||||||
merge_dims=merge_dims,
|
|
||||||
max_precond_dim=max_precond_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
|
|
||||||
"""
|
|
||||||
Projects the gradient back to the original space.
|
|
||||||
"""
|
|
||||||
original_shape = grad.shape
|
|
||||||
if merge_dims:
|
|
||||||
if self._data_format == "channels_last" and grad.dim() == 4:
|
|
||||||
permuted_shape = grad.permute(0, 3, 1, 2).shape
|
|
||||||
grad = self.merge_dims(grad, max_precond_dim)
|
|
||||||
for mat in state["Q"]:
|
|
||||||
if len(mat) > 0:
|
|
||||||
grad = torch.tensordot(
|
|
||||||
grad,
|
|
||||||
mat,
|
|
||||||
dims=[[0], [1]],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
permute_order = list(range(1, len(grad.shape))) + [0]
|
|
||||||
grad = grad.permute(permute_order)
|
|
||||||
|
|
||||||
if merge_dims:
|
|
||||||
if self._data_format == "channels_last" and len(original_shape) == 4:
|
|
||||||
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
|
|
||||||
else:
|
|
||||||
grad = grad.reshape(original_shape)
|
|
||||||
return grad
|
|
||||||
|
|
||||||
def get_orthogonal_matrix(self, mat):
|
|
||||||
"""
|
|
||||||
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
|
|
||||||
"""
|
|
||||||
matrix = []
|
|
||||||
for m in mat:
|
|
||||||
if len(m) == 0:
|
|
||||||
matrix.append([])
|
|
||||||
continue
|
|
||||||
if m.data.dtype != torch.float:
|
|
||||||
float_data = False
|
|
||||||
original_type = m.data.dtype
|
|
||||||
original_device = m.data.device
|
|
||||||
matrix.append(m.data.float())
|
|
||||||
else:
|
|
||||||
float_data = True
|
|
||||||
matrix.append(m.data)
|
|
||||||
|
|
||||||
final = []
|
|
||||||
for m in matrix:
|
|
||||||
if len(m) == 0:
|
|
||||||
final.append([])
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
_, Q = torch.linalg.eigh(
|
|
||||||
m + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
|
||||||
)
|
|
||||||
except: # pylint: disable=bare-except # noqa: E722
|
|
||||||
_, Q = torch.linalg.eigh(
|
|
||||||
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)
|
|
||||||
)
|
|
||||||
Q = Q.to(m.dtype)
|
|
||||||
Q = torch.flip(Q, [1])
|
|
||||||
|
|
||||||
if not float_data:
|
|
||||||
Q = Q.to(original_device).type(original_type)
|
|
||||||
final.append(Q)
|
|
||||||
return final
|
|
||||||
|
|
||||||
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
|
|
||||||
"""
|
|
||||||
Computes the eigenbases of the preconditioner using one round of power iteration
|
|
||||||
followed by torch.linalg.qr decomposition.
|
|
||||||
"""
|
|
||||||
precond_list = state["GG"]
|
|
||||||
orth_list = state["Q"]
|
|
||||||
|
|
||||||
matrix = []
|
|
||||||
orth_matrix = []
|
|
||||||
for m, o in zip(precond_list, orth_list):
|
|
||||||
if len(m) == 0:
|
|
||||||
matrix.append([])
|
|
||||||
orth_matrix.append([])
|
|
||||||
continue
|
|
||||||
if m.data.dtype != torch.float:
|
|
||||||
float_data = False
|
|
||||||
original_type = m.data.dtype
|
|
||||||
original_device = m.data.device
|
|
||||||
matrix.append(m.data.float())
|
|
||||||
orth_matrix.append(o.data.float())
|
|
||||||
else:
|
|
||||||
float_data = True
|
|
||||||
matrix.append(m.data.float())
|
|
||||||
orth_matrix.append(o.data.float())
|
|
||||||
|
|
||||||
orig_shape = state["exp_avg_sq"].shape
|
|
||||||
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
|
||||||
permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape
|
|
||||||
if merge_dims:
|
|
||||||
exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim)
|
|
||||||
else:
|
|
||||||
exp_avg_sq = state["exp_avg_sq"]
|
|
||||||
|
|
||||||
final = []
|
|
||||||
for ind, (m, o) in enumerate(zip(matrix, orth_matrix)):
|
|
||||||
if len(m) == 0:
|
|
||||||
final.append([])
|
|
||||||
continue
|
|
||||||
est_eig = torch.diag(o.T @ m @ o)
|
|
||||||
sort_idx = torch.argsort(est_eig, descending=True)
|
|
||||||
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
|
|
||||||
o = o[:, sort_idx]
|
|
||||||
power_iter = m @ o
|
|
||||||
Q, _ = torch.linalg.qr(power_iter)
|
|
||||||
|
|
||||||
if not float_data:
|
|
||||||
Q = Q.to(original_device).type(original_type)
|
|
||||||
final.append(Q)
|
|
||||||
|
|
||||||
if merge_dims:
|
|
||||||
if self._data_format == "channels_last" and len(orig_shape) == 4:
|
|
||||||
exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
|
|
||||||
else:
|
|
||||||
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
|
|
||||||
|
|
||||||
state["exp_avg_sq"] = exp_avg_sq
|
|
||||||
return final
|
|
||||||
@@ -104,7 +104,9 @@ def allocate(
|
|||||||
|
|
||||||
|
|
||||||
class MultipackBatchSampler(BatchSampler):
|
class MultipackBatchSampler(BatchSampler):
|
||||||
"""Batch sampler class for multipack"""
|
"""
|
||||||
|
Batch Sampler class for multipack
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Module with Pydantic models for configuration."""
|
"""Main Axolotl input configuration Pydantic models"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
@@ -42,7 +42,6 @@ from axolotl.utils.schemas.model import (
|
|||||||
ModelOutputConfig,
|
ModelOutputConfig,
|
||||||
SpecialTokensConfig,
|
SpecialTokensConfig,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
|
||||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
@@ -65,7 +64,6 @@ class AxolotlInputConfig(
|
|||||||
LISAConfig,
|
LISAConfig,
|
||||||
GradioConfig,
|
GradioConfig,
|
||||||
RayConfig,
|
RayConfig,
|
||||||
MultiModalConfig,
|
|
||||||
RemappedParameters,
|
RemappedParameters,
|
||||||
DeprecatedParameters,
|
DeprecatedParameters,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -247,8 +245,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
val_set_size: float | None = Field(default=0.0)
|
val_set_size: float | None = Field(default=0.0)
|
||||||
|
|
||||||
sequence_parallel_degree: int | None = None
|
|
||||||
|
|
||||||
special_tokens: SpecialTokensConfig | None = None
|
special_tokens: SpecialTokensConfig | None = None
|
||||||
tokens: list[str] | None = None
|
tokens: list[str] | None = None
|
||||||
added_tokens_overrides: dict[int, str] | None = None
|
added_tokens_overrides: dict[int, str] | None = None
|
||||||
@@ -1106,29 +1102,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@field_validator("sequence_parallel_degree", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_sequence_parallel_config(cls, value, info):
|
|
||||||
if not value:
|
|
||||||
value = 1
|
|
||||||
|
|
||||||
if value > 1:
|
|
||||||
if not info.data.get("flash_attention"):
|
|
||||||
raise ValueError(
|
|
||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
|
||||||
except ImportError as exception:
|
|
||||||
raise ImportError(
|
|
||||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
|
||||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
|
||||||
"or `pip install ring-flash-attn>=0.1.4`."
|
|
||||||
) from exception
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ class ChatTemplate(str, Enum):
|
|||||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||||
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
|
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
@@ -37,10 +36,6 @@ class ChatTemplate(str, Enum):
|
|||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
exaone = "exaone" # pylint: disable=invalid-name
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
metharme = "metharme" # pylint: disable=invalid-name
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
pixtral = "pixtral" # pylint: disable=invalid-name
|
|
||||||
llava = "llava" # pylint: disable=invalid-name
|
|
||||||
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
|
||||||
gemma3 = "gemma3" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class CustomSupportedOptimizers(str, Enum):
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
@@ -52,4 +47,3 @@ class CustomSupportedOptimizers(str, Enum):
|
|||||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
muon = "muon" # pylint: disable=invalid-name
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
soap = "soap" # pylint: disable=invalid-name
|
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
"""Pydantic models for multimodal-related configuration"""
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from PIL.Image import Resampling
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalConfig(BaseModel):
|
|
||||||
"""Multi-modal configuration subset"""
|
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": (
|
|
||||||
"The size of the image to resize to. It can be an integer (resized into padded-square image) or a tuple (width, height)."
|
|
||||||
"If not provided, we will attempt to load from preprocessor.size, otherwise, images won't be resized."
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
image_resize_algorithm: (
|
|
||||||
Literal["bilinear", "bicubic", "lanczos"] | Resampling | None
|
|
||||||
) = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("image_resize_algorithm", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def convert_image_resize_algorithm(cls, image_resize_algorithm):
|
|
||||||
"""
|
|
||||||
Convert the image resize algorithm to a PIL.Image.Resampling enum.
|
|
||||||
"""
|
|
||||||
if isinstance(image_resize_algorithm, str):
|
|
||||||
image_resize_algorithm = image_resize_algorithm.lower()
|
|
||||||
if image_resize_algorithm == "bilinear":
|
|
||||||
image_resize_algorithm = Resampling.BILINEAR
|
|
||||||
elif image_resize_algorithm == "bicubic":
|
|
||||||
image_resize_algorithm = Resampling.BICUBIC
|
|
||||||
elif image_resize_algorithm == "lanczos":
|
|
||||||
image_resize_algorithm = Resampling.LANCZOS
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid image resize algorithm: {image_resize_algorithm}"
|
|
||||||
)
|
|
||||||
return image_resize_algorithm
|
|
||||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
elif cfg.sample_packing:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
if cfg.eval_sample_packing is not False:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
@@ -443,7 +443,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||||
@@ -474,11 +473,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||||
# on the agreed on value for sample_packing_eff_est
|
# on the agreed on value for sample_packing_eff_est
|
||||||
total_num_steps = int(
|
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||||
math.floor(
|
|
||||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||||
@@ -499,12 +494,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
total_num_steps = int(
|
total_num_steps = int(
|
||||||
math.ceil(
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
len(train_dataset)
|
|
||||||
* cfg.num_epochs
|
|
||||||
* cfg.sequence_parallel_degree
|
|
||||||
/ cfg.batch_size
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|||||||
@@ -11,11 +11,7 @@ import time
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
from datasets import load_dataset
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
@@ -29,11 +25,9 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
|||||||
except (
|
except (
|
||||||
requests.exceptions.ReadTimeout,
|
requests.exceptions.ReadTimeout,
|
||||||
requests.exceptions.ConnectionError,
|
requests.exceptions.ConnectionError,
|
||||||
requests.exceptions.HTTPError,
|
|
||||||
) as exc:
|
) as exc:
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
wait = 2**attempt * delay # in seconds
|
time.sleep(delay)
|
||||||
time.sleep(wait)
|
|
||||||
else:
|
else:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
@@ -43,7 +37,6 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
|
|||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
@disable_hf_offline
|
|
||||||
def snapshot_download_w_retry(*args, **kwargs):
|
def snapshot_download_w_retry(*args, **kwargs):
|
||||||
return snapshot_download(*args, **kwargs)
|
return snapshot_download(*args, **kwargs)
|
||||||
|
|
||||||
@@ -51,19 +44,19 @@ def snapshot_download_w_retry(*args, **kwargs):
|
|||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_smollm2_135m_model():
|
def download_smollm2_135m_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
|
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_llama_68m_random_model():
|
def download_llama_68m_random_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
|
snapshot_download_w_retry("JackFram/llama-68m")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_qwen_2_5_half_billion_model():
|
def download_qwen_2_5_half_billion_model():
|
||||||
# download the model
|
# download the model
|
||||||
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
|
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@@ -108,37 +101,6 @@ def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_fozzie_alpaca_dpo_dataset():
|
|
||||||
# download the dataset
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
|
||||||
)
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"fozziethebeat/alpaca_messages_2k_dpo_test",
|
|
||||||
repo_type="dataset",
|
|
||||||
revision="ea82cff",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
@disable_hf_offline
|
|
||||||
def dataset_fozzie_alpaca_dpo_dataset(
|
|
||||||
download_fozzie_alpaca_dpo_dataset,
|
|
||||||
): # pylint: disable=unused-argument,redefined-outer-name
|
|
||||||
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
@disable_hf_offline
|
|
||||||
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
|
||||||
download_fozzie_alpaca_dpo_dataset,
|
|
||||||
): # pylint: disable=unused-argument,redefined-outer-name
|
|
||||||
return load_dataset(
|
|
||||||
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
@@ -147,141 +109,10 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_argilla_dpo_pairs_dataset():
|
|
||||||
# download the dataset
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_tiny_shakespeare_dataset():
|
def download_tiny_shakespeare_dataset():
|
||||||
# download the dataset
|
# download the dataset
|
||||||
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
|
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_deepseek_model_fixture():
|
|
||||||
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_huggyllama_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"huggyllama/llama-7b",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*", "config.json"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_llama_1b_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"NousResearch/Llama-3.2-1B",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*", "config.json"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_llama3_8b_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_llama3_8b_instruct_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"NousResearch/Meta-Llama-3-8B-Instruct",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_phi_35_mini_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_phi_3_medium_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"microsoft/Phi-3-medium-128k-instruct",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_mistral_7b_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"casperhansen/mistral-7b-instruct-v0.1-awq",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*", "config.json"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_gemma_2b_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"unsloth/gemma-2b-it",
|
|
||||||
revision="703fb4a",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*", "config.json"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_gemma2_9b_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"mlx-community/gemma-2-9b-it-4bit",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*", "config.json"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_mlx_mistral_7b_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*", "config.json"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_llama2_model_fixture():
|
|
||||||
# download the tokenizer only
|
|
||||||
snapshot_download_w_retry(
|
|
||||||
"NousResearch/Llama-2-7b-hf",
|
|
||||||
repo_type="model",
|
|
||||||
allow_patterns=["*token*", "config.json"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
@enable_hf_offline
|
|
||||||
def tokenizer_huggyllama(
|
|
||||||
download_huggyllama_model_fixture,
|
|
||||||
): # pylint: disable=unused-argument,redefined-outer-name
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
|
||||||
tokenizer.pad_token = "</s>"
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -347,34 +178,3 @@ def cleanup_monkeypatches():
|
|||||||
module_globals = module_name_tuple[1]
|
module_globals = module_name_tuple[1]
|
||||||
for module_global in module_globals:
|
for module_global in module_globals:
|
||||||
globals().pop(module_global, None)
|
globals().pop(module_global, None)
|
||||||
|
|
||||||
|
|
||||||
# # pylint: disable=redefined-outer-name,unused-argument
|
|
||||||
# def test_load_fixtures(
|
|
||||||
# download_smollm2_135m_model,
|
|
||||||
# download_llama_68m_random_model,
|
|
||||||
# download_qwen_2_5_half_billion_model,
|
|
||||||
# download_tatsu_lab_alpaca_dataset,
|
|
||||||
# download_mhenrichsen_alpaca_2k_dataset,
|
|
||||||
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
|
|
||||||
# download_mlabonne_finetome_100k_dataset,
|
|
||||||
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
|
|
||||||
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
|
|
||||||
# download_fozzie_alpaca_dpo_dataset,
|
|
||||||
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
|
|
||||||
# download_argilla_dpo_pairs_dataset,
|
|
||||||
# download_tiny_shakespeare_dataset,
|
|
||||||
# download_deepseek_model_fixture,
|
|
||||||
# download_huggyllama_model_fixture,
|
|
||||||
# download_llama_1b_model_fixture,
|
|
||||||
# download_llama3_8b_model_fixture,
|
|
||||||
# download_llama3_8b_instruct_model_fixture,
|
|
||||||
# download_phi_35_mini_model_fixture,
|
|
||||||
# download_phi_3_medium_model_fixture,
|
|
||||||
# download_mistral_7b_model_fixture,
|
|
||||||
# download_gemma_2b_model_fixture,
|
|
||||||
# download_gemma2_9b_model_fixture,
|
|
||||||
# download_mlx_mistral_7b_model_fixture,
|
|
||||||
# download_llama2_model_fixture,
|
|
||||||
# ):
|
|
||||||
# pass
|
|
||||||
|
|||||||
@@ -10,13 +10,10 @@ from transformers import AddedToken, AutoTokenizer
|
|||||||
from axolotl.core.chat.format.chatml import format_message
|
from axolotl.core.chat.format.chatml import format_message
|
||||||
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
from axolotl.core.chat.messages import ChatFormattedChats, Chats
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline # noqa
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", name="llama_tokenizer")
|
@pytest.fixture(scope="session", name="llama_tokenizer")
|
||||||
@enable_hf_offline
|
|
||||||
def llama_tokenizer_fixture():
|
def llama_tokenizer_fixture():
|
||||||
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
@pytest.fixture(scope="session", name="chatml_tokenizer")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ e2e tests for kd trainer support in Axolotl
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
@@ -12,8 +13,6 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="kd_min_cfg")
|
@pytest.fixture(name="kd_min_cfg")
|
||||||
def min_cfg(temp_dir):
|
def min_cfg(temp_dir):
|
||||||
|
|||||||
@@ -2,13 +2,15 @@
|
|||||||
Simple end-to-end test for Liger integration
|
Simple end-to-end test for Liger integration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from e2e.utils import require_torch_2_4_1
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
|
from ..utils import check_model_output_exists
|
||||||
|
|
||||||
|
|
||||||
class LigerIntegrationTestCase:
|
class LigerIntegrationTestCase:
|
||||||
|
|||||||
@@ -8,12 +8,11 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from e2e.utils import require_vllm
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import require_vllm
|
|
||||||
|
|
||||||
|
|
||||||
class TestGRPO:
|
class TestGRPO:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -9,13 +9,12 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from e2e.utils import check_tensorboard
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers.testing_utils import get_torch_dist_unique_port
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,10 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
|
|||||||
def test_geglu_model_integration():
|
def test_geglu_model_integration():
|
||||||
"""Test GeGLU activation with Gemma model."""
|
"""Test GeGLU activation with Gemma model."""
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
|
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
|
||||||
)
|
)
|
||||||
peft_config = get_peft_config(
|
peft_config = get_peft_config(
|
||||||
{
|
{
|
||||||
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
|
|||||||
"""Test LoRA kernel patches across different model architectures."""
|
"""Test LoRA kernel patches across different model architectures."""
|
||||||
# Load model with appropriate dtype
|
# Load model with appropriate dtype
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
|
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply LoRA configuration
|
# Apply LoRA configuration
|
||||||
@@ -408,7 +408,7 @@ def test_kernel_training_integration():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, _, _ = load_model_and_tokenizer(cfg=cfg)
|
model, _ = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
# Verify correct activation function
|
# Verify correct activation function
|
||||||
layer = model.model.model.layers[0]
|
layer = model.model.model.layers[0]
|
||||||
|
|||||||
@@ -1,209 +0,0 @@
|
|||||||
"""Tests for sequence parallelism functionality."""
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name,unused-argument
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from accelerate.state import PartialState
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
|
||||||
get_ring_attn_group,
|
|
||||||
set_ring_attn_group,
|
|
||||||
)
|
|
||||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def partial_state():
|
|
||||||
"""Create a real PartialState instance for testing."""
|
|
||||||
state = PartialState()
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="cfg")
|
|
||||||
def fixture_cfg():
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"learning_rate": 1e-3,
|
|
||||||
"output_dir": "./model-out",
|
|
||||||
"sequence_len": 512,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
class TestSequenceParallelHelpers:
|
|
||||||
"""Test helper functions used in sequence parallelism."""
|
|
||||||
|
|
||||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
|
||||||
"""Test position_ids adjustment for sequence slices."""
|
|
||||||
# Create sample position_ids with multiple sequences
|
|
||||||
position_ids = torch.tensor(
|
|
||||||
[
|
|
||||||
# First sequence with 2 samples
|
|
||||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
|
||||||
# Second sequence with 3 samples
|
|
||||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adjust as if this was the second slice (start_idx = 4)
|
|
||||||
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
|
||||||
|
|
||||||
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
|
||||||
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
|
||||||
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
|
||||||
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
|
||||||
|
|
||||||
assert torch.all(adjusted[0] == expected_first_seq)
|
|
||||||
assert torch.all(adjusted[1] == expected_second_seq)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRingAttention:
|
|
||||||
"""Tests for the ring attention functionality."""
|
|
||||||
|
|
||||||
@patch("torch.distributed.get_rank")
|
|
||||||
@patch("torch.distributed.get_world_size")
|
|
||||||
def test_get_ring_attn_group_no_registration(
|
|
||||||
self, mock_world_size, mock_rank, partial_state
|
|
||||||
):
|
|
||||||
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
|
||||||
# Setup mocks
|
|
||||||
mock_world_size.return_value = 4
|
|
||||||
mock_rank.return_value = 0
|
|
||||||
|
|
||||||
# Get the group without registration
|
|
||||||
group = get_ring_attn_group()
|
|
||||||
|
|
||||||
# Verify that None was returned
|
|
||||||
assert group is None
|
|
||||||
|
|
||||||
@patch("torch.distributed.new_group")
|
|
||||||
@patch("torch.distributed.get_rank")
|
|
||||||
@patch("torch.distributed.get_world_size")
|
|
||||||
def test_register_ring_attn(
|
|
||||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
|
||||||
):
|
|
||||||
"""Test that ring attention groups are created correctly."""
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
|
||||||
|
|
||||||
# Setup mocks
|
|
||||||
mock_world_size.return_value = 8 # 8 GPUs total
|
|
||||||
mock_rank.return_value = 3 # GPU #3
|
|
||||||
mock_group = MagicMock()
|
|
||||||
mock_new_group.return_value = mock_group
|
|
||||||
|
|
||||||
# Call register_ring_attn with size 4
|
|
||||||
register_ring_attn(sequence_parallel_degree=4)
|
|
||||||
|
|
||||||
# Verify the number of calls without examining the arguments
|
|
||||||
assert mock_new_group.call_count == 2
|
|
||||||
|
|
||||||
# Verify that new_group was called
|
|
||||||
mock_new_group.assert_called()
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
set_ring_attn_group(None)
|
|
||||||
|
|
||||||
|
|
||||||
# Mock a simplified DataCollator test
|
|
||||||
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
|
|
||||||
@patch("torch.distributed.get_rank")
|
|
||||||
@patch("torch.distributed.get_world_size")
|
|
||||||
def test_sequence_parallel_slicing(
|
|
||||||
mock_world_size, mock_rank, mock_get_group, partial_state
|
|
||||||
):
|
|
||||||
"""Test the basic sequence slicing logic without full collator instantiation."""
|
|
||||||
# Setup mocks
|
|
||||||
mock_get_group.return_value = MagicMock()
|
|
||||||
mock_rank.return_value = 1 # Second GPU
|
|
||||||
mock_world_size.return_value = 4 # 4 GPUs total
|
|
||||||
|
|
||||||
# Create a sample batch
|
|
||||||
batch = {
|
|
||||||
"input_ids": torch.tensor(
|
|
||||||
[
|
|
||||||
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
|
|
||||||
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
|
|
||||||
]
|
|
||||||
),
|
|
||||||
"attention_mask": torch.ones(2, 12),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Simplified slicing logic from SequenceParallelDataCollator
|
|
||||||
def slice_batch(batch, rank, world_size):
|
|
||||||
result = {}
|
|
||||||
for key in batch:
|
|
||||||
seq_len = batch[key].shape[1]
|
|
||||||
slice_size = seq_len // world_size
|
|
||||||
start_idx = rank * slice_size
|
|
||||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
|
||||||
result[key] = batch[key][:, start_idx:end_idx]
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Slice the batch
|
|
||||||
result = slice_batch(
|
|
||||||
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check slicing
|
|
||||||
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
|
|
||||||
expected_input_ids = torch.tensor(
|
|
||||||
[
|
|
||||||
[104, 105, 106], # Second slice of first sequence
|
|
||||||
[204, 205, 206], # Second slice of second sequence
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert torch.all(result["input_ids"] == expected_input_ids)
|
|
||||||
|
|
||||||
|
|
||||||
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
|
|
||||||
def test_config_validation_with_valid_inputs(cfg):
|
|
||||||
"""Test that valid sequence parallelism configurations pass validation."""
|
|
||||||
# Import the actual model class with appropriate mocks
|
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
|
||||||
|
|
||||||
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
|
||||||
cfg = cfg | {
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Should validate without errors
|
|
||||||
config = AxolotlInputConfig(**cfg)
|
|
||||||
assert config.sequence_parallel_degree == 2
|
|
||||||
assert config.flash_attention is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_config_validation_with_invalid_inputs(cfg):
|
|
||||||
"""Test that invalid sequence parallelism configurations fail validation."""
|
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
|
||||||
|
|
||||||
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
|
||||||
cfg = cfg | {
|
|
||||||
"sequence_parallel_degree": 2,
|
|
||||||
"flash_attention": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Should raise ValidationError
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
|
||||||
AxolotlInputConfig(**cfg)
|
|
||||||
|
|
||||||
# Verify error message
|
|
||||||
assert "flash_attention: true must be set" in str(excinfo.value)
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
E2E tests for deepseekv3
|
E2E tests for lora llama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -14,8 +14,6 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -25,7 +23,6 @@ class TestDeepseekV3:
|
|||||||
Test case for DeepseekV3 models
|
Test case for DeepseekV3 models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sample_packing",
|
"sample_packing",
|
||||||
[True, False],
|
[True, False],
|
||||||
@@ -83,7 +80,6 @@ class TestDeepseekV3:
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sample_packing",
|
"sample_packing",
|
||||||
[True, False],
|
[True, False],
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for gemma2
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGemma2:
|
|
||||||
"""
|
|
||||||
Test case for Gemma2 models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sample_packing",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_lora_gemma2(self, temp_dir, sample_packing):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "axolotl-ai-co/gemma-2-33M",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_property_mappings": {
|
|
||||||
"role": "from",
|
|
||||||
"content": "value",
|
|
||||||
},
|
|
||||||
"drop_system_message": True,
|
|
||||||
"split": "train[:1%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"special_tokens": {
|
|
||||||
"bos_token": "<bos>",
|
|
||||||
"eos_token": "<eos>",
|
|
||||||
},
|
|
||||||
"chat_template": "gemma", # gemma2's template is same as gemma
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sample_packing",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_fft_gemma2(self, temp_dir, sample_packing):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "axolotl-ai-co/gemma-2-33M",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"val_set_size": 0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_property_mappings": {
|
|
||||||
"role": "from",
|
|
||||||
"content": "value",
|
|
||||||
},
|
|
||||||
"split": "train[:1%]",
|
|
||||||
"drop_system_message": True,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"chat_template": "gemma", # gemma2's template is same as gemma
|
|
||||||
"special_tokens": {
|
|
||||||
"bos_token": "<bos>",
|
|
||||||
"eos_token": "<eos>",
|
|
||||||
},
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for gemma3_text
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestGemma3Text:
|
|
||||||
"""
|
|
||||||
Test case for Gemma3Text models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sample_packing",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_lora_gemma3_text(self, temp_dir, sample_packing):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "axolotl-ai-co/gemma-3-34M",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_property_mappings": {
|
|
||||||
"role": "from",
|
|
||||||
"content": "value",
|
|
||||||
},
|
|
||||||
"split": "train[:1%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"special_tokens": {
|
|
||||||
"bos_token": "<bos>",
|
|
||||||
"eos_token": "<eos>",
|
|
||||||
},
|
|
||||||
"chat_template": "gemma3",
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sample_packing",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_fft_gemma3_text(self, temp_dir, sample_packing):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "axolotl-ai-co/gemma-3-34M",
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
"flash_attention": True,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"val_set_size": 0,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_property_mappings": {
|
|
||||||
"role": "from",
|
|
||||||
"content": "value",
|
|
||||||
},
|
|
||||||
"split": "train[:1%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"chat_template": "gemma3",
|
|
||||||
"special_tokens": {
|
|
||||||
"bos_token": "<bos>",
|
|
||||||
"eos_token": "<eos>",
|
|
||||||
},
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 4,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_bnb_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_safetensors": True,
|
|
||||||
"bf16": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
|
||||||
@@ -5,14 +5,14 @@ E2E tests for llama
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from e2e.utils import check_model_output_exists
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import check_model_output_exists
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|||||||
@@ -201,46 +201,3 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_soap(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "vicgalle/alpaca-gpt4",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "soap",
|
|
||||||
"adam_beta1": 0.9,
|
|
||||||
"adam_beta2": 0.95,
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class TestCustomSchedulers(unittest.TestCase):
|
|||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_hf",
|
||||||
"max_steps": 20,
|
"max_steps": 20,
|
||||||
"lr_scheduler": "rex",
|
"lr_scheduler": "rex",
|
||||||
"warmup_steps": 5,
|
"warmup_steps": 5,
|
||||||
|
|||||||
@@ -1,85 +0,0 @@
|
|||||||
"""
|
|
||||||
test utils for helpers and decorators
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
from huggingface_hub.utils import reset_sessions
|
|
||||||
|
|
||||||
|
|
||||||
def reload_modules(hf_hub_offline):
|
|
||||||
# Force reload of the modules that check this variable
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import huggingface_hub.constants
|
|
||||||
|
|
||||||
# Reload the constants module first, as others depend on it
|
|
||||||
importlib.reload(huggingface_hub.constants)
|
|
||||||
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
|
|
||||||
importlib.reload(datasets.config)
|
|
||||||
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
|
|
||||||
reset_sessions()
|
|
||||||
|
|
||||||
|
|
||||||
def enable_hf_offline(test_func):
|
|
||||||
"""
|
|
||||||
test decorator that sets HF_HUB_OFFLINE environment variable to True and restores it after the test even if the test fails.
|
|
||||||
:param test_func:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
@wraps(test_func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
# Save the original value of HF_HUB_OFFLINE environment variable
|
|
||||||
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
|
|
||||||
|
|
||||||
# Set HF_OFFLINE environment variable to True
|
|
||||||
os.environ["HF_HUB_OFFLINE"] = "1"
|
|
||||||
|
|
||||||
reload_modules(True)
|
|
||||||
try:
|
|
||||||
# Run the test function
|
|
||||||
return test_func(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
# Restore the original value of HF_HUB_OFFLINE environment variable
|
|
||||||
if original_hf_offline is not None:
|
|
||||||
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
|
|
||||||
reload_modules(bool(original_hf_offline))
|
|
||||||
else:
|
|
||||||
del os.environ["HF_HUB_OFFLINE"]
|
|
||||||
reload_modules(False)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def disable_hf_offline(test_func):
|
|
||||||
"""
|
|
||||||
test decorator that sets HF_HUB_OFFLINE environment variable to False and restores it after the wrapped func
|
|
||||||
:param test_func:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
@wraps(test_func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
# Save the original value of HF_HUB_OFFLINE environment variable
|
|
||||||
original_hf_offline = os.getenv("HF_HUB_OFFLINE")
|
|
||||||
|
|
||||||
# Set HF_OFFLINE environment variable to True
|
|
||||||
os.environ["HF_HUB_OFFLINE"] = "0"
|
|
||||||
|
|
||||||
reload_modules(False)
|
|
||||||
try:
|
|
||||||
# Run the test function
|
|
||||||
return test_func(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
# Restore the original value of HF_HUB_OFFLINE environment variable
|
|
||||||
if original_hf_offline is not None:
|
|
||||||
os.environ["HF_HUB_OFFLINE"] = original_hf_offline
|
|
||||||
reload_modules(bool(original_hf_offline))
|
|
||||||
else:
|
|
||||||
del os.environ["HF_HUB_OFFLINE"]
|
|
||||||
reload_modules(False)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
@@ -4,13 +4,12 @@ shared fixtures for prompt strategies tests
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||||
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="assistant_dataset")
|
@pytest.fixture(name="assistant_dataset")
|
||||||
def fixture_assistant_dataset():
|
def fixture_assistant_dataset():
|
||||||
@@ -109,27 +108,31 @@ def fixture_toolcalling_dataset():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
|
||||||
@enable_hf_offline
|
def fixture_llama3_tokenizer():
|
||||||
def fixture_llama3_tokenizer(
|
hf_hub_download(
|
||||||
download_llama3_8b_instruct_model_fixture,
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
||||||
): # pylint: disable=unused-argument,redefined-outer-name
|
filename="special_tokens_map.json",
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
||||||
|
filename="tokenizer_config.json",
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", filename="tokenizer.json"
|
||||||
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="smollm2_tokenizer", scope="session", autouse=True)
|
||||||
@enable_hf_offline
|
|
||||||
def fixture_smollm2_tokenizer():
|
def fixture_smollm2_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
||||||
@enable_hf_offline
|
def fixture_mistralv03_tokenizer():
|
||||||
def fixture_mistralv03_tokenizer(
|
|
||||||
download_mlx_mistral_7b_model_fixture,
|
|
||||||
): # pylint: disable=unused-argument,redefined-outer-name
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||||
)
|
)
|
||||||
@@ -137,7 +140,6 @@ def fixture_mistralv03_tokenizer(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
|
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
|
||||||
@enable_hf_offline
|
|
||||||
def fixture_phi35_tokenizer():
|
def fixture_phi35_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|||||||
@@ -11,8 +11,6 @@ from axolotl.datasets import TokenizedPromptDataset
|
|||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="alpaca_dataset")
|
@pytest.fixture(name="alpaca_dataset")
|
||||||
def fixture_alpaca_dataset():
|
def fixture_alpaca_dataset():
|
||||||
@@ -28,7 +26,6 @@ def fixture_alpaca_dataset():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
@enable_hf_offline
|
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
# pylint: disable=all
|
# pylint: disable=all
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
|||||||
@@ -13,11 +13,8 @@ from axolotl.utils.chat_templates import (
|
|||||||
get_chat_template,
|
get_chat_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer")
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
@enable_hf_offline
|
|
||||||
def fixture_llama3_tokenizer():
|
def fixture_llama3_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,6 @@ from axolotl.prompt_strategies.chat_template import (
|
|||||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
from axolotl.utils.chat_templates import get_chat_template
|
from axolotl.utils.chat_templates import get_chat_template
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -32,14 +30,12 @@ PARAMETRIZE_PARAMS = [
|
|||||||
"mistralv03_tokenizer_chat_template_jinja",
|
"mistralv03_tokenizer_chat_template_jinja",
|
||||||
"[/INST]",
|
"[/INST]",
|
||||||
),
|
),
|
||||||
# TODO: temporarily skip gemma due to gemma3 template
|
(
|
||||||
# Re-enable on new chat_template implementation for perf
|
"gemma2_tokenizer",
|
||||||
# (
|
"jinja",
|
||||||
# "gemma2_tokenizer",
|
"gemma2_tokenizer_chat_template_jinja",
|
||||||
# "jinja",
|
"<end_of_turn>",
|
||||||
# "gemma2_tokenizer_chat_template_jinja",
|
),
|
||||||
# "<end_of_turn>",
|
|
||||||
# ),
|
|
||||||
("phi35_tokenizer", "phi_35", None, "<|end|>"),
|
("phi35_tokenizer", "phi_35", None, "<|end|>"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -97,11 +93,7 @@ class TestChatTemplateConfigurations:
|
|||||||
if (
|
if (
|
||||||
turn_idx == 0
|
turn_idx == 0
|
||||||
and turn.get("from") in ["system", "context"]
|
and turn.get("from") in ["system", "context"]
|
||||||
and (
|
and "mistral" in tokenizer.name_or_path.lower()
|
||||||
"mistral" in tokenizer.name_or_path.lower()
|
|
||||||
or "gemma"
|
|
||||||
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
start_idx == -1 and end_idx == -1
|
start_idx == -1 and end_idx == -1
|
||||||
@@ -109,7 +101,6 @@ class TestChatTemplateConfigurations:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
def test_train_on_inputs_true(
|
def test_train_on_inputs_true(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|||||||
@@ -11,8 +11,6 @@ from transformers import AutoTokenizer
|
|||||||
from axolotl.prompt_strategies.dpo.chat_template import default
|
from axolotl.prompt_strategies.dpo.chat_template import default
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="assistant_dataset")
|
@pytest.fixture(name="assistant_dataset")
|
||||||
def fixture_assistant_dataset():
|
def fixture_assistant_dataset():
|
||||||
@@ -80,8 +78,15 @@ def fixture_custom_assistant_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
def fixture_llama3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
tokenizer.eos_token = "<|eot_id|>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="phi3_tokenizer")
|
@pytest.fixture(name="phi3_tokenizer")
|
||||||
@enable_hf_offline
|
|
||||||
def fixture_phi3_tokenizer():
|
def fixture_phi3_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
|
||||||
|
|
||||||
@@ -89,7 +94,6 @@ def fixture_phi3_tokenizer():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="gemma_tokenizer")
|
@pytest.fixture(name="gemma_tokenizer")
|
||||||
@enable_hf_offline
|
|
||||||
def fixture_gemma_tokenizer():
|
def fixture_gemma_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
|
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ from axolotl.prompt_strategies.dpo import load as load_dpo
|
|||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="minimal_dpo_cfg")
|
@pytest.fixture(name="minimal_dpo_cfg")
|
||||||
def fixture_cfg():
|
def fixture_cfg():
|
||||||
@@ -36,8 +34,6 @@ class TestDPOChatml:
|
|||||||
Test loading DPO preference datasets with chatml formatting
|
Test loading DPO preference datasets with chatml formatting
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
||||||
@enable_hf_offline
|
|
||||||
def test_default(self, minimal_dpo_cfg):
|
def test_default(self, minimal_dpo_cfg):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -8,15 +8,12 @@ from transformers import LlamaTokenizer
|
|||||||
|
|
||||||
from axolotl.utils.data import encode_pretraining, md5
|
from axolotl.utils.data import encode_pretraining, md5
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
class TestEncodePretraining(unittest.TestCase):
|
class TestEncodePretraining(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
test class for encode pretraining and md5 helper
|
test class for encode pretraining and md5 helper
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
self.tokenizer.add_special_tokens(
|
self.tokenizer.add_special_tokens(
|
||||||
|
|||||||
@@ -4,37 +4,31 @@ Test dataset loading under various conditions.
|
|||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
from conftest import snapshot_download_w_retry
|
||||||
|
from constants import (
|
||||||
|
ALPACA_MESSAGES_CONFIG_OG,
|
||||||
|
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||||
|
SPECIAL_TOKENS,
|
||||||
|
)
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import snapshot_download
|
from transformers import AutoTokenizer
|
||||||
from transformers import PreTrainedTokenizer
|
|
||||||
|
|
||||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.constants import (
|
|
||||||
ALPACA_MESSAGES_CONFIG_OG,
|
|
||||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
|
||||||
SPECIAL_TOKENS,
|
|
||||||
)
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
class TestDatasetPreparation(unittest.TestCase):
|
||||||
class TestDatasetPreparation:
|
|
||||||
"""Test a configured dataloader."""
|
"""Test a configured dataloader."""
|
||||||
|
|
||||||
@pytest.fixture
|
def setUp(self) -> None:
|
||||||
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||||
yield tokenizer_huggyllama
|
# Alpaca dataset.
|
||||||
|
self.dataset = Dataset.from_list(
|
||||||
@pytest.fixture
|
|
||||||
def dataset_fixture(self):
|
|
||||||
yield Dataset.from_list(
|
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
||||||
@@ -44,9 +38,7 @@ class TestDatasetPreparation:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
def test_load_hub(self):
|
||||||
@enable_hf_offline
|
|
||||||
def test_load_hub(self, tokenizer):
|
|
||||||
"""Core use case. Verify that processing data from the hub works"""
|
"""Core use case. Verify that processing data from the hub works"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
@@ -63,28 +55,25 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
def test_load_local_hub(self):
|
||||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
|
||||||
def test_load_local_hub(self, tokenizer):
|
|
||||||
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
snapshot_path = snapshot_download(
|
snapshot_download_w_retry(
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=tmp_ds_path,
|
local_dir=tmp_ds_path,
|
||||||
)
|
)
|
||||||
# offline mode doesn't actually copy it to local_dir, so we
|
|
||||||
# have to copy all the contents in the dir manually from the returned snapshot_path
|
|
||||||
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
# Right now a local copy that doesn't fully conform to a dataset
|
# Right now a local copy that doesn't fully conform to a dataset
|
||||||
@@ -107,7 +96,9 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -115,12 +106,11 @@ class TestDatasetPreparation:
|
|||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
@enable_hf_offline
|
def test_load_from_save_to_disk(self):
|
||||||
def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
|
|
||||||
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
"""Usual use case. Verify datasets saved via `save_to_disk` can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_name = Path(tmp_dir) / "tmp_dataset"
|
||||||
dataset_fixture.save_to_disk(str(tmp_ds_name))
|
self.dataset.save_to_disk(str(tmp_ds_name))
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -136,21 +126,22 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
def test_load_from_dir_of_parquet(self):
|
||||||
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
|
||||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||||
tmp_ds_dir.mkdir()
|
tmp_ds_dir.mkdir()
|
||||||
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
tmp_ds_path = tmp_ds_dir / "shard1.parquet"
|
||||||
dataset_fixture.to_parquet(tmp_ds_path)
|
self.dataset.to_parquet(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -171,21 +162,22 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
def test_load_from_dir_of_json(self):
|
||||||
def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
|
|
||||||
"""Standard use case. Verify a directory of json files can be loaded."""
|
"""Standard use case. Verify a directory of json files can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||||
tmp_ds_dir.mkdir()
|
tmp_ds_dir.mkdir()
|
||||||
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
tmp_ds_path = tmp_ds_dir / "shard1.json"
|
||||||
dataset_fixture.to_json(tmp_ds_path)
|
self.dataset.to_json(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -206,19 +198,20 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
def test_load_from_single_parquet(self):
|
||||||
def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
|
|
||||||
"""Standard use case. Verify a single parquet file can be loaded."""
|
"""Standard use case. Verify a single parquet file can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.parquet"
|
||||||
dataset_fixture.to_parquet(tmp_ds_path)
|
self.dataset.to_parquet(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -235,19 +228,20 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
def test_load_from_single_json(self):
|
||||||
def test_load_from_single_json(self, tokenizer, dataset_fixture):
|
|
||||||
"""Standard use case. Verify a single json file can be loaded."""
|
"""Standard use case. Verify a single json file can be loaded."""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
tmp_ds_path = Path(tmp_dir) / "tmp_dataset.json"
|
||||||
dataset_fixture.to_json(tmp_ds_path)
|
self.dataset.to_json(tmp_ds_path)
|
||||||
|
|
||||||
prepared_path: Path = Path(tmp_dir) / "prepared"
|
prepared_path: Path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -264,15 +258,15 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
|
|
||||||
@enable_hf_offline
|
|
||||||
def test_load_hub_with_dpo(self):
|
def test_load_hub_with_dpo(self):
|
||||||
"""Verify that processing dpo data from the hub works"""
|
"""Verify that processing dpo data from the hub works"""
|
||||||
|
|
||||||
@@ -291,9 +285,7 @@ class TestDatasetPreparation:
|
|||||||
assert len(train_dataset) == 1800
|
assert len(train_dataset) == 1800
|
||||||
assert "conversation" in train_dataset.features
|
assert "conversation" in train_dataset.features
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
def test_load_hub_with_revision(self):
|
||||||
@enable_hf_offline
|
|
||||||
def test_load_hub_with_revision(self, tokenizer):
|
|
||||||
"""Verify that processing data from the hub works with a specific revision"""
|
"""Verify that processing data from the hub works with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
@@ -315,17 +307,16 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
@enable_hf_offline
|
def test_load_hub_with_revision_with_dpo(self):
|
||||||
def test_load_hub_with_revision_with_dpo(
|
|
||||||
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
|
||||||
):
|
|
||||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -338,34 +329,22 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
with patch(
|
|
||||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
|
||||||
) as mock_load_dataset:
|
|
||||||
# Set up the mock to return different values on successive calls
|
|
||||||
mock_load_dataset.return_value = (
|
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
assert len(train_dataset) == 1800
|
||||||
|
assert "conversation" in train_dataset.features
|
||||||
|
|
||||||
assert len(train_dataset) == 1800
|
def test_load_local_hub_with_revision(self):
|
||||||
assert "conversation" in train_dataset.features
|
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
|
||||||
def test_load_local_hub_with_revision(self, tokenizer):
|
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
snapshot_path = snapshot_download(
|
snapshot_download_w_retry(
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=tmp_ds_path,
|
local_dir=tmp_ds_path,
|
||||||
revision="d05c1cb",
|
revision="d05c1cb",
|
||||||
)
|
)
|
||||||
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -386,7 +365,9 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
@@ -394,19 +375,17 @@ class TestDatasetPreparation:
|
|||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
@enable_hf_offline
|
def test_loading_local_dataset_folder(self):
|
||||||
def test_loading_local_dataset_folder(self, tokenizer):
|
|
||||||
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
snapshot_path = snapshot_download(
|
snapshot_download_w_retry(
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=tmp_ds_path,
|
local_dir=tmp_ds_path,
|
||||||
)
|
)
|
||||||
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -422,10 +401,16 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|||||||
@@ -8,19 +8,16 @@ import hashlib
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
|
||||||
from axolotl.utils.data import prepare_dataset
|
from axolotl.utils.data import prepare_dataset
|
||||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_processor, load_tokenizer
|
from axolotl.utils.models import load_processor, load_tokenizer
|
||||||
|
|
||||||
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
||||||
"""
|
"""
|
||||||
@@ -216,12 +213,13 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
|||||||
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
||||||
|
|
||||||
|
|
||||||
class TestDeduplicateRLDataset:
|
class TestDeduplicateRLDataset(unittest.TestCase):
|
||||||
"""Test a configured dataloader with deduplication."""
|
"""Test a configured dataloader with deduplication."""
|
||||||
|
|
||||||
@pytest.fixture
|
def setUp(self) -> None:
|
||||||
def cfg(self):
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
fixture = DictDefault(
|
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||||
|
self.cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
@@ -234,69 +232,36 @@ class TestDeduplicateRLDataset:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
yield fixture
|
|
||||||
|
|
||||||
@enable_hf_offline
|
def test_load_with_deduplication(self):
|
||||||
def test_load_with_deduplication(
|
|
||||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
|
||||||
):
|
|
||||||
"""Verify that loading with deduplication removes duplicates."""
|
"""Verify that loading with deduplication removes duplicates."""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# Load the dataset using the deduplication setting
|
||||||
with (
|
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||||
patch(
|
|
||||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
|
||||||
) as mock_load_dataset,
|
|
||||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
|
||||||
):
|
|
||||||
# Set up the mock to return different values on successive calls
|
|
||||||
mock_load_dataset.side_effect = [
|
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
|
||||||
]
|
|
||||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
|
||||||
|
|
||||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
# Verify that the dataset has been deduplicated
|
||||||
|
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||||
|
|
||||||
# Verify that the dataset has been deduplicated
|
def test_load_without_deduplication(self):
|
||||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
"""Verify that loading without deduplication retains duplicates."""
|
||||||
|
self.cfg.dataset_exact_deduplication = False
|
||||||
|
# Load the dataset without deduplication
|
||||||
|
train_dataset, _ = load_prepare_preference_datasets(self.cfg)
|
||||||
|
|
||||||
@enable_hf_offline
|
# Verify that the dataset retains duplicates
|
||||||
def test_load_without_deduplication(
|
assert (
|
||||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
len(train_dataset) == 1800 * 2
|
||||||
):
|
), "Dataset deduplication occurred when it should not have"
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
|
||||||
) as mock_load_dataset,
|
|
||||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
|
||||||
):
|
|
||||||
# Set up the mock to return different values on successive calls
|
|
||||||
mock_load_dataset.side_effect = [
|
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
|
||||||
]
|
|
||||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
|
||||||
|
|
||||||
cfg.dataset_exact_deduplication = False
|
|
||||||
# Load the dataset without deduplication
|
|
||||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
|
||||||
|
|
||||||
# Verify that the dataset retains duplicates
|
|
||||||
assert (
|
|
||||||
len(train_dataset) == 1800 * 2
|
|
||||||
), "Dataset deduplication occurred when it should not have"
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeduplicateNonRL(unittest.TestCase):
|
class TestDeduplicateNonRL(unittest.TestCase):
|
||||||
"""Test prepare_dataset function with different configurations."""
|
"""Test prepare_dataset function with different configurations."""
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||||
self.cfg_1 = DictDefault(
|
self.cfg_1 = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "huggyllama/llama-7b",
|
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"dataset_exact_deduplication": True,
|
"dataset_exact_deduplication": True,
|
||||||
@@ -317,10 +282,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(self.cfg_1)
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
||||||
@enable_hf_offline
|
|
||||||
def test_prepare_dataset_with_deduplication_train(self):
|
def test_prepare_dataset_with_deduplication_train(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||||
self.cfg_1.dataset_exact_deduplication = True
|
self.cfg_1.dataset_exact_deduplication = True
|
||||||
@@ -346,8 +308,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"Train dataset should have 2000 samples after deduplication.",
|
"Train dataset should have 2000 samples after deduplication.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
||||||
@enable_hf_offline
|
|
||||||
def test_prepare_dataset_with_deduplication_eval(self):
|
def test_prepare_dataset_with_deduplication_eval(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||||
self.cfg_1.dataset_exact_deduplication = True
|
self.cfg_1.dataset_exact_deduplication = True
|
||||||
@@ -373,8 +333,6 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
|||||||
"Eval dataset should have 2000 samples after deduplication.",
|
"Eval dataset should have 2000 samples after deduplication.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
||||||
@enable_hf_offline
|
|
||||||
def test_prepare_dataset_without_deduplication(self):
|
def test_prepare_dataset_without_deduplication(self):
|
||||||
"""Verify that prepare_dataset function processes the dataset correctly without deduplication."""
|
"""Verify that prepare_dataset function processes the dataset correctly without deduplication."""
|
||||||
self.cfg_1.dataset_exact_deduplication = False
|
self.cfg_1.dataset_exact_deduplication = False
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ from axolotl.utils.data.utils import drop_long_seq_in_dataset
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
@@ -27,7 +25,6 @@ class TestBatchedSamplerPacking:
|
|||||||
Test class for packing streaming dataset sequences
|
Test class for packing streaming dataset sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"batch_size, num_workers",
|
"batch_size, num_workers",
|
||||||
[
|
[
|
||||||
@@ -38,12 +35,11 @@ class TestBatchedSamplerPacking:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||||
@enable_hf_offline
|
|
||||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"winglian/tiny-shakespeare",
|
"Trelis/tiny-shakespeare",
|
||||||
split="train",
|
split="train",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,15 +10,12 @@ from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
|||||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||||
from axolotl.prompters import AlpacaPrompter
|
from axolotl.prompters import AlpacaPrompter
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
|
|
||||||
class TestPacking(unittest.TestCase):
|
class TestPacking(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test class for packing dataset sequences
|
Test class for packing dataset sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
|||||||
@@ -1,60 +1,43 @@
|
|||||||
"""Module for testing streaming dataset sequence packing"""
|
"""Module for testing streaming dataset sequence packing"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import random
|
import unittest
|
||||||
import string
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from datasets import IterableDataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
class TestPretrainingPacking:
|
class TestPretrainingPacking(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test class for packing streaming dataset sequences
|
Test class for packing streaming dataset sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.fixture
|
def setUp(self) -> None:
|
||||||
def random_text(self):
|
# pylint: disable=duplicate-code
|
||||||
# seed with random.seed(0) for reproducibility
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
random.seed(0)
|
self.tokenizer.pad_token = "</s>"
|
||||||
|
|
||||||
# generate row of random text with "words" of between 2 and 10 characters and
|
@pytest.mark.flaky(retries=3, delay=5)
|
||||||
# between 400 to 1200 characters per line
|
def test_packing_stream_dataset(self):
|
||||||
def rand_txt():
|
# pylint: disable=duplicate-code
|
||||||
return " ".join(
|
dataset = load_dataset(
|
||||||
[
|
"allenai/c4",
|
||||||
"".join(
|
"en",
|
||||||
random.choices(string.ascii_lowercase, k=random.randint(2, 10))
|
streaming=True,
|
||||||
)
|
)["train"]
|
||||||
for _ in range(random.randint(50, 200))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a list of 2000 random texts rather than just using it within the
|
|
||||||
# generator so the test runs faster
|
|
||||||
data = [rand_txt() for _ in range(500)]
|
|
||||||
|
|
||||||
# Create an IterableDataset
|
|
||||||
def generator():
|
|
||||||
for row in data:
|
|
||||||
yield {"text": row}
|
|
||||||
|
|
||||||
return IterableDataset.from_generator(generator)
|
|
||||||
|
|
||||||
@pytest.mark.flaky(retries=1, delay=5)
|
|
||||||
def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
|
|
||||||
dataset = random_text
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"pretraining_dataset": [
|
"pretraining_dataset": [
|
||||||
{
|
{
|
||||||
"path": "winglian/tiny-shakespeare",
|
"path": "allenai/c4",
|
||||||
|
"name": "en",
|
||||||
"type": "pretrain",
|
"type": "pretrain",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -71,16 +54,15 @@ class TestPretrainingPacking:
|
|||||||
ds_wrapper_partial = functools.partial(
|
ds_wrapper_partial = functools.partial(
|
||||||
get_dataset_wrapper,
|
get_dataset_wrapper,
|
||||||
cfg.pretraining_dataset[0],
|
cfg.pretraining_dataset[0],
|
||||||
tokenizer_huggyllama,
|
self.tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||||
)
|
)
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
original_bsz = cfg.micro_batch_size
|
original_bsz = cfg.micro_batch_size
|
||||||
train_dataset = wrap_pretraining_dataset(
|
train_dataset = wrap_pretraining_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
tokenizer_huggyllama,
|
self.tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
ds_wrapper_partial,
|
ds_wrapper_partial,
|
||||||
max_tokens=cfg.sequence_len,
|
max_tokens=cfg.sequence_len,
|
||||||
@@ -96,7 +78,7 @@ class TestPretrainingPacking:
|
|||||||
)
|
)
|
||||||
idx = 0
|
idx = 0
|
||||||
for data in trainer_loader:
|
for data in trainer_loader:
|
||||||
if idx > 3:
|
if idx > 10:
|
||||||
break
|
break
|
||||||
assert data["input_ids"].shape == torch.Size(
|
assert data["input_ids"].shape == torch.Size(
|
||||||
[1, original_bsz * cfg.sequence_len]
|
[1, original_bsz * cfg.sequence_len]
|
||||||
@@ -113,3 +95,7 @@ class TestPretrainingPacking:
|
|||||||
# [1, original_bsz * cfg.sequence_len]
|
# [1, original_bsz * cfg.sequence_len]
|
||||||
# )
|
# )
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import logging
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
||||||
|
|
||||||
@@ -23,8 +22,6 @@ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
|||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
@@ -66,7 +63,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
Test class for prompt tokenization strategies.
|
Test class for prompt tokenization strategies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
@@ -123,7 +119,6 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
|||||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
@@ -165,7 +160,6 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
|
|||||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||||
@@ -244,7 +238,6 @@ If a question does not make any sense, or is not factually coherent, explain why
|
|||||||
class OrpoTokenizationTest(unittest.TestCase):
|
class OrpoTokenizationTest(unittest.TestCase):
|
||||||
"""test case for the ORPO tokenization"""
|
"""test case for the ORPO tokenization"""
|
||||||
|
|
||||||
@enable_hf_offline
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
@@ -269,7 +262,6 @@ class OrpoTokenizationTest(unittest.TestCase):
|
|||||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
||||||
).select([0])
|
).select([0])
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
||||||
def test_orpo_integration(self):
|
def test_orpo_integration(self):
|
||||||
strat = load(
|
strat = load(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user