Compare commits

...

53 Commits

Author SHA1 Message Date
NanoCode012
fbf3ca86c9 feat: add support for qwen25 vl for multimodal 2025-02-18 12:42:29 +07:00
Sunny
2de866e92f revert seq len to 8192 2024-12-08 22:30:20 -05:00
Sunny
295e07dcca settings 2024-12-08 22:22:18 -05:00
bursteratom
3c07b6d6b1 lint 2024-12-06 16:06:57 -05:00
bursteratom
89dae7dc6d lora_target_module 2024-12-06 15:41:09 -05:00
bursteratom
1b54af8e54 lora config 2024-12-06 15:27:18 -05:00
bursteratom
ca7b56cba3 lora config 2024-12-06 15:26:06 -05:00
bursteratom
ea8269d2eb lora config 2024-12-06 15:23:24 -05:00
bursteratom
13ca7ed087 comment out lora target 2024-12-06 15:21:08 -05:00
bursteratom
0dfd8541ee lora config qwen2vl 2024-12-06 14:56:51 -05:00
bursteratom
75e1d3537f qwen2_vl get_text_config 2024-12-06 14:54:06 -05:00
bursteratom
2b7f3bd6ab qwen2_vl get_text_config 2024-12-06 14:52:17 -05:00
bursteratom
d85a229afe get_text_config 2024-12-06 14:50:05 -05:00
bursteratom
355cd7c872 update is_multimodal requirement to include qwen2_vl 2024-12-06 14:43:50 -05:00
bursteratom
eab1638686 lint 2024-12-06 14:37:32 -05:00
bursteratom
a3a4d22709 config init qwen2-vl chat template 2024-12-06 14:24:03 -05:00
bursteratom
f9eb7d8663 qwen2 example 2024-12-06 14:22:08 -05:00
bursteratom
343771a6d3 lint 2024-12-06 13:15:49 -05:00
bursteratom
d2c32d0cba lint 2024-12-06 13:04:42 -05:00
bursteratom
cec9887609 add llava chat template to config 2024-12-06 12:57:20 -05:00
bursteratom
88b2cae748 llava template 2024-12-06 12:54:43 -05:00
bursteratom
aea2565938 for test only 2024-12-06 11:54:07 -05:00
bursteratom
1ad56303b2 lint 2024-12-05 15:34:04 -05:00
bursteratom
dc055a4ef7 lint 2024-12-05 14:59:51 -05:00
bursteratom
169116a50f llava example 2024-12-05 12:58:30 -05:00
bursteratom
43e412f660 comment 2024-12-04 13:18:25 -05:00
Wing Lian
7aa57803e1 fix optimizer reset for relora sft (#1414)
* fix optimizer reset

* set states to reset for 8bit optimizers and handle quantile runtime error for embeddings

* fix relora test to check grad_norm

* use flash attn for relora and tweak hyperparams for test

* fix messages field for test dataset
2024-12-04 12:33:29 -05:00
NanoCode012
1969fa3bf0 fix(readme): update cuda instructions during preprocess (#2114) [skip ci] 2024-12-04 12:33:29 -05:00
NanoCode012
4078f37076 feat: add cut_cross_entropy (#2091)
* feat: add cut_cross_entropy

* fix: add to input

* fix: remove from setup.py

* feat: refactor into an integration

* chore: ignore lint

* feat: add test for cce

* fix: set max_steps for liger test

* chore: Update base model following suggestion

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* chore: update special_tokens following suggestion

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* chore: remove with_temp_dir following comments

* fix: plugins aren't loaded

* chore: update quotes in error message

* chore: lint

* chore: lint

* feat: enable FA on test

* chore: refactor get_pytorch_version

* fix: lock cce commit version

* fix: remove subclassing UT

* fix: downcast even if not using FA and config check

* feat: add test to check different attentions

* feat: add install to CI

* chore: refactor to use parametrize for attention

* fix: pytest not detecting test

* feat: handle torch lower than 2.4

* fix args/kwargs to match docs

* use release version cut-cross-entropy==24.11.4

* fix quotes

* fix: use named params for clarity for modal builder

* fix: handle install from pip

* fix: test check only top level module install

* fix: re-add import check

* uninstall existing version if no transformers submodule in cce

* more dataset fixtures into the cache

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-12-04 12:33:29 -05:00
Wing Lian
f073af6d99 fix merge conflict of duplicate max_steps in config for relora (#2116) 2024-12-04 12:33:29 -05:00
Wing Lian
139d2612fa fix so inference can be run against quantized models without adapters (#1834)
* fix so inference can be run against quantized models without adapters

* Update error msg [skip e2e]

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-12-04 12:33:29 -05:00
Sunny Liu
20573fd13e Add ds model card, rebased (#2101) [skip ci]
* rebased add_ds_model_card

* manual rebasing

* fix redundancy

* lint

* include case when ds_tag is none

* conform to kwargs in create_model_card
2024-12-04 12:33:29 -05:00
NanoCode012
2b7b4af81c fix(vlm): handle legacy conversation data format and check image in data (#2018) [skip ci]
* fix: handle legacy conversation data format and check image in data

* feat: add test for llama vision

* feat: add max_steps to test

* fix: incorrect indent and return preprocess

* feat: use smaller model and dataset

* chore: add extra config for sharegpt dataset
2024-12-04 12:33:29 -05:00
Sunny Liu
d56260c8d5 Check torch version for ADOPT optimizer + integrating new ADOPT updates (#2104)
* added torch check for adopt, wip

* lint

* gonna put torch version checking somewhere else

* added ENVcapabilities class for torch version checking

* lint + pydantic

* ENVCapabilities -> EnvCapabilities

* forgot to git add v0_4_1/__init__.py

* removed redundancy

* add check if env_capabilities not specified

* make env_capabilities compulsory [skip e2e]

* fixup env_capabilities

* modified test_validation.py to accomodate env_capabilities

* adopt torch version test [skip e2e]

* raise error

* test correct torch version

* test torch version above requirement

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* removed unused is_totch_min

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-12-04 12:33:29 -05:00
Wing Lian
cac785ec0e use pytest sugar and verbose for more info during ci (#2112) [skip ci]
* use pytest sugar and verbose for more info during ci

* also run test suite when test requirements or cicd.sh changes

* also on PR too
2024-12-04 12:33:29 -05:00
Wing Lian
e62991edef make the eval size smaller for the resume test (#2111) [skip ci] 2024-12-04 12:33:29 -05:00
Wing Lian
fd9e7b55f6 build causal_conv1d and mamba-ssm into the base image (#2113)
* build causal_conv1d and mamba-ssm into the base image

* also build base images on changes to Dockerfile-base and base workflow yaml
2024-12-04 12:33:29 -05:00
Wing Lian
c0c53eb62f various tests fixes for flakey tests (#2110)
* add mhenrichsen/alpaca_2k_test with revision dataset download fixture for flaky tests

* log slowest tests

* pin pynvml==11.5.3

* fix load local hub path

* optimize for speed w smaller models and val_set_size

* replace pynvml

* make the resume from checkpoint e2e faster

* make tests smaller
2024-12-04 12:33:29 -05:00
Oliver Molenschot
b0fbd4d11d Add Exact Deduplication Feature to Preprocessing Pipeline (#2072)
* Add example YAML file for training Mistral using DPO

* added deduplication code

* Add exact deduplication feature and update examples

* Improve deduplication for train/eval overlap

Changed the deduplication function to use a more memory-efficient hashing method. Applied Git suggestions to improve clarity and maintainability.\n\nThe deduplication now handles cases where train and eval datasets have overlapping elements.

* Improve deduplication for train/eval overlap

Changed the deduplication function to use a more memory-efficient hashing method. Applied Git suggestions to improve clarity and maintainability.\n\nThe deduplication now handles cases where train and eval datasets have overlapping elements.

* Apply suggestions from code review

To handle the original case where we do not do deduplication

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Improve false collision detection to ensure dataset integrity

- Added test cases to simulate and verify handling of forced hash collisions between datasets.
- Ensured that datasets with identical hashes but different content are correctly identified, preventing incorrect deduplication.
- Updated unit tests to include scenarios where collisions occur across both training and evaluation datasets, as well as within a single dataset.

* Moved the constants file to the tests folder

- Relocated `constants.py` to the `tests` folder to improve modularity and maintain a clear separation between source and test files.
- Renamed `cicd/tests.py` to `cicd/cicd_tests.py` to resolve a conflict with `tests/__init__.py`, which caused Mypy to fail due to duplicate module names.
- Updated all references to `cicd.tests` in the codebase to `cicd.cicd_tests` to reflect the renaming and ensure compatibility.
- These changes ensure Mypy passes the pre-commit hook and maintain alignment with the project's structure.

* revert some changes from previous commit and fix relative import

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2024-12-04 12:33:29 -05:00
Wing Lian
1a70d4d6a4 add e2e tests for Unsloth qlora and test the builds (#2093)
* see if unsloth installs cleanly in ci

* check unsloth install on regular tests, not sdist

* fix ampere check exception for ci

* use cached_property instead

* add an e2e test for unsloth qlora

* reduce seq len and mbsz to prevent oom in ci

* add checks for fp16 and sdp_attention

* pin unsloth to a specific release

* add unsloth to docker image too

* fix flash attn xentropy patch

* fix loss, add check for loss when using fa_xentropy

* fix special tokens for test

* typo

* test fa xentropy with and without gradient accum

* pr feedback changes
2024-12-04 12:33:29 -05:00
Wing Lian
d8787a433f support seperate lr for embeddings, similar to loraplus (#1910) [skip ci]
* support seperate lr for embeddings, similar to loraplus

* add test case for train w lr embedding scale

* use kwarg for optimizer

* make sure to handle the optimizer creation

* make sure to handle for embedding_lr too

* use smollm for e2e, check for embeddings lr first before wdecay
2024-12-04 12:33:29 -05:00
NanoCode012
e775422269 fix: ds3 and fsdp lmbench eval (#2102) [ski[p ci]
* fix: ds3 and fsdp lmbench eval

* chore: update comment

* fix: test signature
2024-12-04 12:33:29 -05:00
Wing Lian
97178f5960 add finetome dataset to fixtures, check eval_loss in test (#2106) [skip ci]
* add finetome dataset to fixtures, check eval_loss in test

* add qwen 0.5b to pytest session fixture
2024-12-04 12:33:29 -05:00
bursteratom
4698eed43f set pixtral chat template 2024-12-04 12:11:21 -05:00
bursteratom
f84c3b37e7 lint 2024-12-04 11:59:45 -05:00
bursteratom
c39971c659 stuff 2024-11-27 10:52:36 -05:00
bursteratom
33a178c788 val config pixtral chat template 2024-11-27 10:36:23 -05:00
bursteratom
db15605e7e pixral chat template 2024-11-27 10:34:19 -05:00
bursteratom
9e112bc8b5 lint 2024-11-27 10:33:35 -05:00
bursteratom
e038410778 lint 2024-11-27 10:24:37 -05:00
bursteratom
f4385c3cf4 add special tokens 2024-11-27 10:18:45 -05:00
bursteratom
d58c772df6 pixtral flash-attn false 2024-11-27 10:16:17 -05:00
bursteratom
69265a53b5 stuff 2024-11-27 09:53:41 -05:00
66 changed files with 3019 additions and 403 deletions

View File

@@ -1,6 +1,16 @@
name: ci-cd-base name: ci-cd-base
on: on:
push:
branches:
- "main"
paths:
- 'Dockerfile-base'
- '.github/workflows/base.yml'
pull_request:
paths:
- 'Dockerfile-base'
- '.github/workflows/base.yml'
workflow_dispatch: workflow_dispatch:
jobs: jobs:

View File

@@ -55,6 +55,7 @@ jobs:
pip3 install --upgrade pip pip3 install --upgrade pip
pip3 install --upgrade packaging pip3 install --upgrade packaging
pip3 install -U -e . pip3 install -U -e .
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests - name: Run tests

View File

@@ -8,11 +8,15 @@ on:
- '**.py' - '**.py'
- 'requirements.txt' - 'requirements.txt'
- '.github/workflows/*.yml' - '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
pull_request: pull_request:
paths: paths:
- '**.py' - '**.py'
- 'requirements.txt' - 'requirements.txt'
- '.github/workflows/*.yml' - '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
workflow_dispatch: workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered # Cancel jobs on the same ref if a new one is triggered
@@ -67,6 +71,8 @@ jobs:
run: | run: |
pip3 show torch pip3 show torch
pip3 install -U -e . pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests - name: Run tests

View File

@@ -147,7 +147,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
### Usage ### Usage
```bash ```bash
# preprocess datasets - optional but recommended # preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
# finetune lora # finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml

View File

@@ -37,6 +37,9 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image # So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt RUN pip install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
set -e set -e
pytest -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -40,6 +40,7 @@ with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
cicd_image = ( cicd_image = (
Image.from_dockerfile( Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile", pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True, force_build=True,
gpu="A10G", gpu="A10G",
) )

View File

@@ -26,6 +26,9 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi fi
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image # So we can test the Docker image
RUN pip install pytest RUN pip install pytest

View File

@@ -29,7 +29,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \ RUN git lfs install --skip-repo && \
pip3 install awscli && \ pip3 install awscli && \

View File

@@ -162,6 +162,9 @@ datasets:
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. # The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true shuffle_merged_datasets: true
Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true
# A list of one or more datasets to eval the model with. # A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both. # You can use either test_datasets, or val_set_size, but not both.
test_datasets: test_datasets:
@@ -406,7 +409,7 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused # - adamw_torch_fused
# - adamw_torch_xla # - adamw_torch_xla
# - adamw_apex_fused # - adamw_apex_fused
# - adopt_adamw (only for torch version >= 2.5.1) # - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adafactor # - adafactor
# - adamw_anyprecision # - adamw_anyprecision
# - sgd # - sgd

View File

@@ -0,0 +1,95 @@
base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
dataset_exact_deduplication: true
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -0,0 +1,76 @@
base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/lora-out
dataset_exact_deduplication: true
test_value: true
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,63 @@
base_model: llava-hf/llava-1.5-7b-hf
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llava
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -0,0 +1,65 @@
base_model: mistral-community/pixtral-12b
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: pixtral
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,63 @@
base_model: Qwen/Qwen2-VL-7B-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen2_vl
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -2,4 +2,3 @@ pre-commit
black black
mypy mypy
types-requests types-requests
tbparse

View File

@@ -1,3 +1,5 @@
pytest pytest
pytest-xdist pytest-xdist
pytest-retry pytest-retry
pytest-sugar
tbparse

View File

@@ -26,7 +26,7 @@ numpy>=1.24.4,<=2.0.1
evaluate==0.4.1 evaluate==0.4.1
scipy scipy
scikit-learn==1.4.2 scikit-learn==1.4.2
pynvml nvidia-ml-py==12.560.30
art art
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard

View File

@@ -0,0 +1,28 @@
"""Script to output the correct installation command for cut-cross-entropy."""
import importlib.util
import sys
try:
import torch
except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V
v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0
if v < V("2.4.0"):
print("")
sys.exit(0)
cce_spec = importlib.util.find_spec("cut_cross_entropy")
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")
UNINSTALL_PREFIX = ""
if cce_spec and not cce_spec_transformers:
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
)

View File

@@ -8,7 +8,10 @@ from packaging.version import Version as V
v = V(torch.__version__) v = V(torch.__version__)
cuda = str(torch.version.cuda) cuda = str(torch.version.cuda)
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8 is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!") raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"): if v <= V("2.1.0"):
@@ -29,5 +32,5 @@ else:
raise RuntimeError(f"Torch = {v} too new!") raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print( print(
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"' f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
) )

View File

@@ -27,7 +27,6 @@ from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import ( from axolotl.utils.chat_templates import (
@@ -38,6 +37,7 @@ from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
normalize_cfg_datasets, normalize_cfg_datasets,
normalize_config, normalize_config,
prepare_plugins,
validate_config, validate_config,
) )
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
@@ -100,8 +100,8 @@ def print_dep_versions():
print("*" * 40) print("*" * 40)
print("**** Axolotl Dependency Versions *****") print("**** Axolotl Dependency Versions *****")
for pkg in packages: for pkg in packages:
version = _is_package_available(pkg, return_version=True) pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}") print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
print("*" * 40) print("*" * 40)
@@ -426,11 +426,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg.axolotl_config_path = config cfg.axolotl_config_path = config
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
try: try:
device_props = torch.cuda.get_device_properties("cuda") device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
@@ -444,8 +439,13 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)), "n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version, "compute_capability": gpu_version,
}, },
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
) )
prepare_plugins(cfg)
prepare_optim_env(cfg) prepare_optim_env(cfg)
prepare_opinionated_env(cfg) prepare_opinionated_env(cfg)

View File

@@ -19,7 +19,7 @@ from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs): def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg.sample_packing = False parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -107,6 +107,22 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
return kwargs return kwargs
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs
@dataclass @dataclass
class AxolotlTrainingMixins: class AxolotlTrainingMixins:
""" """
@@ -220,6 +236,14 @@ class AxolotlTrainingMixins:
default=1e-6, default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."}, metadata={"help": "loraplus learning rate for lora embedding layers."},
) )
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field( qlora: bool = field(
default=False, default=False,
metadata={"help": "whether this is a qlora training"}, metadata={"help": "whether this is a qlora training"},
@@ -386,7 +410,7 @@ class SchedulerMixin(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio, min_lr_ratio=self.args.cosine_min_lr_ratio,
) )
else: else:
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer=optimizer)
else: else:
if use_cosine_quadratic: if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
@@ -410,10 +434,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
*_args, *_args,
bench_data_collator=None, bench_data_collator=None,
eval_data_collator=None, eval_data_collator=None,
dataset_tags=None,
**kwargs, **kwargs,
): ):
self.bench_data_collator = bench_data_collator self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
super().__init__(*_args, **kwargs) super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
@@ -435,6 +461,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def create_optimizer(self): def create_optimizer(self):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.alternate_optimizer and self.args.alternate_optimizer
not in [ not in [
"optimi_adamw", "optimi_adamw",
@@ -449,30 +477,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model) decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [ params = {
{ "to_weight_decay": {}, # LayerNorm and bias
"params": [ "embeddings": {}, # lm_head, embed_tokens,
p "no_weight_decay": {},
for n, p in opt_model.named_parameters() }
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args, self.args,
opt_model, opt_model,
) )
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
if self.args.loraplus_lr_ratio is not None: if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr( loraplus_lr_embedding = getattr(
@@ -485,6 +542,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_embedding=loraplus_lr_embedding, loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs, **optimizer_kwargs,
) )
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw": elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW from optimi import AdamW
@@ -516,7 +580,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT( ADOPT(
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
) )
) )
@@ -871,6 +937,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
""" """
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
@@ -994,8 +1063,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
tag_names = ["axolotl", "dpo"] tag_names = ["axolotl", "dpo"]
def __init__(self, *args, **kwargs): def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None self.optimizer = None
def create_optimizer(self): def create_optimizer(self):
@@ -1034,6 +1104,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
""" """
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
@@ -1571,6 +1644,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"loraplus_lr_embedding" "loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding ] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[ training_arguments_kwargs[
@@ -1755,6 +1831,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else: else:
trainer_kwargs["tokenizer"] = self.tokenizer trainer_kwargs["tokenizer"] = self.tokenizer
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls( trainer = trainer_cls(
model=self.model, model=self.model,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
@@ -1817,6 +1897,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = MultiModalChatDataCollator collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template kwargs["chat_template"] = training_args.chat_template
kwargs["chat_template_type"] = self.cfg.chat_template
else: else:
collator = DataCollatorForSeq2Seq collator = DataCollatorForSeq2Seq
@@ -2028,6 +2109,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else: else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer dpo_trainer_kwargs["tokenizer"] = self.tokenizer
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls( dpo_trainer = trainer_cls(
*trainer_cls_args, *trainer_cls_args,
args=training_args, args=training_args,

View File

@@ -40,7 +40,7 @@ class TRLPPOTrainer(PPOTrainer):
query_tensors, query_tensors,
return_prompt=False, return_prompt=False,
generate_ref_response=True, generate_ref_response=True,
**generation_kwargs **generation_kwargs,
) )
batch["response"] = self.tokenizer.batch_decode(response_tensors) batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors) batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)

View File

@@ -0,0 +1,325 @@
Acknowledgements
Portions of this Cut Cross Entropy Software may utilize the following copyrighted
material, the use of which is hereby acknowledged.
------
PyTorch
From PyTorch:
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
From Caffe2:
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.
All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.
All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.
All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain
All contributions by Cruise LLC:
Copyright (c) 2022 Cruise LLC.
All rights reserved.
All contributions by Arm:
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.
All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.
Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
mark their specific copyright on a particular contribution, they should
indicate their copyright solely in the commit message of the change when it is
committed.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
Triton
/*
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
Transformers
Copyright 2018- The Hugging Face team. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,47 @@
Copyright (C) 2024 Apple Inc. All Rights Reserved.
IMPORTANT: This Apple software is supplied to you by Apple
Inc. ("Apple") in consideration of your agreement to the following
terms, and your use, installation, modification or redistribution of
this Apple software constitutes acceptance of these terms. If you do
not agree with these terms, please do not use, install, modify or
redistribute this Apple software.
In consideration of your agreement to abide by the following terms, and
subject to these terms, Apple grants you a personal, non-exclusive
license, under Apple's copyrights in this original Apple software (the
"Apple Software"), to use, reproduce, modify and redistribute the Apple
Software, with or without modifications, in source and/or binary forms;
provided that if you redistribute the Apple Software in its entirety and
without modifications, you must retain this notice and the following
text and disclaimers in all such redistributions of the Apple Software.
Neither the name, trademarks, service marks or logos of Apple Inc. may
be used to endorse or promote products derived from the Apple Software
without specific prior written permission from Apple. Except as
expressly stated in this notice, no other rights or licenses, express or
implied, are granted by Apple herein, including but not limited to any
patent rights that may be infringed by your derivative works or by other
works in which the Apple Software may be incorporated.
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
-------------------------------------------------------------------------------
SOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY:
The Cut Cross Entropy software includes a number of subcomponents with separate
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md.
-------------------------------------------------------------------------------

View File

@@ -0,0 +1,10 @@
# Cut Cross Entropy
### Usage
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
```

View File

@@ -0,0 +1,83 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for the Plugin for Cut Cross Entropy integration with Axolotl.
Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
import logging
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from ...utils.distributed import zero_only
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
)
class CutCrossEntropyPlugin(BasePlugin):
"""
Plugin for Cut Cross Entropy integration with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.cut_cross_entropy.CutCrossEntropyArgs"
def _check_requirements(self):
"""Check if all requirements are met."""
# Check PyTorch version
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
raise ImportError(
"Cut Cross Entropy requires PyTorch >= 2.4.0. "
f"Current version: {torch.__version__}"
)
# Check if cut_cross_entropy is installed
cce_spec = importlib.util.find_spec("cut_cross_entropy")
if cce_spec is None:
raise ImportError(_CCE_INSTALL_MESSAGE)
cce_spec_transformers = importlib.util.find_spec(
"cut_cross_entropy.transformers"
)
if cce_spec_transformers is None:
raise ImportError(_CCE_INSTALL_MESSAGE)
def pre_model_load(self, cfg):
"""Apply cut cross entropy before model loading if enabled."""
if cfg.cut_cross_entropy:
self._check_requirements()
from cut_cross_entropy.transformers import cce_patch
with zero_only():
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
# The patch checks model_type internally
cce_patch(cfg.model_config_type)

View File

@@ -0,0 +1,42 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for handling Cut Cross Entropy input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
class CutCrossEntropyArgs(BaseModel):
"""
Input args for Cut Cross Entropy.
"""
cut_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_dtype_is_half(cls, data):
if not (data.get("bf16") or data.get("fp16")):
raise ValueError(
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
"Please set `bf16` or `fp16` to `True`."
)
return data

View File

@@ -4,7 +4,6 @@
import logging import logging
import warnings import warnings
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -94,13 +93,32 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv) set_module_name(model, name, qkv)
def patch_llama_cross_entropy(): def patch_fa_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss LOG.info(
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
) )
from flash_attn.ops.triton.cross_entropy import (
cross_entropy_loss as flash_attn_cross_entropy_loss,
)
def fa2_fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
loss, _ = flash_attn_cross_entropy_loss(
source, target, ignore_index=ignore_index
)
if reduction == "sum":
loss = loss.sum() / num_items_in_batch
else:
loss = loss.sum() / (target != ignore_index).sum()
return loss
transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy
def patch_llama_rms_norm(): def patch_llama_rms_norm():
@@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled # skip only if explicitly disabled
if cross_entropy: if cross_entropy:
patch_llama_cross_entropy() patch_fa_llama_cross_entropy()
# skip only if explicitly disabled # skip only if explicitly disabled
if rms_norm: if rms_norm:

View File

@@ -46,9 +46,10 @@ def reset_optimizer(
*, *,
reset_params: List[str], # where str is the key to a torch.nn.Parameter reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str], optimizer_state_keys: List[str],
prune_ratio: float = 0.9, optimizer_magnitude_pruning: float = 0.9,
): ):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) # pylint:disable=unused-argument
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
n_zeros = 0 n_zeros = 0
n_total = 0 n_total = 0
@@ -56,16 +57,22 @@ def reset_optimizer(
if isinstance(optimizer, ZeroRedundancyOptimizer): if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state optimizer_state = optimizer.optim.state
for param in reset_params: for group in optimizer.param_groups:
param_state = optimizer_state[param] for param in group["params"]:
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer state = optimizer_state[param]
for key, value in state.items():
if key not in optimizer_state_keys:
continue continue
for key in optimizer_state_keys: if torch.is_tensor(value):
pruning_fn( try:
param_state[key] pruning_fn(value)
) # pruning fn has to be inplace to keep the same keys in the dict n_total += value.numel()
n_total += param_state[key].numel() n_zeros += torch.sum(value == 0).item()
n_zeros += torch.sum(param_state[key] == 0).item() except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
pass
else:
raise exc
_zeroed = n_zeros / (1e-7 + n_total) * 100 _zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}") LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
@@ -129,6 +136,9 @@ class ReLoRACallback(TrainerCallback):
if "adam" in args.optim.lower(): if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"] optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
if "8bit" in args.optim.lower():
optimizer_state_keys.append("state1")
optimizer_state_keys.append("state2")
else: else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA") raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
@@ -160,7 +170,7 @@ class ReLoRACallback(TrainerCallback):
optimizer, optimizer,
reset_params=lora_params, reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys, optimizer_state_keys=optimizer_state_keys,
prune_ratio=args.relora_prune_ratio, optimizer_magnitude_pruning=args.relora_prune_ratio,
) )
if self.quantized: if self.quantized:

View File

@@ -259,11 +259,31 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id: if not cfg.hub_model_id:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
try: try:
trainer.create_model_card( # Check to make sure the base model is from HuggingFace not a local directory
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8") hf_api = HfApi()
) hf_api.model_info(cfg.base_model)
except (AttributeError, UnicodeDecodeError):
model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model:
model_card_kwarg["dataset_name"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
else:
model_card_kwarg["dataset_tags"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
pass pass
elif cfg.hub_model_id: elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated # defensively push to the hub to ensure the model card is updated

View File

@@ -1,7 +1,11 @@
""" """
Basic utils for Axolotl Basic utils for Axolotl
""" """
import importlib.util import importlib.util
import re
import torch
def is_mlflow_available(): def is_mlflow_available():
@@ -10,3 +14,23 @@ def is_mlflow_available():
def is_comet_available(): def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None return importlib.util.find_spec("comet_ml") is not None
# pylint: disable=duplicate-code
def get_pytorch_version() -> tuple[int, int, int]:
"""
Get Pytorch version as a tuple of (major, minor, patch).
"""
torch_version = torch.__version__
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if not version_match:
raise ValueError("Invalid version format")
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = int(patch) if patch is not None else 0 # Default patch to 0 if not present
return major, minor, patch
# pylint: enable=duplicate-code

View File

@@ -1,13 +1,24 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools import functools
import pynvml
import torch import torch
from pynvml.nvml import NVMLError
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import get_device_type from axolotl.utils.distributed import get_device_type
try:
from pynvml import (
NVMLError,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlInit,
)
except ImportError:
NVMLError = None
nvmlDeviceGetHandleByIndex = None
nvmlDeviceGetMemoryInfo = None
nvmlInit = None
def check_cuda_device(default_value): def check_cuda_device(default_value):
""" """
@@ -68,10 +79,12 @@ def gpu_memory_usage_smi(device=0):
device = device.index device = device.index
if isinstance(device, str) and device.startswith("cuda:"): if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:]) device = int(device[5:])
if not nvmlInit:
return 0.0
try: try:
pynvml.nvmlInit() nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device) handle = nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle) info = nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3 return info.used / 1024.0**3
except NVMLError: except NVMLError:
return 0.0 return 0.0

View File

@@ -28,6 +28,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
@@ -46,6 +47,7 @@ from axolotl.utils.distributed import (
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments from axolotl.core.trainer_builder import AxolotlTrainingArguments
IGNORE_INDEX = -100 IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
@@ -64,7 +66,10 @@ class EvalFirstStepCallback(
control: TrainerControl, control: TrainerControl,
**kwargs, **kwargs,
): ):
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1: if (
args.evaluation_strategy == IntervalStrategy.STEPS
and state.global_step == 1
):
control.should_evaluate = True control.should_evaluate = True
return control return control
@@ -375,7 +380,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
for metric in self.cfg.eval_causal_lm_metrics: for metric in self.cfg.eval_causal_lm_metrics:
if metric == "perplexity": if metric == "perplexity":
max_seq_len = self.cfg.eval_max_new_tokens max_seq_len = self.cfg.eval_max_new_tokens
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len) metrics[metric] = Perplexity(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
)
else: else:
try: try:
metrics[metric] = evaluate.load(metric) metrics[metric] = evaluate.load(metric)
@@ -392,8 +400,11 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
eval_dataloader, eval_dataloader,
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
trainer.model.eval() trainer.model_wrapped.eval()
device = torch.device(self.cfg.device)
device = torch.device(
self.cfg.device
) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
generation_config = GenerationConfig( generation_config = GenerationConfig(
@@ -430,6 +441,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
for k in metric._feature_names() # pylint: disable=protected-access for k in metric._feature_names() # pylint: disable=protected-access
if k in kwargs if k in kwargs
} }
if isinstance(metric, Perplexity):
metric_kwargs["model"] = trainer.model_wrapped
metric_score = metric.compute(**metric_kwargs) metric_score = metric.compute(**metric_kwargs)
return ( return (
metric_score["score"] metric_score["score"]
@@ -465,7 +480,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
def predict_with_generate(): def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], [] eval_src, eval_pred, eval_ref = [], [], []
for batch in tqdm(eval_dataloader): with unwrap_model_for_generation(
trainer.model_wrapped, trainer.accelerator
) as unwrapped_model:
for batch in tqdm(eval_dataloader, disable=not is_main_process()):
batch_labels = batch["labels"].to(device) batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device) batch_input_ids = batch["input_ids"].to(device)
@@ -497,7 +515,9 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
tokens_without_loss = labels == IGNORE_INDEX tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id tokens_exclude_padding = (
input_ids != tokenizer.pad_token_id
)
prompt_token_includes = ( prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding tokens_without_loss & tokens_exclude_padding
) )
@@ -518,11 +538,14 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
with torch.no_grad(): with torch.no_grad():
prompt_encoding = tokenizer( prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt" prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device) ).to(device)
predictions = trainer.model.generate(
predictions = unwrapped_model.generate(
**prompt_encoding, generation_config=generation_config **prompt_encoding, generation_config=generation_config
) )
del prompt_encoding
prediction_all_tokens = predictions["sequences"].cpu().tolist() prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = [] prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip( for prompt_token_ids, prediction_tokens in zip(
@@ -536,7 +559,8 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
) )
predicted_texts = tokenizer.batch_decode( predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True prediction_without_prompt_tokens_list,
skip_special_tokens=True,
) )
eval_src.extend(prompt_texts) eval_src.extend(prompt_texts)
@@ -545,7 +569,6 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
return eval_src, eval_pred, eval_ref return eval_src, eval_pred, eval_ref
if is_main_process():
eval_preds = predict_with_generate() eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds)) trainer.log(evaluate_preds(*eval_preds))

View File

@@ -8,6 +8,8 @@ from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.utils.distributed import is_main_process
class Perplexity: class Perplexity:
""" """
@@ -17,16 +19,13 @@ class Perplexity:
def __init__( def __init__(
self, self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
max_seq_len: int, max_seq_len: int,
stride: int = 512, stride: int = 512,
) -> None: ) -> None:
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.stride = stride self.stride = stride
self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.device = model.device
self.name = "perplexity" self.name = "perplexity"
def _feature_names(self) -> List[str]: def _feature_names(self) -> List[str]:
@@ -34,6 +33,7 @@ class Perplexity:
def compute( def compute(
self, self,
model: PreTrainedModel,
references: Optional[List[str]] = None, references: Optional[List[str]] = None,
) -> Dict[str, float]: ) -> Dict[str, float]:
""" """
@@ -41,17 +41,21 @@ class Perplexity:
""" """
assert references is not None, "Missing parameter: references" assert references is not None, "Missing parameter: references"
model.eval()
references_tokenized = self.tokenizer( references_tokenized = self.tokenizer(
references, return_tensors="pt", padding=True, truncation=True references, return_tensors="pt", padding=True, truncation=True
) )
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
input_ids = input_ids.to(self.device) input_ids = input_ids.to(model.device)
sequence_length = input_ids.size(1) sequence_length = input_ids.size(1)
losses = [] losses = []
prev_end_loc = 0 prev_end_loc = 0
for begin_loc in tqdm(range(0, sequence_length, self.stride)): for begin_loc in tqdm(
range(0, sequence_length, self.stride), disable=not is_main_process()
):
end_loc = min(begin_loc + self.max_seq_len, sequence_length) end_loc = min(begin_loc + self.max_seq_len, sequence_length)
trg_len = end_loc - prev_end_loc trg_len = end_loc - prev_end_loc
input_ids_slice = input_ids[:, begin_loc:end_loc] input_ids_slice = input_ids[:, begin_loc:end_loc]
@@ -59,7 +63,7 @@ class Perplexity:
labels_slice[:, :-trg_len] = -100 labels_slice[:, :-trg_len] = -100
with torch.no_grad(): with torch.no_grad():
outputs: CausalLMOutput = self.model( outputs: CausalLMOutput = model(
input_ids=input_ids_slice, labels=labels_slice input_ids=input_ids_slice, labels=labels_slice
) )

File diff suppressed because one or more lines are too long

View File

@@ -1,8 +1,10 @@
""" """
Collators for multi-modal chat messages and packing Collators for multi-modal chat messages and packing
""" """
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from PIL import Image from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin from transformers import PreTrainedTokenizerBase, ProcessorMixin
@@ -20,6 +22,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
processor: ProcessorMixin processor: ProcessorMixin
return_tensors: str = "pt" return_tensors: str = "pt"
chat_template: Optional[str] = None chat_template: Optional[str] = None
chat_template_type: Optional[str] = None
packing: bool = False packing: bool = False
max_images: int = -1 max_images: int = -1
padding: Union[bool, str, PaddingStrategy] = True padding: Union[bool, str, PaddingStrategy] = True
@@ -30,38 +33,190 @@ class MultiModalChatDataCollator(DataCollatorMixin):
raise ValueError("Packing is currently not supported.") raise ValueError("Packing is currently not supported.")
def torch_call( def torch_call(
self, examples: List[Union[List[int], Any, Dict[str, Any]]] self, examples: list[Union[list[int], Any, dict[str, Any]]]
) -> Dict[str, Any]: ) -> dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor. # Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows( return self.__class__.process_rows(
examples, self.processor, self.chat_template, self.max_images examples,
self.processor,
self.chat_template,
self.max_images,
chat_template_type=self.chat_template_type,
) )
@staticmethod @staticmethod
def process_rows(examples, processor, chat_template, max_images, length_only=False): def preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
return {"messages": messages, **result}
processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)
# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))
else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)
return processed_examples
@staticmethod
def process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""
def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images
images = [get_image(example) for example in examples]
# Count None and non-None images
none_count = sum(1 for img in images if img is None)
# All images are None
if none_count == len(images):
return None
# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)
# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]
return images
@staticmethod
def pixtral_chat_conversion(messages):
is_single_message = not isinstance(messages, list)
if is_single_message:
messages = [messages]
for i, message in enumerate(messages):
if message["role"] == "user":
for j, content in enumerate(message["content"]):
if "type" in content and content["type"] == "text":
messages[i]["content"][j] = {
"type": "text",
"content": content["text"],
}
if message["role"] == "assistant":
messages[i]["content"] = message["content"][0]["text"]
if is_single_message:
return messages[0]
return messages
@staticmethod
def process_rows(
examples,
processor,
chat_template,
max_images,
length_only=False,
chat_template_type=None,
):
# HINT: use `_torch_collate_batch` to stack and pad tensors # HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator # see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code *** # *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point # use this as a starting point
# Preprocess the examples
examples = __class__.preprocess(examples)
# Get the texts and images, and apply the chat template # Get the texts and images, and apply the chat template
if chat_template_type == "pixtral":
texts = [
processor.apply_chat_template(
__class__.pixtral_chat_conversion(example["messages"]),
chat_template=chat_template,
tokenize=False,
)
for example in examples
]
else:
texts = [ texts = [
processor.apply_chat_template( processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False example["messages"], chat_template=chat_template, tokenize=False
) )
for example in examples for example in examples
] ]
images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
if max_images > 0: images = __class__.process_images(examples, max_images=max_images)
images = [img_batch[:max_images] for img_batch in images] if chat_template_type == "llava":
# LLava1.5 does not support multiple images
images = [image[0] for image in images]
# Tokenize the texts and process the images # Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True) batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
@@ -70,6 +225,9 @@ class MultiModalChatDataCollator(DataCollatorMixin):
labels = batch["input_ids"].clone() labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 # labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific) # Ignore the image token index in the loss computation (model specific)
if chat_template_type == "qwen2_vl":
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
else:
image_token_id = processor.tokenizer.convert_tokens_to_ids( image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token processor.image_token
) )

View File

@@ -7,6 +7,7 @@ import torch
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import ( from axolotl.utils.config.models.input.v0_4_1 import (
@@ -131,7 +132,7 @@ def normalize_config(cfg):
cfg.is_multimodal = ( cfg.is_multimodal = (
hasattr(model_config, "model_type") hasattr(model_config, "model_type")
and model_config.model_type in ["llava", "mllama"] and model_config.model_type in ["llava", "mllama", "qwen2_vl", "qwen2_5_vl"]
or any( or any(
multimodal_name in cfg.base_model.lower() multimodal_name in cfg.base_model.lower()
for multimodal_name in [ for multimodal_name in [
@@ -144,7 +145,12 @@ def normalize_config(cfg):
cfg.processor_config = ( cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model cfg.processor_config or cfg.base_model_config or cfg.base_model
) )
try:
model_config = model_config.text_config model_config = model_config.text_config
except AttributeError:
# for qwen2_vl
model_config = model_config.get_text_config()
cfg.model_config_type = model_config.model_type cfg.model_config_type = model_config.model_type
@@ -229,7 +235,11 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase AxolotlInputConfig = AxolotlInputConfigBase
@@ -239,14 +249,35 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
AxolotlInputConfig, # pylint: disable=invalid-name AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args() ) = merge_input_args()
if capabilities: if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."
)
return DictDefault( return DictDefault(
dict( dict(
AxolotlConfigWCapabilities( AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities **cfg.to_dict(),
capabilities=capabilities,
env_capabilities=env_capabilities,
).model_dump(exclude_none=True) ).model_dump(exclude_none=True)
) )
) )
return DictDefault( return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True)) dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
) )
def prepare_plugins(cfg):
"""
Prepare the plugins for the configuration
"""
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)

View File

@@ -9,6 +9,7 @@ import os
from enum import Enum from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from packaging import version
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
Field, Field,
@@ -21,7 +22,7 @@ from transformers import SchedulerType
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import GPUCapabilities from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
LOG = logging.getLogger("axolotl.utils.config.models.input") LOG = logging.getLogger("axolotl.utils.config.models.input")
@@ -50,6 +51,7 @@ class ChatTemplate(str, Enum):
cohere = "cohere" # pylint: disable=invalid-name cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
llava = "llava" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
@@ -59,6 +61,8 @@ class ChatTemplate(str, Enum):
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name metharme = "metharme" # pylint: disable=invalid-name
pixtral = "pixtral" # pylint: disable=invalid-name
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel): class DeprecatedParameters(BaseModel):
@@ -322,11 +326,13 @@ class LoraConfig(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_adapter(cls, data): def validate_adapter(cls, data):
if not data.get("adapter") and ( if (
data.get("load_in_8bit") or data.get("load_in_4bit") not data.get("adapter")
and not data.get("inference")
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
): ):
raise ValueError( raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter." "load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit." "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
) )
return data return data
@@ -430,6 +436,8 @@ class HyperparametersConfig(BaseModel):
group_by_length: Optional[bool] = None group_by_length: Optional[bool] = None
learning_rate: Union[str, float] learning_rate: Union[str, float]
embedding_lr: Optional[float] = None
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[
Union[ Union[
@@ -622,6 +630,7 @@ class AxolotlInputConfig(
json_schema_extra={"description": "streaming dataset to use for pretraining"}, json_schema_extra={"description": "streaming dataset to use for pretraining"},
) )
dataset_processes: Optional[int] = Field(default=os.cpu_count()) dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None dataloader_pin_memory: Optional[bool] = None
dataloader_num_workers: Optional[int] = None dataloader_num_workers: Optional[int] = None
@@ -1474,6 +1483,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options""" """wrapper to valdiate gpu capabilities with the configured options"""
capabilities: GPUCapabilities capabilities: GPUCapabilities
env_capabilities: EnvCapabilities
@model_validator(mode="after") @model_validator(mode="after")
def check_bf16(self): def check_bf16(self):
@@ -1548,3 +1558,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training." "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data

View File

@@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel):
n_gpu: int = Field(default=1) n_gpu: int = Field(default=1)
n_node: int = Field(default=1) n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None) compute_capability: Optional[str] = Field(default=None)
class EnvCapabilities(BaseModel):
"""model to manage the environment capabilities statically"""
torch_version: Optional[str] = Field(default=None)

View File

@@ -13,7 +13,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.utils import md5 from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
@@ -208,4 +208,9 @@ def load_prepare_dpo_datasets(cfg):
if eval_dataset and not eval_is_preprocessed: if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
)
return train_dataset, eval_dataset return train_dataset, eval_dataset

View File

@@ -44,7 +44,7 @@ from axolotl.prompters import (
UnsupportedPrompter, UnsupportedPrompter,
) )
from axolotl.utils.data.pretraining import wrap_pretraining_dataset from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import md5 from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
@@ -136,8 +136,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets")
return train_dataset, eval_dataset, cfg.max_steps, prompters return train_dataset, eval_dataset, cfg.max_steps, prompters
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0: if total_eval_steps == 0:
@@ -584,7 +585,8 @@ def load_prepare_datasets(
) )
train_fingerprint = md5(to_hash_train) train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test) test_fingerprint = md5(to_hash_test)
if cfg.dataset_exact_deduplication:
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
dataset = dataset.train_test_split( dataset = dataset.train_test_split(
test_size=val_set_size, test_size=val_set_size,
shuffle=False, shuffle=False,
@@ -596,12 +598,17 @@ def load_prepare_datasets(
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]
elif split == "test": elif split == "test":
train_dataset = None if cfg.dataset_exact_deduplication:
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
else:
eval_dataset = dataset eval_dataset = dataset
train_dataset = None
else:
if cfg.dataset_exact_deduplication:
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
else: else:
train_dataset = dataset train_dataset = dataset
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, prompters return train_dataset, eval_dataset, prompters

View File

@@ -1,6 +1,11 @@
"""data handling helpers""" """data handling helpers"""
import hashlib import hashlib
import logging
from datasets import Dataset
LOG = logging.getLogger("axolotl")
def md5(to_hash: str, encoding: str = "utf-8") -> str: def md5(to_hash: str, encoding: str = "utf-8") -> str:
@@ -8,3 +13,96 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError: except TypeError:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
def deduplicate_dataset(
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
) -> Dataset:
unique_indices = []
for idx, row in enumerate(dataset):
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
if row_hash not in seen_hashes:
seen_hashes[row_hash] = [idx]
unique_indices.append(idx)
else:
# Check for collision by looking up the original dataset indices
original_indices = seen_hashes[row_hash]
is_duplicate = False
for original_idx in original_indices:
if (
not idx == original_idx
and original_idx < len(dataset)
and str(dataset[original_idx]) == str(row)
):
is_duplicate = True
break
# Check in the other dataset if provided
if other_dataset is not None:
if original_idx < len(other_dataset) and str(
other_dataset[original_idx]
) == str(row):
is_duplicate = True
break
if not is_duplicate:
seen_hashes[row_hash].append(idx)
unique_indices.append(idx)
continue
return dataset.select(unique_indices)
def deduplicate_and_log_datasets(
*,
train_dataset: Dataset = None,
eval_dataset: Dataset = None,
dataset: Dataset = None,
) -> tuple[Dataset, Dataset, Dataset]:
"""
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
Returns:
tuple: Deduplicated train, eval, and additional datasets.
"""
seen_hashes: dict[str, list[int]] = {}
# Handle cases where datasets are None
if train_dataset is not None:
LOG.info(
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
)
train_dataset = deduplicate_dataset(
dataset=train_dataset, seen_hashes=seen_hashes
)
LOG.info(
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
)
else:
LOG.info("Train dataset is None. Skipping deduplication.")
if eval_dataset is not None:
LOG.info(
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
)
eval_dataset = deduplicate_dataset(
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
)
LOG.info(
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
)
else:
LOG.info("Eval dataset is None. Skipping deduplication.")
if dataset is not None and (eval_dataset is None and train_dataset is None):
LOG.info(
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
)
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
LOG.info(
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
)
return train_dataset, eval_dataset, dataset

View File

@@ -2,10 +2,12 @@
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import gc import gc
import importlib
import logging import logging
import math import math
import os import os
import types import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401 from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
import addict import addict
@@ -28,6 +30,7 @@ from transformers import ( # noqa: F401
AddedToken, AddedToken,
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
@@ -89,7 +92,11 @@ def get_module_class_from_name(module, name):
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
if cfg.is_multimodal: if cfg.is_multimodal:
try:
model_config = model_config.text_config model_config = model_config.text_config
except AttributeError:
# for qwen2_vl
model_config = model_config.get_text_config()
quant_config_exists = ( quant_config_exists = (
hasattr(model_config, "quantization_config") hasattr(model_config, "quantization_config")
@@ -365,7 +372,11 @@ class ModelLoader:
# init model config # init model config
self.model_config = load_model_config(cfg) self.model_config = load_model_config(cfg)
if cfg.is_multimodal: if cfg.is_multimodal:
try:
self.text_model_config = self.model_config.text_config self.text_model_config = self.model_config.text_config
except AttributeError:
# for qwen2_vl
self.text_model_config = self.model_config.get_text_config()
else: else:
self.text_model_config = self.model_config self.text_model_config = self.model_config
@@ -409,7 +420,7 @@ class ModelLoader:
) )
if self.cfg.is_llama_derived_model: if self.cfg.is_llama_derived_model:
self.patch_loss() self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -451,27 +462,34 @@ class ModelLoader:
replace_stablelm_attn_with_flash_attn(self.cfg.base_model) replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
def patch_loss(self) -> None: @cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed"""
return importlib.util.find_spec("flash_attn") is not None
def patch_loss_llama(self) -> None:
""" """
Patch loss functions Patch loss functions
""" """
if self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy, patch_fa_llama_cross_entropy,
patch_llama_rms_norm, patch_llama_rms_norm,
) )
if self.cfg.flash_attn_cross_entropy: if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
patch_llama_cross_entropy() patch_fa_llama_cross_entropy()
if self.cfg.flash_attn_rms_norm: elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
patch_llama_rms_norm() patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm: elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm() patch_unsloth_layernorm()
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -481,6 +499,7 @@ class ModelLoader:
""" """
Modify all llama derived models in one block Modify all llama derived models in one block
""" """
self.patch_loss_llama()
if self.cfg.flash_attention: if self.cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
@@ -528,16 +547,6 @@ class ModelLoader:
"Shifted-sparse attention not currently implemented without flash attention." "Shifted-sparse attention not currently implemented without flash attention."
) )
if self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def set_auto_model_loader(self) -> None: def set_auto_model_loader(self) -> None:
"""set self.AutoModelLoader """set self.AutoModelLoader
- default value: AutoModelForCausalLM (set at __init__) - default value: AutoModelForCausalLM (set at __init__)
@@ -553,6 +562,10 @@ class ModelLoader:
self.AutoModelLoader = ( # pylint: disable=invalid-name self.AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration MllamaForConditionalGeneration
) )
elif self.model_config.model_type == "qwen2_vl":
self.AutoModelLoader = ( # pylint: disable=invalid-name
AutoModelForImageTextToText
)
else: else:
self.AutoModelLoader = ( self.AutoModelLoader = (
AutoModelForVision2Seq # pylint: disable=invalid-name AutoModelForVision2Seq # pylint: disable=invalid-name
@@ -1045,7 +1058,9 @@ class ModelLoader:
and self.model.get_input_embeddings().num_embeddings < embeddings_len and self.model.get_input_embeddings().num_embeddings < embeddings_len
): ):
resize_kwargs = {} resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None: if self.cfg.mean_resizing_embeddings is not None and not (
self.model_config.model_type == "llava"
):
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
else: else:
@@ -1084,14 +1099,17 @@ class ModelLoader:
self.prepare_model(qlora_fsdp) self.prepare_model(qlora_fsdp)
should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility. # convert them back to fp16/bf16 for flash-attn compatibility.
if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp: ((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
LOG.info( or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
"converting modules to %s for flash attention", self.cfg.torch_dtype
) )
if should_convert:
LOG.info("Converting modules to %s", self.cfg.torch_dtype)
self.convert_embedding_modules_dtype( self.convert_embedding_modules_dtype(
embedding_modules, embedding_modules=embedding_modules,
dist_dtype=self.cfg.torch_dtype, dist_dtype=self.cfg.torch_dtype,
before_kbit_train_or_finetune=False, before_kbit_train_or_finetune=False,
) )

View File

@@ -6,21 +6,29 @@ Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeo
""" """
# mypy: ignore-errors # mypy: ignore-errors
# pylint: skip-file # pylint: skip-file
# flake8: noqa
# mypy: allow-untyped-decorators # mypy: allow-untyped-decorators
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
from typing import List, Optional, Tuple, Union, cast from typing import Callable, List, Optional, Tuple, Union, cast
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim.optimizer import ( from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
DeviceDict,
Optimizer, Optimizer,
ParamsT, ParamsT,
_capturable_doc,
_default_to_fused_or_foreach, _default_to_fused_or_foreach,
_device_dtype_check_for_fused, _device_dtype_check_for_fused,
_differentiable_doc,
_disable_dynamo_if_unsupported, _disable_dynamo_if_unsupported,
_foreach_doc,
_fused_doc,
_get_capturable_supported_devices, _get_capturable_supported_devices,
_get_scalar_dtype, _get_scalar_dtype,
_get_value, _get_value,
_maximize_doc,
_stack_if_compiling,
_use_grad_for_differentiable, _use_grad_for_differentiable,
_view_as_real, _view_as_real,
) )
@@ -35,8 +43,9 @@ class ADOPT(Optimizer):
lr: Union[float, Tensor] = 1e-3, lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999), betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6, eps: float = 1e-6,
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
weight_decay: float = 0.0, weight_decay: float = 0.0,
decoupled: bool = False, decouple: bool = False,
*, *,
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
maximize: bool = False, maximize: bool = False,
@@ -62,12 +71,14 @@ class ADOPT(Optimizer):
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
self.clip_lambda = clip_lambda
defaults = dict( defaults = dict(
lr=lr, lr=lr,
betas=betas, betas=betas,
eps=eps, eps=eps,
weight_decay=weight_decay, weight_decay=weight_decay,
decoupled=decoupled, decouple=decouple,
maximize=maximize, maximize=maximize,
foreach=foreach, foreach=foreach,
capturable=capturable, capturable=capturable,
@@ -219,8 +230,9 @@ class ADOPT(Optimizer):
beta1=beta1, beta1=beta1,
beta2=beta2, beta2=beta2,
lr=group["lr"], lr=group["lr"],
clip_lambda=self.clip_lambda,
weight_decay=group["weight_decay"], weight_decay=group["weight_decay"],
decoupled=group["decoupled"], decouple=group["decouple"],
eps=group["eps"], eps=group["eps"],
maximize=group["maximize"], maximize=group["maximize"],
foreach=group["foreach"], foreach=group["foreach"],
@@ -247,8 +259,9 @@ def _single_tensor_adopt(
beta1: float, beta1: float,
beta2: float, beta2: float,
lr: Union[float, Tensor], lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float, weight_decay: float,
decoupled: bool, decouple: bool,
eps: float, eps: float,
maximize: bool, maximize: bool,
capturable: bool, capturable: bool,
@@ -276,13 +289,9 @@ def _single_tensor_adopt(
and param.device.type in capturable_supported_devices and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
# update step step = step_t if capturable or differentiable else _get_value(step_t)
step_t += 1
if weight_decay != 0: if weight_decay != 0 and not decouple:
if decoupled:
param.add_(param, alpha=-lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay) grad = grad.add(param, alpha=weight_decay)
if torch.is_complex(param): if torch.is_complex(param):
@@ -293,20 +302,29 @@ def _single_tensor_adopt(
exp_avg_sq = torch.view_as_real(exp_avg_sq) exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param) param = torch.view_as_real(param)
step = step_t if capturable or differentiable else _get_value(step_t) if step == 0:
if step == 1:
exp_avg_sq.addcmul_(grad, grad.conj()) exp_avg_sq.addcmul_(grad, grad.conj())
# update step
step_t += 1
continue continue
if weight_decay != 0 and decouple:
param.add_(param, alpha=-lr * weight_decay)
denom = torch.clamp(exp_avg_sq.sqrt(), eps) denom = torch.clamp(exp_avg_sq.sqrt(), eps)
if step == 2: normed_grad = grad.div(denom)
exp_avg.addcdiv_(grad, denom) if clip_lambda is not None:
else: clip = clip_lambda(step)
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1) normed_grad.clamp_(-clip, clip)
exp_avg.lerp_(normed_grad, 1 - beta1)
param.add_(exp_avg, alpha=-lr) param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
# update step
step_t += 1
def _multi_tensor_adopt( def _multi_tensor_adopt(
params: List[Tensor], params: List[Tensor],
@@ -321,8 +339,9 @@ def _multi_tensor_adopt(
beta1: float, beta1: float,
beta2: float, beta2: float,
lr: Union[float, Tensor], lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float, weight_decay: float,
decoupled: bool, decouple: bool,
eps: float, eps: float,
maximize: bool, maximize: bool,
capturable: bool, capturable: bool,
@@ -376,6 +395,18 @@ def _multi_tensor_adopt(
if maximize: if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
if weight_decay != 0 and not decouple:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 0:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
# Update steps # Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
@@ -387,34 +418,21 @@ def _multi_tensor_adopt(
else: else:
torch._foreach_add_(device_state_steps, 1) torch._foreach_add_(device_state_steps, 1)
if weight_decay != 0:
if decoupled:
torch._foreach_add_(
device_params, device_params, alpha=-lr * weight_decay
)
else:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add( # type: ignore[assignment]
device_grads, device_params, alpha=weight_decay
)
if device_state_steps[0] == 1:
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
continue continue
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) if weight_decay != 0 and decouple:
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps) torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
if device_state_steps[0] == 2: exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt) torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
else:
torch._foreach_mul_(device_exp_avgs, beta1) normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
torch._foreach_addcdiv_( if clip_lambda is not None:
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1 clip = clip_lambda(device_state_steps[0])
) torch._foreach_maximum_(normed_grad, -clip)
torch._foreach_minimum_(normed_grad, clip)
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr) torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2) torch._foreach_mul_(device_exp_avg_sqs, beta2)
@@ -422,6 +440,17 @@ def _multi_tensor_adopt(
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2 device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
) )
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
else:
torch._foreach_add_(device_state_steps, 1)
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
def adopt( def adopt(
@@ -443,8 +472,9 @@ def adopt(
beta1: float, beta1: float,
beta2: float, beta2: float,
lr: Union[float, Tensor], lr: Union[float, Tensor],
clip_lambda: Optional[Callable[[int], float]],
weight_decay: float, weight_decay: float,
decoupled: bool, decouple: bool,
eps: float, eps: float,
maximize: bool, maximize: bool,
): ):
@@ -497,8 +527,9 @@ def adopt(
beta1=beta1, beta1=beta1,
beta2=beta2, beta2=beta2,
lr=lr, lr=lr,
clip_lambda=clip_lambda,
weight_decay=weight_decay, weight_decay=weight_decay,
decoupled=decoupled, decouple=decouple,
eps=eps, eps=eps,
maximize=maximize, maximize=maximize,
capturable=capturable, capturable=capturable,

View File

@@ -15,17 +15,58 @@ def download_smollm2_135m_model():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset(): def download_llama_68m_random_model():
# download the model # download the model
snapshot_download("JackFram/llama-68m")
@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download("Qwen/Qwen2.5-0.5B")
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the dataset
snapshot_download("tatsu-lab/alpaca", repo_type="dataset") snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_dataset(): def download_mhenrichsen_alpaca_2k_dataset():
# download the model # download the dataset
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset") snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
# download the dataset
snapshot_download(
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
)
def download_mlabonne_finetome_100k_dataset():
# download the dataset
snapshot_download("mlabonne/FineTome-100k", repo_type="dataset")
@pytest.fixture
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
# download the dataset
snapshot_download(
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
)
@pytest.fixture
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset
snapshot_download(
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
)
@pytest.fixture @pytest.fixture
def temp_dir(): def temp_dir():
# Create a temporary directory # Create a temporary directory

32
tests/constants.py Normal file
View File

@@ -0,0 +1,32 @@
# constants.py
"""
This module contains constants and configuration dictionaries used for
datasets and other utilities in the Axolotl project, specifically for testing.
"""
# Configuration for Alpaca Messages Dataset
ALPACA_MESSAGES_CONFIG_OG = {
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
# Revision configuration extending the original
ALPACA_MESSAGES_CONFIG_REVISION = ALPACA_MESSAGES_CONFIG_OG.copy()
ALPACA_MESSAGES_CONFIG_REVISION["revision"] = "ea82cff"
SPECIAL_TOKENS = {
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}

View File

@@ -14,9 +14,7 @@ from axolotl.utils.models import load_model, load_tokenizer
def fixture_cfg(): def fixture_cfg():
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", "base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"learning_rate": 0.00005, "learning_rate": 0.00005,
@@ -33,6 +31,9 @@ def fixture_cfg():
"dataloader_num_workers": 1, "dataloader_num_workers": 1,
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
"model_config_type": "llama", "model_config_type": "llama",
"special_tokens": {
"pad_token": "<|endoftext|>",
},
} }
) )

View File

@@ -7,7 +7,7 @@ from pathlib import Path
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir from ..utils import with_temp_dir
@@ -54,8 +54,10 @@ class LigerIntegrationTestCase(unittest.TestCase):
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_safetensors": True, "save_safetensors": True,
"bf16": "auto", "bf16": "auto",
"max_steps": 10,
} }
) )
prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -99,8 +101,10 @@ class LigerIntegrationTestCase(unittest.TestCase):
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_safetensors": True, "save_safetensors": True,
"bf16": "auto", "bf16": "auto",
"max_steps": 10,
} }
) )
prepare_plugins(cfg)
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -0,0 +1,94 @@
"""
Simple end-to-end test for Cut Cross Entropy integration
"""
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
# pylint: disable=duplicate-code
@pytest.fixture()
def min_cfg(temp_dir):
return {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
],
"cut_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
}
class TestCutCrossEntropyIntegration:
"""
e2e tests for cut_cross_entropy integration with Axolotl
"""
# pylint: disable=redefined-outer-name
def test_llama_w_cce(self, min_cfg, temp_dir):
cfg = DictDefault(min_cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@pytest.mark.parametrize(
"attention_type",
["flash_attention", "sdp_attention", "xformers_attention"],
)
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
cfg = DictDefault(
min_cfg
| {
attention_type: True,
}
)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -7,10 +7,13 @@ from pathlib import Path
import yaml import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -26,7 +29,7 @@ class TestMultiGPUEval:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False, "load_in_8bit": False,
"load_in_4bit": True, "load_in_4bit": True,
"strict": False, "strict": False,
@@ -40,8 +43,8 @@ class TestMultiGPUEval:
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"], "lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1, "val_set_size": 0.004,
"special_tokens": {"pad_token": "<|end_of_text|>"}, "special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [ "datasets": [
{ {
"path": "teknium/GPT4-LLM-Cleaned", "path": "teknium/GPT4-LLM-Cleaned",
@@ -66,6 +69,7 @@ class TestMultiGPUEval:
"saves_per_epoch": 1, "saves_per_epoch": 1,
"logging_steps": 1, "logging_steps": 1,
"weight_decay": 0.0, "weight_decay": 0.0,
"use_tensorboard": True,
} }
) )
@@ -87,12 +91,18 @@ class TestMultiGPUEval:
str(Path(temp_dir) / "config.yaml"), str(Path(temp_dir) / "config.yaml"),
] ]
) )
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.5, "Loss is too high"
def test_eval(self, temp_dir): def test_eval(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False, "load_in_8bit": False,
"load_in_4bit": True, "load_in_4bit": True,
"strict": False, "strict": False,
@@ -106,8 +116,8 @@ class TestMultiGPUEval:
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"], "lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1, "val_set_size": 0.0004,
"special_tokens": {"pad_token": "<|end_of_text|>"}, "special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [ "datasets": [
{ {
"path": "teknium/GPT4-LLM-Cleaned", "path": "teknium/GPT4-LLM-Cleaned",
@@ -132,6 +142,7 @@ class TestMultiGPUEval:
"saves_per_epoch": 1, "saves_per_epoch": 1,
"logging_steps": 1, "logging_steps": 1,
"weight_decay": 0.0, "weight_decay": 0.0,
"use_tensorboard": True,
} }
) )
@@ -153,3 +164,9 @@ class TestMultiGPUEval:
str(Path(temp_dir) / "config.yaml"), str(Path(temp_dir) / "config.yaml"),
] ]
) )
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.9, "Loss is too high"

View File

@@ -4,11 +4,11 @@ E2E tests for lora llama
import logging import logging
import os import os
import unittest
from importlib import reload from importlib import reload
from pathlib import Path from pathlib import Path
import pytest import pytest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
@@ -17,7 +17,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir from ..utils import most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -31,49 +31,55 @@ def reload_transformers():
reload(transformers.models.llama.modeling_llama) reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama(unittest.TestCase): class TestFAXentropyLlama:
""" """
Test case for Llama models using LoRA w multipack Test case for Llama models using LoRA w multipack
""" """
@with_temp_dir @pytest.mark.parametrize(
def test_lora_packing_fa_cross_entropy(self, temp_dir): "gradient_accumulation_steps",
[1, 4],
)
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": True,
"flash_attention": True, "flash_attention": True,
"flash_attn_cross_entropy": True, "flash_attn_cross_entropy": True,
"load_in_8bit": True, "load_in_8bit": True,
"adapter": "lora", "adapter": "lora",
"lora_r": 32, "lora_r": 8,
"lora_alpha": 64, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.2, "val_set_size": 0.05,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "pad_token": "<|endoftext|>",
"bos_token": "<s>",
"eos_token": "</s>",
}, },
"chat_template": "chatml",
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mlabonne/FineTome-100k",
"type": "alpaca", "field_messages": "conversations",
"message_field_content": "value",
"message_field_role": "from",
"type": "chat_template",
"split": "train[:2%]",
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 10, "max_steps": 5,
"save_steps": 10, "save_steps": 5,
"micro_batch_size": 8, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"use_tensorboard": True,
} }
) )
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():
@@ -87,3 +93,10 @@ class TestFAXentropyLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 1.5, "Loss is too high"

View File

@@ -6,7 +6,6 @@ import logging
import os import os
import re import re
import subprocess import subprocess
import unittest
from pathlib import Path from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
@@ -17,35 +16,35 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import most_recent_subdir, with_temp_dir from ..utils import most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
class TestResumeLlama(unittest.TestCase): class TestResumeLlama:
""" """
Test case for resuming training of llama models Test case for resuming training of llama models
""" """
@with_temp_dir def test_resume_lora_packed(self, temp_dir):
def test_resume_qlora_packed(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": True,
"flash_attention": True, "flash_attention": True,
"load_in_4bit": True, "load_in_8bit": True,
"adapter": "qlora", "adapter": "lora",
"lora_r": 32, "lora_r": 8,
"lora_alpha": 64, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.001,
"special_tokens": {}, "special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [ "datasets": [
{ {
"path": "vicgalle/alpaca-gpt4", "path": "vicgalle/alpaca-gpt4",
@@ -57,11 +56,11 @@ class TestResumeLlama(unittest.TestCase):
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_steps": 10, "save_steps": 3,
"save_total_limit": 5, "save_total_limit": 5,
"max_steps": 40, "max_steps": 15,
"use_tensorboard": True, "use_tensorboard": True,
} }
) )
@@ -77,7 +76,7 @@ class TestResumeLlama(unittest.TestCase):
resume_cfg = cfg | DictDefault( resume_cfg = cfg | DictDefault(
{ {
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/", "resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
} }
) )
normalize_config(resume_cfg) normalize_config(resume_cfg)
@@ -93,4 +92,4 @@ class TestResumeLlama(unittest.TestCase):
) )
pattern = r"first_step\s+(\d+)" pattern = r"first_step\s+(\d+)"
first_steps = int(re.findall(pattern, res.stdout)[0]) first_steps = int(re.findall(pattern, res.stdout)[0])
assert first_steps == 31 assert first_steps == 10

View File

@@ -0,0 +1,186 @@
"""
e2e tests for unsloth qlora
"""
import logging
import os
from pathlib import Path
import pytest
from e2e.utils import most_recent_subdir
from tbparse import SummaryReader
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code
class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models
"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": sample_packing,
"flash_attention": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"
@pytest.mark.parametrize(
"sdp_attention",
[True, False],
)
def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"sample_packing": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 10,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"sdp_attention": sdp_attention,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

View File

@@ -0,0 +1,121 @@
"""
E2E tests for llama pretrain
"""
import logging
import os
import unittest
from pathlib import Path
from tbparse import SummaryReader
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestEmbeddingsLrScale(unittest.TestCase):
"""
Test case for embedding_lr*
"""
@with_temp_dir
def test_train_w_embedding_lr_scale(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"embedding_lr_scale": 0.5,
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"
@with_temp_dir
def test_train_w_embedding_lr(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": True,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"embedding_lr": 0.000005,
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

View File

@@ -0,0 +1,116 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestLlamaVision(unittest.TestCase):
"""
Test case for Llama Vision models
"""
@with_temp_dir
def test_lora_llama_vision_text_only_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
"processor_type": "AutoProcessor",
"skip_prepare_dataset": True,
"remove_unused_columns": False,
"sample_packing": False,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@with_temp_dir
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
"processor_type": "AutoProcessor",
"skip_prepare_dataset": True,
"remove_unused_columns": False,
"sample_packing": False,
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
{
"path": "axolotl-ai-co/llava-instruct-mix-vsft-small",
"type": "chat_template",
"split": "train",
"field_messages": "messages",
},
],
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -57,6 +57,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"max_steps": 20,
} }
) )
normalize_config(cfg) normalize_config(cfg)

View File

@@ -56,6 +56,7 @@ class TestCustomOptimizers(unittest.TestCase):
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "optimi_adamw", "optimizer": "optimi_adamw",
"max_steps": 5,
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
} }
) )
@@ -94,6 +95,7 @@ class TestCustomOptimizers(unittest.TestCase):
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -115,7 +117,7 @@ class TestCustomOptimizers(unittest.TestCase):
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.1, "val_set_size": 0.01,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",
}, },
@@ -126,13 +128,14 @@ class TestCustomOptimizers(unittest.TestCase):
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 4, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "schedule_free_adamw", "optimizer": "schedule_free_adamw",
"lr_scheduler": "constant", "lr_scheduler": "constant",
"save_safetensors": True, "save_safetensors": True,
"max_steps": 10,
} }
) )
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

View File

@@ -7,13 +7,15 @@ import os
import unittest import unittest
from pathlib import Path from pathlib import Path
from tbparse import SummaryReader
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -29,35 +31,48 @@ class TestReLoraLlama(unittest.TestCase):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "LlamaTokenizer", "sequence_len": 2048,
"sequence_len": 1024, "sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"load_in_8bit": True, "load_in_8bit": True,
"adapter": "lora", "adapter": "lora",
"lora_r": 32, "lora_r": 8,
"lora_alpha": 16, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"], "lora_target_modules": ["q_proj", "v_proj"],
"relora_steps": 25, "relora_steps": 100,
"relora_warmup_steps": 5, "relora_warmup_steps": 20,
"relora_anneal_steps": 5, "relora_anneal_steps": 10,
"relora_prune_ratio": 0.9,
"relora_cpu_offload": True, "relora_cpu_offload": True,
"val_set_size": 0.0, "val_set_size": 0.0,
"special_tokens": {}, "special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "mlabonne/FineTome-100k",
"type": "alpaca", "type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
}, },
], ],
"warmup_steps": 15, "warmup_steps": 20,
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 4, "max_steps": 205, # at least 2x relora_steps
"micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -65,4 +80,14 @@ class TestReLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists() assert (
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
).exists()
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/grad_norm")] # pylint: disable=invalid-name
assert df.value.values[-1] < 0.2, "grad_norm is too high"

View File

@@ -53,7 +53,7 @@ def require_torch_2_3_1(test_case):
def require_torch_2_5_1(test_case): def require_torch_2_5_1(test_case):
""" """
Decorator marking a test that requires torch >= 2.3.1 Decorator marking a test that requires torch >= 2.5.1
""" """
def is_min_2_5_1(): def is_min_2_5_1():

View File

@@ -7,6 +7,11 @@ import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from constants import (
ALPACA_MESSAGES_CONFIG_OG,
ALPACA_MESSAGES_CONFIG_REVISION,
SPECIAL_TOKENS,
)
from datasets import Dataset from datasets import Dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
@@ -21,13 +26,7 @@ class TestDatasetPreparation(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens( self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
# Alpaca dataset. # Alpaca dataset.
self.dataset = Dataset.from_list( self.dataset = Dataset.from_list(
[ [
@@ -68,7 +67,7 @@ class TestDatasetPreparation(unittest.TestCase):
def test_load_local_hub(self): def test_load_local_hub(self):
"""Niche use case. Verify that a local copy of a hub dataset can be loaded""" """Niche use case. Verify that a local copy of a hub dataset can be loaded"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test") tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True) tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download( snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test", repo_id="mhenrichsen/alpaca_2k_test",
@@ -90,7 +89,7 @@ class TestDatasetPreparation(unittest.TestCase):
"ds_type": "parquet", "ds_type": "parquet",
"type": "alpaca", "type": "alpaca",
"data_files": [ "data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet", f"{tmp_ds_path}/alpaca_2000.parquet",
], ],
}, },
], ],
@@ -277,23 +276,7 @@ class TestDatasetPreparation(unittest.TestCase):
"sequence_len": 1024, "sequence_len": 1024,
"rl": "dpo", "rl": "dpo",
"chat_template": "llama3", "chat_template": "llama3",
"datasets": [ "datasets": [ALPACA_MESSAGES_CONFIG_OG],
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
} }
) )
@@ -342,24 +325,7 @@ class TestDatasetPreparation(unittest.TestCase):
"sequence_len": 1024, "sequence_len": 1024,
"rl": "dpo", "rl": "dpo",
"chat_template": "llama3", "chat_template": "llama3",
"datasets": [ "datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"revision": "ea82cff",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
} }
) )

View File

@@ -0,0 +1,433 @@
"""
Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function.
Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command.
"""
import hashlib
import unittest
from unittest.mock import patch
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
"""
Validates deduplication results and size consistency.
Parameters:
- actual_dataset: Deduplicated dataset.
- expected_dataset: Expected dataset.
- dataset_name: Name of the dataset (e.g., 'train' or 'eval').
Asserts:
- Datasets match in content.
- Dataset size matches unique row count.
"""
# Convert datasets to sets of tuples for unordered comparison
actual_rows = set(tuple(row.values()) for row in actual_dataset)
expected_rows = set(tuple(row.values()) for row in expected_dataset)
# Verify deduplication correctness
assert actual_rows == expected_rows, f"Mismatch in {dataset_name} dataset"
# Verify size consistency
assert len(actual_rows) == len(
actual_dataset
), f"Size mismatch in {dataset_name} dataset after deduplication"
class TestDeduplicateIndividualFunctions(unittest.TestCase):
"""
test class for deduplication function in data utils
"""
def setUp(self):
# Sample data with duplicates
self.data = {
"column1": ["apple", "banana", "apple", "orange", "banana"],
"column2": [1, 2, 1, 3, 2],
"column3": ["red", "yellow", "red", "orange", "yellow"],
}
# Expected result after deduplication
self.expected_data = {
"column1": ["apple", "banana", "orange"],
"column2": [1, 2, 3],
"column3": ["red", "yellow", "orange"],
}
# Convert to Dataset format
self.dataset = Dataset.from_dict(self.data)
self.expected_dataset = Dataset.from_dict(self.expected_data)
def test_deduplication(self):
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_datasets_are_none(self):
# Test when both datasets are None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=None
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
def test_only_train_is_none(self):
# Test when only train_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=self.dataset
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_only_eval_is_none(self):
# Test when only eval_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=self.dataset, eval_dataset=None
)
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
def test_exact_duplicates(self):
# Test when datasets are exact duplicates
duplicate_data = {
"column1": ["apple", "apple", "apple"],
"column2": [1, 1, 1],
"column3": ["red", "red", "red"],
}
expected_data = {"column1": ["apple"], "column2": [1], "column3": ["red"]}
# Convert to Dataset format
dataset = Dataset.from_dict(duplicate_data)
expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
def test_partial_duplicates(self):
# Test when only part of the dataset is a duplicate
partial_duplicate_data = {
"column1": ["apple", "banana", "apple"],
"column2": [1, 2, 1],
"column3": ["red", "yellow", "red"],
}
expected_data = {
"column1": ["apple", "banana"],
"column2": [1, 2],
"column3": ["red", "yellow"],
}
# Convert to Dataset format
dataset = Dataset.from_dict(partial_duplicate_data)
expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
def test_combined_duplicates_empty(self):
# Test when only part of the dataset is a duplicate
partial_duplicate_data = {
"column1": ["apple", "banana", "apple"],
"column2": [1, 2, 1],
"column3": ["red", "yellow", "red"],
}
expected_data_train = {
"column1": ["apple", "banana"],
"column2": [1, 2],
"column3": ["red", "yellow"],
}
expected_data_eval = {
"column1": [],
"column2": [],
"column3": [],
}
# Convert to Dataset format
dataset = Dataset.from_dict(partial_duplicate_data)
expected_dataset_train = Dataset.from_dict(expected_data_train)
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=dataset, eval_dataset=dataset
)
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
def test_combined_duplicates_one(self):
# Test when only part of the dataset is a duplicate
partial_duplicate_data_train = {
"column1": ["apple", "banana", "apple"],
"column2": [1, 2, 1],
"column3": ["red", "yellow", "red"],
}
partial_duplicate_data_eval = {
"column1": ["apple", "orange", "apple"],
"column2": [1, 2, 1],
"column3": ["red", "orange", "red"],
}
expected_data_train = {
"column1": ["apple", "banana"],
"column2": [1, 2],
"column3": ["red", "yellow"],
}
expected_data_eval = {
"column1": ["orange"],
"column2": [2],
"column3": ["orange"],
}
# Convert to Dataset format
dataset_train = Dataset.from_dict(partial_duplicate_data_train)
dataset_eval = Dataset.from_dict(partial_duplicate_data_eval)
expected_dataset_train = Dataset.from_dict(expected_data_train)
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=dataset_train, eval_dataset=dataset_eval
)
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
class TestDeduplicateRLDataset(unittest.TestCase):
"""Test a configured dataloader with deduplication."""
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"dataset_exact_deduplication": True,
"datasets": [
ALPACA_MESSAGES_CONFIG_REVISION,
ALPACA_MESSAGES_CONFIG_REVISION,
],
}
)
def test_load_with_deduplication(self):
"""Verify that loading with deduplication removes duplicates."""
# Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
def test_load_without_deduplication(self):
"""Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
class TestDeduplicateNonRL(unittest.TestCase):
"""Test prepare_dataset function with different configurations."""
def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
self.cfg_1 = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"dataset_exact_deduplication": True,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"val_set_size": 0.0,
"gradient_accumulation_steps": 4,
"batch_size": 10,
"micro_batch_size": 10,
"num_epochs": 1,
}
)
def test_prepare_dataset_with_deduplication_train(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
self.cfg_1.dataset_exact_deduplication = True
# Load tokenizer and processor
tokenizer = load_tokenizer(self.cfg_1)
processor = (
load_processor(self.cfg_1, tokenizer=tokenizer)
if self.cfg_1.processor_type
else None
)
# Prepare dataset using the prepare_dataset function
train_dataset, _, _, _ = prepare_dataset(
self.cfg_1,
tokenizer,
processor=processor,
)
self.assertEqual(
len(train_dataset),
2000,
"Train dataset should have 2000 samples after deduplication.",
)
def test_prepare_dataset_with_deduplication_eval(self):
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
self.cfg_1.dataset_exact_deduplication = True
self.cfg_1.val_set_size = 0.5
# Load tokenizer and processor
tokenizer = load_tokenizer(self.cfg_1)
processor = (
load_processor(self.cfg_1, tokenizer=tokenizer)
if self.cfg_1.processor_type
else None
)
# Prepare dataset using the prepare_dataset function
_, eval_dataset, _, _ = prepare_dataset(
self.cfg_1,
tokenizer,
processor=processor,
)
self.assertEqual(
len(eval_dataset),
1000,
"Eval dataset should have 2000 samples after deduplication.",
)
def test_prepare_dataset_without_deduplication(self):
"""Verify that prepare_dataset function processes the dataset correctly without deduplication."""
self.cfg_1.dataset_exact_deduplication = False
self.cfg_1.val_set_size = 0.1
# Load tokenizer and processor
tokenizer = load_tokenizer(self.cfg_1)
processor = (
load_processor(self.cfg_1, tokenizer=tokenizer)
if self.cfg_1.processor_type
else None
)
# Prepare dataset using the prepare_dataset function
train_dataset, eval_dataset, _, _ = prepare_dataset(
self.cfg_1,
tokenizer,
processor=processor,
)
# Verify that the dataset has been prepared correctly
self.assertEqual(
len(train_dataset),
1800 * 2,
"Train dataset should have 3600 samples without deduplication.",
)
self.assertEqual(
len(eval_dataset),
200 * 2,
"Train dataset should have 400 samples after deduplication.",
)
class TestWrongCollisions(unittest.TestCase):
"""Creating mock datasets for testing wrong collisions"""
def setUp(self):
self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]}
self.eval_data = {
"text": [
"sample 5",
"sample 7",
], # Different label but same text as in train_data
"label": [2, 3],
}
self.dataset_data = {
"text": ["sample 5", "sample 9", "sample 5"],
"label": [1, 2, 8],
}
self.train_dataset = Dataset.from_dict(self.train_data)
self.eval_dataset = Dataset.from_dict(self.eval_data)
self.dataset = Dataset.from_dict(self.dataset_data)
@patch(
"axolotl.utils.data.utils.sha256",
side_effect=lambda x: hashlib.sha256(
"forced_collision_hash".encode("utf-8")
).hexdigest()
if "sample 5" in x
else hashlib.sha256(x.encode("utf-8")).hexdigest(),
)
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
)
self.assertEqual(
len(dedup_train),
2,
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
2,
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
len(self.eval_dataset),
"The output eval dataset should have the same number of rows as the input eval dataset.",
)
self.assertEqual(
str(dedup_eval),
str(self.eval_dataset),
"The string representation of the output eval dataset should be identical to the input eval dataset.",
)
def test_deduplication_dataset_only(self):
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
self.assertEqual(
len(dedup_dataset), 3, "Dataset should have all original values"
)
self.assertEqual(
str(dedup_dataset),
str(self.dataset),
"The string representation of the output dataset should not differ.",
)
if __name__ == "__main__":
unittest.main()

View File

@@ -7,35 +7,40 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.callbacks.perplexity import Perplexity
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" MODEL_NAME = "HuggingFaceTB/SmolLM2-135M"
@fixture() @fixture()
def metric(tokenizer): def metric(tokenizer):
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True) return Perplexity(tokenizer=tokenizer, max_seq_len=512)
return Perplexity(model, tokenizer, 512)
@fixture()
def model():
return AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
@fixture() @fixture()
def tokenizer(): def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) tokenizer_ = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer_.add_special_tokens({"pad_token": "<|endoftext|>"})
return tokenizer_
def test_perplexity_longer_than_stride(metric): def test_perplexity_longer_than_stride(model, metric):
# taken from https://huggingface.co/datasets/roneneldan/TinyStories # taken from https://huggingface.co/datasets/roneneldan/TinyStories
sample_text = """ sample_text = """
Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after. Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.
One day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. "Hi, I am Fin. Do you want to play?" asked the little fish. The crab looked at Fin and said, "No, I don't want to play. I am cold and I don't feel fine." Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, "Please, sun, help my new friend feel fine and not freeze!" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, "Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!" And so, Fin and the crab played and became good friends. One day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. "Hi, I am Fin. Do you want to play?" asked the little fish. The crab looked at Fin and said, "No, I don't want to play. I am cold and I don't feel fine." Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, "Please, sun, help my new friend feel fine and not freeze!" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, "Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!" And so, Fin and the crab played and became good friends.
""" """
result = metric.compute([sample_text]) result = metric.compute(model, [sample_text])
ppl = result["score"] ppl = result["score"]
assert round(ppl, 2) == 5.37 assert round(ppl, 2) == 7.41
def test_perplexity_short(metric): def test_perplexity_short(model, metric):
# taken from https://huggingface.co/datasets/roneneldan/TinyStories # taken from https://huggingface.co/datasets/roneneldan/TinyStories
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun." sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
result = metric.compute([sample_text]) result = metric.compute(model, [sample_text])
ppl = result["score"] ppl = result["score"]
assert round(ppl, 2) == 10.02 assert round(ppl, 2) == 10.33

View File

@@ -672,6 +672,9 @@ class TestValidation(BaseValidation):
{ {
"bf16": True, "bf16": True,
"capabilities": {"bf16": False}, "capabilities": {"bf16": False},
"env_capabilities": {
"torch_version": "2.5.1",
},
} }
) )
| minimal_cfg | minimal_cfg
@@ -1160,6 +1163,38 @@ class TestValidation(BaseValidation):
in self._caplog.records[0].message in self._caplog.records[0].message
) )
def test_torch_version_adopt_req(self, minimal_cfg):
cfg = (
DictDefault(
{
"optimizer": "adopt_adamw",
}
)
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*ADOPT optimizer is incompatible with torch version*",
):
env_capabilities = {"torch_version": "2.3.0"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.1"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.2"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
class TestValidationCheckModelConfig(BaseValidation): class TestValidationCheckModelConfig(BaseValidation):
""" """

View File

@@ -72,6 +72,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1, "n_gpu": 1,
"compute_capability": "8.0", "compute_capability": "8.0",
}, },
env_capabilities={
"torch_version": "2.5.1",
},
) )
_check_config() _check_config()
@@ -124,6 +127,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1, "n_gpu": 1,
"compute_capability": "8.0", "compute_capability": "8.0",
}, },
env_capabilities={
"torch_version": "2.5.1",
},
) )
_check_config() _check_config()
@@ -177,6 +183,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1, "n_gpu": 1,
"compute_capability": "8.0", "compute_capability": "8.0",
}, },
env_capabilities={
"torch_version": "2.5.1",
},
) )
_check_config() _check_config()
@@ -231,6 +240,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
"n_gpu": 1, "n_gpu": 1,
"compute_capability": "8.0", "compute_capability": "8.0",
}, },
env_capabilities={
"torch_version": "2.5.1",
},
) )
_check_config() _check_config()