Compare commits
53 Commits
fix/cce-li
...
pixtral_in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fbf3ca86c9 | ||
|
|
2de866e92f | ||
|
|
295e07dcca | ||
|
|
3c07b6d6b1 | ||
|
|
89dae7dc6d | ||
|
|
1b54af8e54 | ||
|
|
ca7b56cba3 | ||
|
|
ea8269d2eb | ||
|
|
13ca7ed087 | ||
|
|
0dfd8541ee | ||
|
|
75e1d3537f | ||
|
|
2b7f3bd6ab | ||
|
|
d85a229afe | ||
|
|
355cd7c872 | ||
|
|
eab1638686 | ||
|
|
a3a4d22709 | ||
|
|
f9eb7d8663 | ||
|
|
343771a6d3 | ||
|
|
d2c32d0cba | ||
|
|
cec9887609 | ||
|
|
88b2cae748 | ||
|
|
aea2565938 | ||
|
|
1ad56303b2 | ||
|
|
dc055a4ef7 | ||
|
|
169116a50f | ||
|
|
43e412f660 | ||
|
|
7aa57803e1 | ||
|
|
1969fa3bf0 | ||
|
|
4078f37076 | ||
|
|
f073af6d99 | ||
|
|
139d2612fa | ||
|
|
20573fd13e | ||
|
|
2b7b4af81c | ||
|
|
d56260c8d5 | ||
|
|
cac785ec0e | ||
|
|
e62991edef | ||
|
|
fd9e7b55f6 | ||
|
|
c0c53eb62f | ||
|
|
b0fbd4d11d | ||
|
|
1a70d4d6a4 | ||
|
|
d8787a433f | ||
|
|
e775422269 | ||
|
|
97178f5960 | ||
|
|
4698eed43f | ||
|
|
f84c3b37e7 | ||
|
|
c39971c659 | ||
|
|
33a178c788 | ||
|
|
db15605e7e | ||
|
|
9e112bc8b5 | ||
|
|
e038410778 | ||
|
|
f4385c3cf4 | ||
|
|
d58c772df6 | ||
|
|
69265a53b5 |
10
.github/workflows/base.yml
vendored
10
.github/workflows/base.yml
vendored
@@ -1,6 +1,16 @@
|
||||
name: ci-cd-base
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- 'Dockerfile-base'
|
||||
- '.github/workflows/base.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'Dockerfile-base'
|
||||
- '.github/workflows/base.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
1
.github/workflows/tests-nightly.yml
vendored
1
.github/workflows/tests-nightly.yml
vendored
@@ -55,6 +55,7 @@ jobs:
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install -U -e .
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
|
||||
6
.github/workflows/tests.yml
vendored
6
.github/workflows/tests.yml
vendored
@@ -8,11 +8,15 @@ on:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
workflow_dispatch:
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
@@ -67,6 +71,8 @@ jobs:
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
|
||||
@@ -147,7 +147,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
|
||||
### Usage
|
||||
```bash
|
||||
# preprocess datasets - optional but recommended
|
||||
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
|
||||
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
|
||||
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
|
||||
@@ -37,6 +37,9 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
RUN python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
pytest -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -40,6 +40,7 @@ with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
cicd_image = (
|
||||
Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
context_mount=None,
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
)
|
||||
|
||||
@@ -26,6 +26,9 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
RUN python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install pytest
|
||||
|
||||
|
||||
@@ -29,7 +29,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
|
||||
@@ -162,6 +162,9 @@ datasets:
|
||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
||||
shuffle_merged_datasets: true
|
||||
|
||||
Deduplicates datasets and test_datasets with identical entries.
|
||||
dataset_exact_deduplication: true
|
||||
|
||||
# A list of one or more datasets to eval the model with.
|
||||
# You can use either test_datasets, or val_set_size, but not both.
|
||||
test_datasets:
|
||||
@@ -406,7 +409,7 @@ lr_div_factor: # Learning rate div factor
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (only for torch version >= 2.5.1)
|
||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - sgd
|
||||
|
||||
95
examples/llama-3/lora-1b-deduplicate-dpo.yml
Normal file
95
examples/llama-3/lora-1b-deduplicate-dpo.yml
Normal file
@@ -0,0 +1,95 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
chat_template: llama3
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_exact_deduplication: true
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
76
examples/llama-3/lora-1b-deduplicate-sft.yml
Normal file
76
examples/llama-3/lora-1b-deduplicate-sft.yml
Normal file
@@ -0,0 +1,76 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
dataset_exact_deduplication: true
|
||||
test_value: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
63
examples/llava/lora-7b.yaml
Normal file
63
examples/llava/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: llava-hf/llava-1.5-7b-hf
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: llava
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
65
examples/pixtral/lora-12b.yml
Normal file
65
examples/pixtral/lora-12b.yml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: mistral-community/pixtral-12b
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: pixtral
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
63
examples/qwen2-vl/lora-7b.yaml
Normal file
63
examples/qwen2-vl/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||
processor_type: AutoProcessor
|
||||
strict: false
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen2_vl
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -2,4 +2,3 @@ pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
tbparse
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
pytest
|
||||
pytest-xdist
|
||||
pytest-retry
|
||||
pytest-sugar
|
||||
tbparse
|
||||
|
||||
@@ -26,7 +26,7 @@ numpy>=1.24.4,<=2.0.1
|
||||
evaluate==0.4.1
|
||||
scipy
|
||||
scikit-learn==1.4.2
|
||||
pynvml
|
||||
nvidia-ml-py==12.560.30
|
||||
art
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
28
scripts/cutcrossentropy_install.py
Normal file
28
scripts/cutcrossentropy_install.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError as exc:
|
||||
raise ImportError("Install torch via `pip install torch`") from exc
|
||||
from packaging.version import Version as V
|
||||
|
||||
v = V(torch.__version__)
|
||||
|
||||
# no cut-cross-entropy support for torch < 2.4.0
|
||||
if v < V("2.4.0"):
|
||||
print("")
|
||||
sys.exit(0)
|
||||
|
||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")
|
||||
|
||||
UNINSTALL_PREFIX = ""
|
||||
if cce_spec and not cce_spec_transformers:
|
||||
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||
)
|
||||
@@ -8,7 +8,10 @@ from packaging.version import Version as V
|
||||
|
||||
v = V(torch.__version__)
|
||||
cuda = str(torch.version.cuda)
|
||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||
try:
|
||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||
except RuntimeError:
|
||||
is_ampere = False
|
||||
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||
if v <= V("2.1.0"):
|
||||
@@ -29,5 +32,5 @@ else:
|
||||
raise RuntimeError(f"Torch = {v} too new!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
print(
|
||||
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
|
||||
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
|
||||
)
|
||||
|
||||
@@ -27,7 +27,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import (
|
||||
@@ -38,6 +37,7 @@ from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
prepare_plugins,
|
||||
validate_config,
|
||||
)
|
||||
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
|
||||
@@ -100,8 +100,8 @@ def print_dep_versions():
|
||||
print("*" * 40)
|
||||
print("**** Axolotl Dependency Versions *****")
|
||||
for pkg in packages:
|
||||
version = _is_package_available(pkg, return_version=True)
|
||||
print(f"{pkg: >{max_len}}: {version[1]: <15}")
|
||||
pkg_version = _is_package_available(pkg, return_version=True)
|
||||
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
|
||||
print("*" * 40)
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ def check_remote_config(config: Union[str, Path]):
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(content)
|
||||
LOG.info(
|
||||
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
|
||||
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
|
||||
)
|
||||
return output_path
|
||||
|
||||
@@ -426,11 +426,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
|
||||
cfg.axolotl_config_path = config
|
||||
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||
@@ -444,8 +439,13 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
},
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
prepare_opinionated_env(cfg)
|
||||
|
||||
@@ -19,7 +19,7 @@ from axolotl.common.cli import TrainerCliArgs
|
||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg = load_cfg(config, inference=True, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
|
||||
@@ -107,6 +107,22 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||
return kwargs
|
||||
|
||||
|
||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||
if isinstance(dataset_tags, str):
|
||||
dataset_tags = [dataset_tags]
|
||||
|
||||
if (dataset_tags is not None) and (kwargs is not None):
|
||||
if "dataset_tags" not in kwargs:
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||
kwargs["dataset_tags"].extend(dataset_tags)
|
||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||
dataset_tags.append(kwargs["dataset_tags"])
|
||||
kwargs["dataset_tags"] = dataset_tags
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
@@ -220,6 +236,14 @@ class AxolotlTrainingMixins:
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
@@ -386,7 +410,7 @@ class SchedulerMixin(Trainer):
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
@@ -410,10 +434,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
self.dataset_tags = dataset_tags
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
@@ -435,6 +461,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
def create_optimizer(self):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.embedding_lr_scale is None
|
||||
and self.args.embedding_lr is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
@@ -449,30 +477,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in opt_model.named_parameters()
|
||||
if (n in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in opt_model.named_parameters()
|
||||
if (n not in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
params = {
|
||||
"to_weight_decay": {}, # LayerNorm and bias
|
||||
"embeddings": {}, # lm_head, embed_tokens,
|
||||
"no_weight_decay": {},
|
||||
}
|
||||
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
for name, param in opt_model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if name.endswith("modules_to_save.default.weight") or any(
|
||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||
):
|
||||
params["embeddings"][name] = param
|
||||
elif name in decay_parameters:
|
||||
params["to_weight_decay"][name] = param
|
||||
else:
|
||||
params["no_weight_decay"][name] = param
|
||||
optimizer_grouped_parameters = []
|
||||
if params["to_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["to_weight_decay"].values()),
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
if params["embeddings"]:
|
||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||
if self.args.embedding_lr_scale:
|
||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||
elif self.args.embedding_lr:
|
||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["embeddings"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": lr,
|
||||
}
|
||||
)
|
||||
if params["no_weight_decay"]:
|
||||
optimizer_grouped_parameters.append(
|
||||
{
|
||||
"params": list(params["no_weight_decay"].values()),
|
||||
"weight_decay": 0.0,
|
||||
"lr": optimizer_kwargs["lr"],
|
||||
}
|
||||
)
|
||||
|
||||
if self.args.loraplus_lr_ratio is not None:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
@@ -485,6 +542,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
elif (
|
||||
self.args.embedding_lr_scale is not None
|
||||
or self.args.embedding_lr is not None
|
||||
):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
@@ -516,7 +580,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
||||
optimizer_grouped_parameters,
|
||||
decouple=True,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -871,6 +937,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
@@ -994,8 +1063,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
|
||||
def create_optimizer(self):
|
||||
@@ -1034,6 +1104,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
@@ -1571,6 +1644,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_arguments_kwargs[
|
||||
@@ -1755,6 +1831,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
@@ -1817,6 +1897,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator = MultiModalChatDataCollator
|
||||
kwargs["processor"] = self.processor
|
||||
kwargs["chat_template"] = training_args.chat_template
|
||||
kwargs["chat_template_type"] = self.cfg.chat_template
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
@@ -2028,6 +2109,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
|
||||
@@ -40,7 +40,7 @@ class TRLPPOTrainer(PPOTrainer):
|
||||
query_tensors,
|
||||
return_prompt=False,
|
||||
generate_ref_response=True,
|
||||
**generation_kwargs
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
||||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
||||
|
||||
325
src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.md
Normal file
325
src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.md
Normal file
@@ -0,0 +1,325 @@
|
||||
Acknowledgements
|
||||
|
||||
Portions of this Cut Cross Entropy Software may utilize the following copyrighted
|
||||
material, the use of which is hereby acknowledged.
|
||||
|
||||
|
||||
------
|
||||
|
||||
|
||||
PyTorch
|
||||
|
||||
From PyTorch:
|
||||
|
||||
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||
|
||||
From Caffe2:
|
||||
|
||||
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||
|
||||
All contributions by Facebook:
|
||||
Copyright (c) 2016 Facebook Inc.
|
||||
|
||||
All contributions by Google:
|
||||
Copyright (c) 2015 Google Inc.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Yangqing Jia:
|
||||
Copyright (c) 2015 Yangqing Jia
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Kakao Brain:
|
||||
Copyright 2019-2020 Kakao Brain
|
||||
|
||||
All contributions by Cruise LLC:
|
||||
Copyright (c) 2022 Cruise LLC.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Arm:
|
||||
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
All other contributions:
|
||||
Copyright(c) 2015, 2016 the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||
copyright over their contributions to Caffe2. The project versioning records
|
||||
all such contribution and copyright details. If a contributor wants to further
|
||||
mark their specific copyright on a particular contribution, they should
|
||||
indicate their copyright solely in the commit message of the change when it is
|
||||
committed.
|
||||
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||
and IDIAP Research Institute nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
Triton
|
||||
|
||||
/*
|
||||
* Copyright 2018-2020 Philippe Tillet
|
||||
* Copyright 2020-2022 OpenAI
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
|
||||
Transformers
|
||||
|
||||
Copyright 2018- The Hugging Face team. All rights reserved.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
47
src/axolotl/integrations/cut_cross_entropy/LICENSE
Normal file
47
src/axolotl/integrations/cut_cross_entropy/LICENSE
Normal file
@@ -0,0 +1,47 @@
|
||||
Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
IMPORTANT: This Apple software is supplied to you by Apple
|
||||
Inc. ("Apple") in consideration of your agreement to the following
|
||||
terms, and your use, installation, modification or redistribution of
|
||||
this Apple software constitutes acceptance of these terms. If you do
|
||||
not agree with these terms, please do not use, install, modify or
|
||||
redistribute this Apple software.
|
||||
|
||||
In consideration of your agreement to abide by the following terms, and
|
||||
subject to these terms, Apple grants you a personal, non-exclusive
|
||||
license, under Apple's copyrights in this original Apple software (the
|
||||
"Apple Software"), to use, reproduce, modify and redistribute the Apple
|
||||
Software, with or without modifications, in source and/or binary forms;
|
||||
provided that if you redistribute the Apple Software in its entirety and
|
||||
without modifications, you must retain this notice and the following
|
||||
text and disclaimers in all such redistributions of the Apple Software.
|
||||
Neither the name, trademarks, service marks or logos of Apple Inc. may
|
||||
be used to endorse or promote products derived from the Apple Software
|
||||
without specific prior written permission from Apple. Except as
|
||||
expressly stated in this notice, no other rights or licenses, express or
|
||||
implied, are granted by Apple herein, including but not limited to any
|
||||
patent rights that may be infringed by your derivative works or by other
|
||||
works in which the Apple Software may be incorporated.
|
||||
|
||||
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
|
||||
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
|
||||
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
|
||||
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
|
||||
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
|
||||
|
||||
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
|
||||
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
|
||||
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
|
||||
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
|
||||
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
-------------------------------------------------------------------------------
|
||||
SOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY:
|
||||
|
||||
The Cut Cross Entropy software includes a number of subcomponents with separate
|
||||
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md.
|
||||
-------------------------------------------------------------------------------
|
||||
10
src/axolotl/integrations/cut_cross_entropy/README.md
Normal file
10
src/axolotl/integrations/cut_cross_entropy/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Cut Cross Entropy
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
cut_cross_entropy: true
|
||||
```
|
||||
83
src/axolotl/integrations/cut_cross_entropy/__init__.py
Normal file
83
src/axolotl/integrations/cut_cross_entropy/__init__.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Module for the Plugin for Cut Cross Entropy integration with Axolotl.
|
||||
|
||||
Cut Cross Entropy is an optimized implementation of cross entropy loss
|
||||
from Apple's ML team.
|
||||
"""
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils import get_pytorch_version
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
|
||||
)
|
||||
|
||||
|
||||
class CutCrossEntropyPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for Cut Cross Entropy integration with Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.cut_cross_entropy.CutCrossEntropyArgs"
|
||||
|
||||
def _check_requirements(self):
|
||||
"""Check if all requirements are met."""
|
||||
# Check PyTorch version
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
raise ImportError(
|
||||
"Cut Cross Entropy requires PyTorch >= 2.4.0. "
|
||||
f"Current version: {torch.__version__}"
|
||||
)
|
||||
|
||||
# Check if cut_cross_entropy is installed
|
||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||
if cce_spec is None:
|
||||
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||
|
||||
cce_spec_transformers = importlib.util.find_spec(
|
||||
"cut_cross_entropy.transformers"
|
||||
)
|
||||
if cce_spec_transformers is None:
|
||||
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply cut cross entropy before model loading if enabled."""
|
||||
if cfg.cut_cross_entropy:
|
||||
self._check_requirements()
|
||||
|
||||
from cut_cross_entropy.transformers import cce_patch
|
||||
|
||||
with zero_only():
|
||||
LOG.info(
|
||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||
)
|
||||
|
||||
# The patch checks model_type internally
|
||||
cce_patch(cfg.model_config_type)
|
||||
42
src/axolotl/integrations/cut_cross_entropy/args.py
Normal file
42
src/axolotl/integrations/cut_cross_entropy/args.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Module for handling Cut Cross Entropy input arguments.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
|
||||
|
||||
|
||||
class CutCrossEntropyArgs(BaseModel):
|
||||
"""
|
||||
Input args for Cut Cross Entropy.
|
||||
"""
|
||||
|
||||
cut_cross_entropy: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_dtype_is_half(cls, data):
|
||||
if not (data.get("bf16") or data.get("fp16")):
|
||||
raise ValueError(
|
||||
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
|
||||
"Please set `bf16` or `fp16` to `True`."
|
||||
)
|
||||
|
||||
return data
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -94,13 +93,32 @@ def replace_llama_qkv_with_fused(model):
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def patch_llama_cross_entropy():
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
def patch_fa_llama_cross_entropy():
|
||||
LOG.info(
|
||||
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
||||
)
|
||||
from flash_attn.ops.triton.cross_entropy import (
|
||||
cross_entropy_loss as flash_attn_cross_entropy_loss,
|
||||
)
|
||||
|
||||
def fa2_fixed_cross_entropy(
|
||||
source,
|
||||
target,
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
loss, _ = flash_attn_cross_entropy_loss(
|
||||
source, target, ignore_index=ignore_index
|
||||
)
|
||||
if reduction == "sum":
|
||||
loss = loss.sum() / num_items_in_batch
|
||||
else:
|
||||
loss = loss.sum() / (target != ignore_index).sum()
|
||||
return loss
|
||||
|
||||
transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy
|
||||
|
||||
|
||||
def patch_llama_rms_norm():
|
||||
@@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn(
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
patch_llama_cross_entropy()
|
||||
patch_fa_llama_cross_entropy()
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
|
||||
@@ -46,9 +46,10 @@ def reset_optimizer(
|
||||
*,
|
||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: List[str],
|
||||
prune_ratio: float = 0.9,
|
||||
optimizer_magnitude_pruning: float = 0.9,
|
||||
):
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||
# pylint:disable=unused-argument
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
|
||||
n_zeros = 0
|
||||
n_total = 0
|
||||
|
||||
@@ -56,16 +57,22 @@ def reset_optimizer(
|
||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||
optimizer_state = optimizer.optim.state
|
||||
|
||||
for param in reset_params:
|
||||
param_state = optimizer_state[param]
|
||||
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
|
||||
continue
|
||||
for key in optimizer_state_keys:
|
||||
pruning_fn(
|
||||
param_state[key]
|
||||
) # pruning fn has to be inplace to keep the same keys in the dict
|
||||
n_total += param_state[key].numel()
|
||||
n_zeros += torch.sum(param_state[key] == 0).item()
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
state = optimizer_state[param]
|
||||
for key, value in state.items():
|
||||
if key not in optimizer_state_keys:
|
||||
continue
|
||||
if torch.is_tensor(value):
|
||||
try:
|
||||
pruning_fn(value)
|
||||
n_total += value.numel()
|
||||
n_zeros += torch.sum(value == 0).item()
|
||||
except RuntimeError as exc:
|
||||
if "quantile() input tensor is too large" in str(exc):
|
||||
pass
|
||||
else:
|
||||
raise exc
|
||||
|
||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||
@@ -129,6 +136,9 @@ class ReLoRACallback(TrainerCallback):
|
||||
|
||||
if "adam" in args.optim.lower():
|
||||
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
||||
if "8bit" in args.optim.lower():
|
||||
optimizer_state_keys.append("state1")
|
||||
optimizer_state_keys.append("state2")
|
||||
else:
|
||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||
|
||||
@@ -160,7 +170,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
optimizer,
|
||||
reset_params=lora_params,
|
||||
optimizer_state_keys=optimizer_state_keys,
|
||||
prune_ratio=args.relora_prune_ratio,
|
||||
optimizer_magnitude_pruning=args.relora_prune_ratio,
|
||||
)
|
||||
|
||||
if self.quantized:
|
||||
|
||||
@@ -259,11 +259,31 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
try:
|
||||
trainer.create_model_card(
|
||||
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
|
||||
)
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
# Check to make sure the base model is from HuggingFace not a local directory
|
||||
hf_api = HfApi()
|
||||
hf_api.model_info(cfg.base_model)
|
||||
|
||||
model_card_kwarg = {
|
||||
"model_name": cfg.output_dir.lstrip("./")
|
||||
.encode("utf-8")
|
||||
.decode("utf-8")
|
||||
}
|
||||
if cfg.datasets is not None:
|
||||
if cfg.rl is not None or cfg.reward_model:
|
||||
model_card_kwarg["dataset_name"] = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
else:
|
||||
model_card_kwarg["dataset_tags"] = [
|
||||
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
trainer.create_model_card(**model_card_kwarg)
|
||||
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
|
||||
pass
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
"""
|
||||
Basic utils for Axolotl
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
@@ -10,3 +14,23 @@ def is_mlflow_available():
|
||||
|
||||
def is_comet_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def get_pytorch_version() -> tuple[int, int, int]:
|
||||
"""
|
||||
Get Pytorch version as a tuple of (major, minor, patch).
|
||||
"""
|
||||
torch_version = torch.__version__
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
|
||||
if not version_match:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = int(patch) if patch is not None else 0 # Default patch to 0 if not present
|
||||
return major, minor, patch
|
||||
|
||||
|
||||
# pylint: enable=duplicate-code
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
"""Benchmarking and measurement utilities"""
|
||||
import functools
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
from pynvml.nvml import NVMLError
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.distributed import get_device_type
|
||||
|
||||
try:
|
||||
from pynvml import (
|
||||
NVMLError,
|
||||
nvmlDeviceGetHandleByIndex,
|
||||
nvmlDeviceGetMemoryInfo,
|
||||
nvmlInit,
|
||||
)
|
||||
except ImportError:
|
||||
NVMLError = None
|
||||
nvmlDeviceGetHandleByIndex = None
|
||||
nvmlDeviceGetMemoryInfo = None
|
||||
nvmlInit = None
|
||||
|
||||
|
||||
def check_cuda_device(default_value):
|
||||
"""
|
||||
@@ -68,10 +79,12 @@ def gpu_memory_usage_smi(device=0):
|
||||
device = device.index
|
||||
if isinstance(device, str) and device.startswith("cuda:"):
|
||||
device = int(device[5:])
|
||||
if not nvmlInit:
|
||||
return 0.0
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
nvmlInit()
|
||||
handle = nvmlDeviceGetHandleByIndex(device)
|
||||
info = nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used / 1024.0**3
|
||||
except NVMLError:
|
||||
return 0.0
|
||||
|
||||
@@ -28,6 +28,7 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
@@ -46,6 +47,7 @@ from axolotl.utils.distributed import (
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
@@ -64,7 +66,10 @@ class EvalFirstStepCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and state.global_step == 1
|
||||
):
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
@@ -375,7 +380,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
for metric in self.cfg.eval_causal_lm_metrics:
|
||||
if metric == "perplexity":
|
||||
max_seq_len = self.cfg.eval_max_new_tokens
|
||||
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
|
||||
metrics[metric] = Perplexity(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
metrics[metric] = evaluate.load(metric)
|
||||
@@ -392,8 +400,11 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
eval_dataloader,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
trainer.model.eval()
|
||||
device = torch.device(self.cfg.device)
|
||||
trainer.model_wrapped.eval()
|
||||
|
||||
device = torch.device(
|
||||
self.cfg.device
|
||||
) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
generation_config = GenerationConfig(
|
||||
@@ -430,6 +441,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
for k in metric._feature_names() # pylint: disable=protected-access
|
||||
if k in kwargs
|
||||
}
|
||||
|
||||
if isinstance(metric, Perplexity):
|
||||
metric_kwargs["model"] = trainer.model_wrapped
|
||||
|
||||
metric_score = metric.compute(**metric_kwargs)
|
||||
return (
|
||||
metric_score["score"]
|
||||
@@ -465,89 +480,97 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
def predict_with_generate():
|
||||
eval_src, eval_pred, eval_ref = [], [], []
|
||||
|
||||
for batch in tqdm(eval_dataloader):
|
||||
batch_labels = batch["labels"].to(device)
|
||||
batch_input_ids = batch["input_ids"].to(device)
|
||||
with unwrap_model_for_generation(
|
||||
trainer.model_wrapped, trainer.accelerator
|
||||
) as unwrapped_model:
|
||||
for batch in tqdm(eval_dataloader, disable=not is_main_process()):
|
||||
batch_labels = batch["labels"].to(device)
|
||||
batch_input_ids = batch["input_ids"].to(device)
|
||||
|
||||
if "position_ids" in batch:
|
||||
batch_pos_ids = batch["position_ids"].tolist()
|
||||
else:
|
||||
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||
|
||||
prompt_token_ids_list = []
|
||||
completion_token_ids_list = []
|
||||
|
||||
for input_ids_all, labels_all, pos_ids in zip(
|
||||
batch_input_ids,
|
||||
batch_labels,
|
||||
batch_pos_ids,
|
||||
):
|
||||
if pos_ids is None:
|
||||
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||
if "position_ids" in batch:
|
||||
batch_pos_ids = batch["position_ids"].tolist()
|
||||
else:
|
||||
pos_ranges = find_ranges(pos_ids)
|
||||
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||
|
||||
for pos_range in pos_ranges:
|
||||
start, end = pos_range
|
||||
if start == end:
|
||||
continue
|
||||
prompt_token_ids_list = []
|
||||
completion_token_ids_list = []
|
||||
|
||||
input_ids = input_ids_all[start : end + 1]
|
||||
labels = labels_all[start : end + 1]
|
||||
for input_ids_all, labels_all, pos_ids in zip(
|
||||
batch_input_ids,
|
||||
batch_labels,
|
||||
batch_pos_ids,
|
||||
):
|
||||
if pos_ids is None:
|
||||
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||
else:
|
||||
pos_ranges = find_ranges(pos_ids)
|
||||
|
||||
tokens_without_loss = labels == IGNORE_INDEX
|
||||
tokens_with_loss = labels != IGNORE_INDEX
|
||||
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
|
||||
prompt_token_includes = (
|
||||
tokens_without_loss & tokens_exclude_padding
|
||||
for pos_range in pos_ranges:
|
||||
start, end = pos_range
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
input_ids = input_ids_all[start : end + 1]
|
||||
labels = labels_all[start : end + 1]
|
||||
|
||||
tokens_without_loss = labels == IGNORE_INDEX
|
||||
tokens_with_loss = labels != IGNORE_INDEX
|
||||
tokens_exclude_padding = (
|
||||
input_ids != tokenizer.pad_token_id
|
||||
)
|
||||
prompt_token_includes = (
|
||||
tokens_without_loss & tokens_exclude_padding
|
||||
)
|
||||
|
||||
prompt_token_ids = input_ids[prompt_token_includes]
|
||||
prompt_token_ids_list.append(prompt_token_ids)
|
||||
|
||||
completion_token_ids = input_ids[tokens_with_loss]
|
||||
completion_token_ids_list.append(completion_token_ids)
|
||||
|
||||
prompt_texts = tokenizer.batch_decode(
|
||||
prompt_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
completion_texts = tokenizer.batch_decode(
|
||||
completion_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_encoding = tokenizer(
|
||||
prompt_texts, padding=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
predictions = unwrapped_model.generate(
|
||||
**prompt_encoding, generation_config=generation_config
|
||||
)
|
||||
|
||||
prompt_token_ids = input_ids[prompt_token_includes]
|
||||
prompt_token_ids_list.append(prompt_token_ids)
|
||||
del prompt_encoding
|
||||
|
||||
completion_token_ids = input_ids[tokens_with_loss]
|
||||
completion_token_ids_list.append(completion_token_ids)
|
||||
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||
prediction_without_prompt_tokens_list = []
|
||||
for prompt_token_ids, prediction_tokens in zip(
|
||||
prompt_token_ids_list, prediction_all_tokens
|
||||
):
|
||||
prediction_without_prompt_tokens = prediction_tokens[
|
||||
len(prompt_token_ids) :
|
||||
]
|
||||
prediction_without_prompt_tokens_list.append(
|
||||
prediction_without_prompt_tokens
|
||||
)
|
||||
|
||||
prompt_texts = tokenizer.batch_decode(
|
||||
prompt_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
completion_texts = tokenizer.batch_decode(
|
||||
completion_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_encoding = tokenizer(
|
||||
prompt_texts, padding=True, return_tensors="pt"
|
||||
).to(self.cfg.device)
|
||||
predictions = trainer.model.generate(
|
||||
**prompt_encoding, generation_config=generation_config
|
||||
predicted_texts = tokenizer.batch_decode(
|
||||
prediction_without_prompt_tokens_list,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||
prediction_without_prompt_tokens_list = []
|
||||
for prompt_token_ids, prediction_tokens in zip(
|
||||
prompt_token_ids_list, prediction_all_tokens
|
||||
):
|
||||
prediction_without_prompt_tokens = prediction_tokens[
|
||||
len(prompt_token_ids) :
|
||||
]
|
||||
prediction_without_prompt_tokens_list.append(
|
||||
prediction_without_prompt_tokens
|
||||
)
|
||||
|
||||
predicted_texts = tokenizer.batch_decode(
|
||||
prediction_without_prompt_tokens_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
eval_src.extend(prompt_texts)
|
||||
eval_pred.extend(predicted_texts)
|
||||
eval_ref.extend(completion_texts)
|
||||
eval_src.extend(prompt_texts)
|
||||
eval_pred.extend(predicted_texts)
|
||||
eval_ref.extend(completion_texts)
|
||||
|
||||
return eval_src, eval_pred, eval_ref
|
||||
|
||||
if is_main_process():
|
||||
eval_preds = predict_with_generate()
|
||||
trainer.log(evaluate_preds(*eval_preds))
|
||||
eval_preds = predict_with_generate()
|
||||
trainer.log(evaluate_preds(*eval_preds))
|
||||
|
||||
return control
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from transformers.modeling_outputs import CausalLMOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
|
||||
class Perplexity:
|
||||
"""
|
||||
@@ -17,16 +19,13 @@ class Perplexity:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_seq_len: int,
|
||||
stride: int = 512,
|
||||
) -> None:
|
||||
self.max_seq_len = max_seq_len
|
||||
self.stride = stride
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = model.device
|
||||
self.name = "perplexity"
|
||||
|
||||
def _feature_names(self) -> List[str]:
|
||||
@@ -34,6 +33,7 @@ class Perplexity:
|
||||
|
||||
def compute(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
references: Optional[List[str]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
@@ -41,17 +41,21 @@ class Perplexity:
|
||||
"""
|
||||
assert references is not None, "Missing parameter: references"
|
||||
|
||||
model.eval()
|
||||
|
||||
references_tokenized = self.tokenizer(
|
||||
references, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
|
||||
input_ids = input_ids.to(self.device)
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
sequence_length = input_ids.size(1)
|
||||
|
||||
losses = []
|
||||
prev_end_loc = 0
|
||||
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
|
||||
for begin_loc in tqdm(
|
||||
range(0, sequence_length, self.stride), disable=not is_main_process()
|
||||
):
|
||||
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
|
||||
trg_len = end_loc - prev_end_loc
|
||||
input_ids_slice = input_ids[:, begin_loc:end_loc]
|
||||
@@ -59,7 +63,7 @@ class Perplexity:
|
||||
labels_slice[:, :-trg_len] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
outputs: CausalLMOutput = self.model(
|
||||
outputs: CausalLMOutput = model(
|
||||
input_ids=input_ids_slice, labels=labels_slice
|
||||
)
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,8 +1,10 @@
|
||||
"""
|
||||
Collators for multi-modal chat messages and packing
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||
@@ -20,6 +22,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
processor: ProcessorMixin
|
||||
return_tensors: str = "pt"
|
||||
chat_template: Optional[str] = None
|
||||
chat_template_type: Optional[str] = None
|
||||
packing: bool = False
|
||||
max_images: int = -1
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
@@ -30,38 +33,190 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
raise ValueError("Packing is currently not supported.")
|
||||
|
||||
def torch_call(
|
||||
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
||||
) -> Dict[str, Any]:
|
||||
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
||||
) -> dict[str, Any]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
|
||||
return self.__class__.process_rows(
|
||||
examples, self.processor, self.chat_template, self.max_images
|
||||
examples,
|
||||
self.processor,
|
||||
self.chat_template,
|
||||
self.max_images,
|
||||
chat_template_type=self.chat_template_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||
def preprocess(examples: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Preprocess conversation examples to ensure consistent format.
|
||||
Converts different conversation formats to OpenAI format with 'messages'.
|
||||
Supports two formats:
|
||||
1. OpenAI format with 'messages'
|
||||
2. Legacy format with 'conversations'
|
||||
|
||||
Args:
|
||||
examples: list of conversation dictionaries
|
||||
Returns:
|
||||
dict in OpenAI format with 'messages' key
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation format is not supported
|
||||
"""
|
||||
role_mapping = {
|
||||
"human": "user",
|
||||
"gpt": "assistant",
|
||||
}
|
||||
|
||||
def normalize_role(role: str) -> str:
|
||||
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||
return role_mapping.get(role, role)
|
||||
|
||||
def convert_legacy_format(example: dict) -> dict:
|
||||
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||
messages = [
|
||||
{
|
||||
"role": normalize_role(convo["from"]),
|
||||
"content": convo["value"],
|
||||
}
|
||||
for convo in example["conversations"]
|
||||
]
|
||||
|
||||
# Create new dict without 'conversations' key
|
||||
result = deepcopy(example)
|
||||
result.pop("conversations")
|
||||
return {"messages": messages, **result}
|
||||
|
||||
processed_examples = []
|
||||
for example in examples:
|
||||
# OpenAI format
|
||||
if "messages" in example:
|
||||
processed_examples.append(example)
|
||||
|
||||
# Legacy format
|
||||
elif "conversations" in example:
|
||||
processed_examples.append(convert_legacy_format(example))
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only `messages` and `conversations` message keys are currently supported."
|
||||
)
|
||||
|
||||
return processed_examples
|
||||
|
||||
@staticmethod
|
||||
def process_images(examples, max_images):
|
||||
"""
|
||||
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
||||
|
||||
Args:
|
||||
examples: List of dictionaries that may contain 'images' key
|
||||
max_images: Maximum number of images to keep per example (0 means no limit)
|
||||
|
||||
Returns:
|
||||
Either None (if no images) or List[Image objects] (if all examples have images)
|
||||
|
||||
Raises:
|
||||
ValueError: If there's a mix of None and non-None images
|
||||
"""
|
||||
|
||||
def get_image(example):
|
||||
if "images" not in example:
|
||||
return None
|
||||
images = example["images"]
|
||||
if isinstance(images, str):
|
||||
return Image.open(images)
|
||||
return images
|
||||
|
||||
images = [get_image(example) for example in examples]
|
||||
|
||||
# Count None and non-None images
|
||||
none_count = sum(1 for img in images if img is None)
|
||||
|
||||
# All images are None
|
||||
if none_count == len(images):
|
||||
return None
|
||||
|
||||
# Mix of None and non-None images
|
||||
if none_count > 0:
|
||||
raise ValueError(
|
||||
"All images should be either None or not None. "
|
||||
"Please provide images for all examples or None."
|
||||
)
|
||||
|
||||
# Apply max_images limit if specified
|
||||
if max_images > 0:
|
||||
images = [
|
||||
(
|
||||
img_batch[:max_images]
|
||||
if isinstance(img_batch, (list, tuple))
|
||||
else img_batch
|
||||
)
|
||||
for img_batch in images
|
||||
]
|
||||
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def pixtral_chat_conversion(messages):
|
||||
is_single_message = not isinstance(messages, list)
|
||||
if is_single_message:
|
||||
messages = [messages]
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if message["role"] == "user":
|
||||
for j, content in enumerate(message["content"]):
|
||||
if "type" in content and content["type"] == "text":
|
||||
messages[i]["content"][j] = {
|
||||
"type": "text",
|
||||
"content": content["text"],
|
||||
}
|
||||
|
||||
if message["role"] == "assistant":
|
||||
messages[i]["content"] = message["content"][0]["text"]
|
||||
|
||||
if is_single_message:
|
||||
return messages[0]
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def process_rows(
|
||||
examples,
|
||||
processor,
|
||||
chat_template,
|
||||
max_images,
|
||||
length_only=False,
|
||||
chat_template_type=None,
|
||||
):
|
||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||
# see also DataCollatorWithFlattening and DefaultDataCollator
|
||||
|
||||
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||
# use this as a starting point
|
||||
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
example["messages"], chat_template=chat_template, tokenize=False
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
images = [
|
||||
Image.open(example["images"])
|
||||
if isinstance(example["images"], str)
|
||||
else example["images"]
|
||||
for example in examples
|
||||
]
|
||||
# Preprocess the examples
|
||||
examples = __class__.preprocess(examples)
|
||||
|
||||
if max_images > 0:
|
||||
images = [img_batch[:max_images] for img_batch in images]
|
||||
# Get the texts and images, and apply the chat template
|
||||
if chat_template_type == "pixtral":
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
__class__.pixtral_chat_conversion(example["messages"]),
|
||||
chat_template=chat_template,
|
||||
tokenize=False,
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
example["messages"], chat_template=chat_template, tokenize=False
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
|
||||
images = __class__.process_images(examples, max_images=max_images)
|
||||
if chat_template_type == "llava":
|
||||
# LLava1.5 does not support multiple images
|
||||
images = [image[0] for image in images]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
@@ -70,9 +225,12 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.image_token
|
||||
)
|
||||
if chat_template_type == "qwen2_vl":
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
|
||||
else:
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.image_token
|
||||
)
|
||||
labels[labels == image_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
@@ -131,7 +132,7 @@ def normalize_config(cfg):
|
||||
|
||||
cfg.is_multimodal = (
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type in ["llava", "mllama"]
|
||||
and model_config.model_type in ["llava", "mllama", "qwen2_vl", "qwen2_5_vl"]
|
||||
or any(
|
||||
multimodal_name in cfg.base_model.lower()
|
||||
for multimodal_name in [
|
||||
@@ -144,7 +145,12 @@ def normalize_config(cfg):
|
||||
cfg.processor_config = (
|
||||
cfg.processor_config or cfg.base_model_config or cfg.base_model
|
||||
)
|
||||
model_config = model_config.text_config
|
||||
|
||||
try:
|
||||
model_config = model_config.text_config
|
||||
except AttributeError:
|
||||
# for qwen2_vl
|
||||
model_config = model_config.get_text_config()
|
||||
|
||||
cfg.model_config_type = model_config.model_type
|
||||
|
||||
@@ -229,7 +235,11 @@ def normalize_cfg_datasets(cfg):
|
||||
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
||||
|
||||
|
||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
def validate_config(
|
||||
cfg: DictDefault,
|
||||
capabilities: Optional[dict] = None,
|
||||
env_capabilities: Optional[dict] = None,
|
||||
):
|
||||
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
||||
AxolotlInputConfig = AxolotlInputConfigBase
|
||||
|
||||
@@ -239,14 +249,35 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
AxolotlInputConfig, # pylint: disable=invalid-name
|
||||
) = merge_input_args()
|
||||
|
||||
if capabilities:
|
||||
if capabilities or env_capabilities:
|
||||
if (capabilities and not env_capabilities) or (
|
||||
env_capabilities and not capabilities
|
||||
):
|
||||
raise ValueError(
|
||||
"Both capabilities and env_capabilities must be provided or not provided."
|
||||
)
|
||||
|
||||
return DictDefault(
|
||||
dict(
|
||||
AxolotlConfigWCapabilities(
|
||||
**cfg.to_dict(), capabilities=capabilities
|
||||
**cfg.to_dict(),
|
||||
capabilities=capabilities,
|
||||
env_capabilities=env_capabilities,
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
|
||||
return DictDefault(
|
||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||
)
|
||||
|
||||
|
||||
def prepare_plugins(cfg):
|
||||
"""
|
||||
Prepare the plugins for the configuration
|
||||
"""
|
||||
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
|
||||
@@ -9,6 +9,7 @@ import os
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
@@ -21,7 +22,7 @@ from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.config.models.internals import GPUCapabilities
|
||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
|
||||
@@ -50,6 +51,7 @@ class ChatTemplate(str, Enum):
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
llava = "llava" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
@@ -59,6 +61,8 @@ class ChatTemplate(str, Enum):
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
pixtral = "pixtral" # pylint: disable=invalid-name
|
||||
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
@@ -322,11 +326,13 @@ class LoraConfig(BaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_adapter(cls, data):
|
||||
if not data.get("adapter") and (
|
||||
data.get("load_in_8bit") or data.get("load_in_4bit")
|
||||
if (
|
||||
not data.get("adapter")
|
||||
and not data.get("inference")
|
||||
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||
):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
return data
|
||||
@@ -430,6 +436,8 @@ class HyperparametersConfig(BaseModel):
|
||||
group_by_length: Optional[bool] = None
|
||||
|
||||
learning_rate: Union[str, float]
|
||||
embedding_lr: Optional[float] = None
|
||||
embedding_lr_scale: Optional[float] = None
|
||||
weight_decay: Optional[float] = 0.0
|
||||
optimizer: Optional[
|
||||
Union[
|
||||
@@ -622,6 +630,7 @@ class AxolotlInputConfig(
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
)
|
||||
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
||||
dataset_exact_deduplication: Optional[bool] = None
|
||||
dataset_keep_in_memory: Optional[bool] = None
|
||||
dataloader_pin_memory: Optional[bool] = None
|
||||
dataloader_num_workers: Optional[int] = None
|
||||
@@ -1474,6 +1483,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
capabilities: GPUCapabilities
|
||||
env_capabilities: EnvCapabilities
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_bf16(self):
|
||||
@@ -1548,3 +1558,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_adopt_torch_version(cls, data):
|
||||
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
|
||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||
raise ValueError(
|
||||
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel):
|
||||
n_gpu: int = Field(default=1)
|
||||
n_node: int = Field(default=1)
|
||||
compute_capability: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class EnvCapabilities(BaseModel):
|
||||
"""model to manage the environment capabilities statically"""
|
||||
|
||||
torch_version: Optional[str] = Field(default=None)
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.kto import load as load_kto
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.utils import md5
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
@@ -208,4 +208,9 @@ def load_prepare_dpo_datasets(cfg):
|
||||
if eval_dataset and not eval_is_preprocessed:
|
||||
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
||||
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=train_dataset, eval_dataset=eval_dataset
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
@@ -44,7 +44,7 @@ from axolotl.prompters import (
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||
from axolotl.utils.data.utils import md5
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||
from axolotl.utils.trainer import (
|
||||
@@ -136,8 +136,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
if cfg.dataset_exact_deduplication:
|
||||
LOG.info("Deduplication not available for pretrained datasets")
|
||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
if total_eval_steps == 0:
|
||||
@@ -584,7 +585,8 @@ def load_prepare_datasets(
|
||||
)
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
|
||||
if cfg.dataset_exact_deduplication:
|
||||
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=val_set_size,
|
||||
shuffle=False,
|
||||
@@ -596,12 +598,17 @@ def load_prepare_datasets(
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
elif split == "test":
|
||||
if cfg.dataset_exact_deduplication:
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
else:
|
||||
eval_dataset = dataset
|
||||
train_dataset = None
|
||||
eval_dataset = dataset
|
||||
else:
|
||||
train_dataset = dataset
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
else:
|
||||
train_dataset = dataset
|
||||
eval_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset, prompters
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
"""data handling helpers"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
@@ -8,3 +13,96 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
||||
except TypeError:
|
||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||
|
||||
|
||||
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
|
||||
|
||||
|
||||
def deduplicate_dataset(
|
||||
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
|
||||
) -> Dataset:
|
||||
unique_indices = []
|
||||
|
||||
for idx, row in enumerate(dataset):
|
||||
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
|
||||
if row_hash not in seen_hashes:
|
||||
seen_hashes[row_hash] = [idx]
|
||||
unique_indices.append(idx)
|
||||
else:
|
||||
# Check for collision by looking up the original dataset indices
|
||||
original_indices = seen_hashes[row_hash]
|
||||
is_duplicate = False
|
||||
for original_idx in original_indices:
|
||||
if (
|
||||
not idx == original_idx
|
||||
and original_idx < len(dataset)
|
||||
and str(dataset[original_idx]) == str(row)
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
# Check in the other dataset if provided
|
||||
if other_dataset is not None:
|
||||
if original_idx < len(other_dataset) and str(
|
||||
other_dataset[original_idx]
|
||||
) == str(row):
|
||||
is_duplicate = True
|
||||
break
|
||||
if not is_duplicate:
|
||||
seen_hashes[row_hash].append(idx)
|
||||
unique_indices.append(idx)
|
||||
continue
|
||||
return dataset.select(unique_indices)
|
||||
|
||||
|
||||
def deduplicate_and_log_datasets(
|
||||
*,
|
||||
train_dataset: Dataset = None,
|
||||
eval_dataset: Dataset = None,
|
||||
dataset: Dataset = None,
|
||||
) -> tuple[Dataset, Dataset, Dataset]:
|
||||
"""
|
||||
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
|
||||
|
||||
Returns:
|
||||
tuple: Deduplicated train, eval, and additional datasets.
|
||||
"""
|
||||
seen_hashes: dict[str, list[int]] = {}
|
||||
|
||||
# Handle cases where datasets are None
|
||||
if train_dataset is not None:
|
||||
LOG.info(
|
||||
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
|
||||
)
|
||||
train_dataset = deduplicate_dataset(
|
||||
dataset=train_dataset, seen_hashes=seen_hashes
|
||||
)
|
||||
LOG.info(
|
||||
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
|
||||
)
|
||||
else:
|
||||
LOG.info("Train dataset is None. Skipping deduplication.")
|
||||
|
||||
if eval_dataset is not None:
|
||||
LOG.info(
|
||||
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
|
||||
)
|
||||
eval_dataset = deduplicate_dataset(
|
||||
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
|
||||
)
|
||||
LOG.info(
|
||||
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
|
||||
)
|
||||
else:
|
||||
LOG.info("Eval dataset is None. Skipping deduplication.")
|
||||
|
||||
if dataset is not None and (eval_dataset is None and train_dataset is None):
|
||||
LOG.info(
|
||||
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
|
||||
)
|
||||
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
|
||||
LOG.info(
|
||||
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
|
||||
)
|
||||
|
||||
return train_dataset, eval_dataset, dataset
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import gc
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import types
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
|
||||
import addict
|
||||
@@ -28,6 +30,7 @@ from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
@@ -89,7 +92,11 @@ def get_module_class_from_name(module, name):
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
|
||||
if cfg.is_multimodal:
|
||||
model_config = model_config.text_config
|
||||
try:
|
||||
model_config = model_config.text_config
|
||||
except AttributeError:
|
||||
# for qwen2_vl
|
||||
model_config = model_config.get_text_config()
|
||||
|
||||
quant_config_exists = (
|
||||
hasattr(model_config, "quantization_config")
|
||||
@@ -365,7 +372,11 @@ class ModelLoader:
|
||||
# init model config
|
||||
self.model_config = load_model_config(cfg)
|
||||
if cfg.is_multimodal:
|
||||
self.text_model_config = self.model_config.text_config
|
||||
try:
|
||||
self.text_model_config = self.model_config.text_config
|
||||
except AttributeError:
|
||||
# for qwen2_vl
|
||||
self.text_model_config = self.model_config.get_text_config()
|
||||
else:
|
||||
self.text_model_config = self.model_config
|
||||
|
||||
@@ -409,7 +420,7 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
if self.cfg.is_llama_derived_model:
|
||||
self.patch_loss()
|
||||
self.patch_loss_llama()
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
@@ -451,27 +462,34 @@ class ModelLoader:
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
||||
|
||||
def patch_loss(self) -> None:
|
||||
@cached_property
|
||||
def has_flash_attn(self) -> bool:
|
||||
"""Check if flash attention is installed"""
|
||||
return importlib.util.find_spec("flash_attn") is not None
|
||||
|
||||
def patch_loss_llama(self) -> None:
|
||||
"""
|
||||
Patch loss functions
|
||||
"""
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_llama_cross_entropy,
|
||||
patch_llama_rms_norm,
|
||||
)
|
||||
if self.has_flash_attn:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_fa_llama_cross_entropy,
|
||||
patch_llama_rms_norm,
|
||||
)
|
||||
|
||||
if self.cfg.flash_attn_cross_entropy:
|
||||
patch_llama_cross_entropy()
|
||||
if self.cfg.flash_attn_rms_norm:
|
||||
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
|
||||
patch_fa_llama_cross_entropy()
|
||||
elif self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
|
||||
patch_llama_rms_norm()
|
||||
elif self.cfg.unsloth_rms_norm:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||
|
||||
patch_unsloth_layernorm()
|
||||
if self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
@@ -481,6 +499,7 @@ class ModelLoader:
|
||||
"""
|
||||
Modify all llama derived models in one block
|
||||
"""
|
||||
self.patch_loss_llama()
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
@@ -528,16 +547,6 @@ class ModelLoader:
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
)
|
||||
|
||||
if self.cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def set_auto_model_loader(self) -> None:
|
||||
"""set self.AutoModelLoader
|
||||
- default value: AutoModelForCausalLM (set at __init__)
|
||||
@@ -553,6 +562,10 @@ class ModelLoader:
|
||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
MllamaForConditionalGeneration
|
||||
)
|
||||
elif self.model_config.model_type == "qwen2_vl":
|
||||
self.AutoModelLoader = ( # pylint: disable=invalid-name
|
||||
AutoModelForImageTextToText
|
||||
)
|
||||
else:
|
||||
self.AutoModelLoader = (
|
||||
AutoModelForVision2Seq # pylint: disable=invalid-name
|
||||
@@ -1045,7 +1058,9 @@ class ModelLoader:
|
||||
and self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||
):
|
||||
resize_kwargs = {}
|
||||
if self.cfg.mean_resizing_embeddings is not None:
|
||||
if self.cfg.mean_resizing_embeddings is not None and not (
|
||||
self.model_config.model_type == "llava"
|
||||
):
|
||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||
else:
|
||||
@@ -1084,14 +1099,17 @@ class ModelLoader:
|
||||
|
||||
self.prepare_model(qlora_fsdp)
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp:
|
||||
LOG.info(
|
||||
"converting modules to %s for flash attention", self.cfg.torch_dtype
|
||||
)
|
||||
should_convert = (
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
|
||||
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
||||
)
|
||||
|
||||
if should_convert:
|
||||
LOG.info("Converting modules to %s", self.cfg.torch_dtype)
|
||||
self.convert_embedding_modules_dtype(
|
||||
embedding_modules,
|
||||
embedding_modules=embedding_modules,
|
||||
dist_dtype=self.cfg.torch_dtype,
|
||||
before_kbit_train_or_finetune=False,
|
||||
)
|
||||
|
||||
@@ -6,21 +6,29 @@ Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeo
|
||||
"""
|
||||
# mypy: ignore-errors
|
||||
# pylint: skip-file
|
||||
# flake8: noqa
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
from typing import Callable, List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import (
|
||||
from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
|
||||
DeviceDict,
|
||||
Optimizer,
|
||||
ParamsT,
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
_fused_doc,
|
||||
_get_capturable_supported_devices,
|
||||
_get_scalar_dtype,
|
||||
_get_value,
|
||||
_maximize_doc,
|
||||
_stack_if_compiling,
|
||||
_use_grad_for_differentiable,
|
||||
_view_as_real,
|
||||
)
|
||||
@@ -35,8 +43,9 @@ class ADOPT(Optimizer):
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||
eps: float = 1e-6,
|
||||
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
||||
weight_decay: float = 0.0,
|
||||
decoupled: bool = False,
|
||||
decouple: bool = False,
|
||||
*,
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
@@ -62,12 +71,14 @@ class ADOPT(Optimizer):
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
self.clip_lambda = clip_lambda
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
decoupled=decoupled,
|
||||
decouple=decouple,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
@@ -219,8 +230,9 @@ class ADOPT(Optimizer):
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group["lr"],
|
||||
clip_lambda=self.clip_lambda,
|
||||
weight_decay=group["weight_decay"],
|
||||
decoupled=group["decoupled"],
|
||||
decouple=group["decouple"],
|
||||
eps=group["eps"],
|
||||
maximize=group["maximize"],
|
||||
foreach=group["foreach"],
|
||||
@@ -247,8 +259,9 @@ def _single_tensor_adopt(
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
clip_lambda: Optional[Callable[[int], float]],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
decouple: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
@@ -276,14 +289,10 @@ def _single_tensor_adopt(
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
param.add_(param, alpha=-lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
if weight_decay != 0 and not decouple:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
@@ -293,20 +302,29 @@ def _single_tensor_adopt(
|
||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||
param = torch.view_as_real(param)
|
||||
|
||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||
if step == 1:
|
||||
if step == 0:
|
||||
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||
# update step
|
||||
step_t += 1
|
||||
continue
|
||||
|
||||
if weight_decay != 0 and decouple:
|
||||
param.add_(param, alpha=-lr * weight_decay)
|
||||
|
||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||
if step == 2:
|
||||
exp_avg.addcdiv_(grad, denom)
|
||||
else:
|
||||
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
||||
normed_grad = grad.div(denom)
|
||||
if clip_lambda is not None:
|
||||
clip = clip_lambda(step)
|
||||
normed_grad.clamp_(-clip, clip)
|
||||
|
||||
exp_avg.lerp_(normed_grad, 1 - beta1)
|
||||
|
||||
param.add_(exp_avg, alpha=-lr)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
|
||||
|
||||
def _multi_tensor_adopt(
|
||||
params: List[Tensor],
|
||||
@@ -321,8 +339,9 @@ def _multi_tensor_adopt(
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
clip_lambda: Optional[Callable[[int], float]],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
decouple: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
@@ -376,6 +395,51 @@ def _multi_tensor_adopt(
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||
|
||||
if weight_decay != 0 and not decouple:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
|
||||
if device_state_steps[0] == 0:
|
||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
continue
|
||||
|
||||
if weight_decay != 0 and decouple:
|
||||
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
||||
|
||||
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
||||
if clip_lambda is not None:
|
||||
clip = clip_lambda(device_state_steps[0])
|
||||
torch._foreach_maximum_(normed_grad, -clip)
|
||||
torch._foreach_minimum_(normed_grad, clip)
|
||||
|
||||
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
||||
|
||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
torch._foreach_addcmul_(
|
||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||
)
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
@@ -387,41 +451,6 @@ def _multi_tensor_adopt(
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
torch._foreach_add_(
|
||||
device_params, device_params, alpha=-lr * weight_decay
|
||||
)
|
||||
else:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
|
||||
if device_state_steps[0] == 1:
|
||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||
continue
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
||||
|
||||
if device_state_steps[0] == 2:
|
||||
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
||||
else:
|
||||
torch._foreach_mul_(device_exp_avgs, beta1)
|
||||
torch._foreach_addcdiv_(
|
||||
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
||||
)
|
||||
|
||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
torch._foreach_addcmul_(
|
||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||
)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||
def adopt(
|
||||
@@ -443,8 +472,9 @@ def adopt(
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
clip_lambda: Optional[Callable[[int], float]],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
decouple: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
):
|
||||
@@ -497,8 +527,9 @@ def adopt(
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
clip_lambda=clip_lambda,
|
||||
weight_decay=weight_decay,
|
||||
decoupled=decoupled,
|
||||
decouple=decouple,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
|
||||
@@ -15,17 +15,58 @@ def download_smollm2_135m_model():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tatsu_lab_alpaca_dataset():
|
||||
def download_llama_68m_random_model():
|
||||
# download the model
|
||||
snapshot_download("JackFram/llama-68m")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_qwen_2_5_half_billion_model():
|
||||
# download the model
|
||||
snapshot_download("Qwen/Qwen2.5-0.5B")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tatsu_lab_alpaca_dataset():
|
||||
# download the dataset
|
||||
snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_mhenrichsen_alpaca_2k_dataset():
|
||||
# download the model
|
||||
# download the dataset
|
||||
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
||||
# download the dataset
|
||||
snapshot_download(
|
||||
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
|
||||
)
|
||||
|
||||
|
||||
def download_mlabonne_finetome_100k_dataset():
|
||||
# download the dataset
|
||||
snapshot_download("mlabonne/FineTome-100k", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
||||
# download the dataset
|
||||
snapshot_download(
|
||||
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||
# download the dataset
|
||||
snapshot_download(
|
||||
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
# Create a temporary directory
|
||||
|
||||
32
tests/constants.py
Normal file
32
tests/constants.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# constants.py
|
||||
"""
|
||||
This module contains constants and configuration dictionaries used for
|
||||
datasets and other utilities in the Axolotl project, specifically for testing.
|
||||
"""
|
||||
# Configuration for Alpaca Messages Dataset
|
||||
ALPACA_MESSAGES_CONFIG_OG = {
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"chat_template": "llama3",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
}
|
||||
|
||||
# Revision configuration extending the original
|
||||
ALPACA_MESSAGES_CONFIG_REVISION = ALPACA_MESSAGES_CONFIG_OG.copy()
|
||||
ALPACA_MESSAGES_CONFIG_REVISION["revision"] = "ea82cff"
|
||||
|
||||
|
||||
SPECIAL_TOKENS = {
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
@@ -14,9 +14,7 @@ from axolotl.utils.models import load_model, load_tokenizer
|
||||
def fixture_cfg():
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.00005,
|
||||
@@ -33,6 +31,9 @@ def fixture_cfg():
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"model_config_type": "llama",
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
@@ -54,8 +54,10 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"max_steps": 10,
|
||||
}
|
||||
)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -99,8 +101,10 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"max_steps": 10,
|
||||
}
|
||||
)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
94
tests/e2e/integrations/test_cut_cross_entropy.py
Normal file
94
tests/e2e/integrations/test_cut_cross_entropy.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Simple end-to-end test for Cut Cross Entropy integration
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def min_cfg(temp_dir):
|
||||
return {
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"plugins": [
|
||||
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
|
||||
],
|
||||
"cut_cross_entropy": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"output_dir": temp_dir,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"max_steps": 10,
|
||||
"bf16": "auto",
|
||||
}
|
||||
|
||||
|
||||
class TestCutCrossEntropyIntegration:
|
||||
"""
|
||||
e2e tests for cut_cross_entropy integration with Axolotl
|
||||
"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_llama_w_cce(self, min_cfg, temp_dir):
|
||||
cfg = DictDefault(min_cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
with pytest.raises(ImportError):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
else:
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention_type",
|
||||
["flash_attention", "sdp_attention", "xformers_attention"],
|
||||
)
|
||||
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
|
||||
cfg = DictDefault(
|
||||
min_cfg
|
||||
| {
|
||||
attention_type: True,
|
||||
}
|
||||
)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
with pytest.raises(ImportError):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
else:
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
@@ -7,10 +7,13 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import most_recent_subdir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -26,7 +29,7 @@ class TestMultiGPUEval:
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"load_in_8bit": False,
|
||||
"load_in_4bit": True,
|
||||
"strict": False,
|
||||
@@ -40,8 +43,8 @@ class TestMultiGPUEval:
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {"pad_token": "<|end_of_text|>"},
|
||||
"val_set_size": 0.004,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "teknium/GPT4-LLM-Cleaned",
|
||||
@@ -66,6 +69,7 @@ class TestMultiGPUEval:
|
||||
"saves_per_epoch": 1,
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -87,12 +91,18 @@ class TestMultiGPUEval:
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.5, "Loss is too high"
|
||||
|
||||
def test_eval(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"load_in_8bit": False,
|
||||
"load_in_4bit": True,
|
||||
"strict": False,
|
||||
@@ -106,8 +116,8 @@ class TestMultiGPUEval:
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {"pad_token": "<|end_of_text|>"},
|
||||
"val_set_size": 0.0004,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "teknium/GPT4-LLM-Cleaned",
|
||||
@@ -132,6 +142,7 @@ class TestMultiGPUEval:
|
||||
"saves_per_epoch": 1,
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -153,3 +164,9 @@ class TestMultiGPUEval:
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.9, "Loss is too high"
|
||||
|
||||
@@ -4,11 +4,11 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from importlib import reload
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from tbparse import SummaryReader
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
@@ -17,7 +17,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import most_recent_subdir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -31,49 +31,55 @@ def reload_transformers():
|
||||
reload(transformers.models.llama.modeling_llama)
|
||||
|
||||
|
||||
class TestFAXentropyLlama(unittest.TestCase):
|
||||
class TestFAXentropyLlama:
|
||||
"""
|
||||
Test case for Llama models using LoRA w multipack
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_packing_fa_cross_entropy(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 4],
|
||||
)
|
||||
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"flash_attn_cross_entropy": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.2,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"field_messages": "conversations",
|
||||
"message_field_content": "value",
|
||||
"message_field_role": "from",
|
||||
"type": "chat_template",
|
||||
"split": "train[:2%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 10,
|
||||
"save_steps": 10,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"max_steps": 5,
|
||||
"save_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -87,3 +93,10 @@ class TestFAXentropyLlama(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 1.5, "Loss is too high"
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
@@ -17,35 +16,35 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import most_recent_subdir, with_temp_dir
|
||||
from ..utils import most_recent_subdir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestResumeLlama(unittest.TestCase):
|
||||
class TestResumeLlama:
|
||||
"""
|
||||
Test case for resuming training of llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_resume_qlora_packed(self, temp_dir):
|
||||
def test_resume_lora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {},
|
||||
"val_set_size": 0.001,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
@@ -57,11 +56,11 @@ class TestResumeLlama(unittest.TestCase):
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_steps": 10,
|
||||
"save_steps": 3,
|
||||
"save_total_limit": 5,
|
||||
"max_steps": 40,
|
||||
"max_steps": 15,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
@@ -77,7 +76,7 @@ class TestResumeLlama(unittest.TestCase):
|
||||
|
||||
resume_cfg = cfg | DictDefault(
|
||||
{
|
||||
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/",
|
||||
"resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
|
||||
}
|
||||
)
|
||||
normalize_config(resume_cfg)
|
||||
@@ -93,4 +92,4 @@ class TestResumeLlama(unittest.TestCase):
|
||||
)
|
||||
pattern = r"first_step\s+(\d+)"
|
||||
first_steps = int(re.findall(pattern, res.stdout)[0])
|
||||
assert first_steps == 31
|
||||
assert first_steps == 10
|
||||
|
||||
186
tests/e2e/patched/test_unsloth_qlora.py
Normal file
186
tests/e2e/patched/test_unsloth_qlora.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
e2e tests for unsloth qlora
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from e2e.utils import most_recent_subdir
|
||||
from tbparse import SummaryReader
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
class TestUnslothQLoRA:
|
||||
"""
|
||||
Test class for Unsloth QLoRA Llama models
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": sample_packing,
|
||||
"flash_attention": True,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"save_steps": 10,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||
|
||||
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": False,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"save_steps": 10,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sdp_attention",
|
||||
[True, False],
|
||||
)
|
||||
def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": False,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"save_steps": 10,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"sdp_attention": sdp_attention,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||
121
tests/e2e/test_embeddings_lr.py
Normal file
121
tests/e2e/test_embeddings_lr.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
E2E tests for llama pretrain
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from tbparse import SummaryReader
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import most_recent_subdir, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"""
|
||||
Test case for embedding_lr*
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_train_w_embedding_lr_scale(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"max_steps": 5,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"embedding_lr_scale": 0.5,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||
|
||||
@with_temp_dir
|
||||
def test_train_w_embedding_lr(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"max_steps": 5,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"embedding_lr": 0.000005,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
||||
116
tests/e2e/test_llama_vision.py
Normal file
116
tests/e2e/test_llama_vision.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestLlamaVision(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama Vision models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_llama_vision_text_only_dataset(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
|
||||
"processor_type": "AutoProcessor",
|
||||
"skip_prepare_dataset": True,
|
||||
"remove_unused_columns": False,
|
||||
"sample_packing": False,
|
||||
"sequence_len": 1024,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
|
||||
"val_set_size": 0,
|
||||
"chat_template": "llama3_2_vision",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
|
||||
"processor_type": "AutoProcessor",
|
||||
"skip_prepare_dataset": True,
|
||||
"remove_unused_columns": False,
|
||||
"sample_packing": False,
|
||||
"sequence_len": 1024,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
|
||||
"val_set_size": 0,
|
||||
"chat_template": "llama3_2_vision",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "axolotl-ai-co/llava-instruct-mix-vsft-small",
|
||||
"type": "chat_template",
|
||||
"split": "train",
|
||||
"field_messages": "messages",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -57,6 +57,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
|
||||
@@ -56,6 +56,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "optimi_adamw",
|
||||
"max_steps": 5,
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
@@ -94,6 +95,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
@@ -115,7 +117,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
@@ -126,13 +128,14 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "schedule_free_adamw",
|
||||
"lr_scheduler": "constant",
|
||||
"save_safetensors": True,
|
||||
"max_steps": 10,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
@@ -7,13 +7,15 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from tbparse import SummaryReader
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import most_recent_subdir, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -29,35 +31,48 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attention": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": ["q_proj", "v_proj"],
|
||||
"relora_steps": 25,
|
||||
"relora_warmup_steps": 5,
|
||||
"relora_anneal_steps": 5,
|
||||
"relora_steps": 100,
|
||||
"relora_warmup_steps": 20,
|
||||
"relora_anneal_steps": 10,
|
||||
"relora_prune_ratio": 0.9,
|
||||
"relora_cpu_offload": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {},
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"warmup_steps": 15,
|
||||
"warmup_steps": 20,
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 4,
|
||||
"max_steps": 205, # at least 2x relora_steps
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
@@ -65,4 +80,14 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
assert (
|
||||
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
|
||||
).exists()
|
||||
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == "train/grad_norm")] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < 0.2, "grad_norm is too high"
|
||||
|
||||
@@ -53,7 +53,7 @@ def require_torch_2_3_1(test_case):
|
||||
|
||||
def require_torch_2_5_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.3.1
|
||||
Decorator marking a test that requires torch >= 2.5.1
|
||||
"""
|
||||
|
||||
def is_min_2_5_1():
|
||||
|
||||
@@ -7,6 +7,11 @@ import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from constants import (
|
||||
ALPACA_MESSAGES_CONFIG_OG,
|
||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||
SPECIAL_TOKENS,
|
||||
)
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
@@ -21,13 +26,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
)
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
# Alpaca dataset.
|
||||
self.dataset = Dataset.from_list(
|
||||
[
|
||||
@@ -68,7 +67,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
def test_load_local_hub(self):
|
||||
"""Niche use case. Verify that a local copy of a hub dataset can be loaded"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
@@ -90,7 +89,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||
],
|
||||
},
|
||||
],
|
||||
@@ -277,23 +276,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
"sequence_len": 1024,
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"chat_template": "llama3",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
}
|
||||
],
|
||||
"datasets": [ALPACA_MESSAGES_CONFIG_OG],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -342,24 +325,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
"sequence_len": 1024,
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"chat_template": "llama3",
|
||||
"revision": "ea82cff",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
}
|
||||
],
|
||||
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
433
tests/test_exact_deduplication.py
Normal file
433
tests/test_exact_deduplication.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function.
|
||||
|
||||
Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command.
|
||||
"""
|
||||
import hashlib
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from constants import ALPACA_MESSAGES_CONFIG_REVISION, SPECIAL_TOKENS
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_dpo_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_processor, load_tokenizer
|
||||
|
||||
|
||||
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
||||
"""
|
||||
Validates deduplication results and size consistency.
|
||||
|
||||
Parameters:
|
||||
- actual_dataset: Deduplicated dataset.
|
||||
- expected_dataset: Expected dataset.
|
||||
- dataset_name: Name of the dataset (e.g., 'train' or 'eval').
|
||||
|
||||
Asserts:
|
||||
- Datasets match in content.
|
||||
- Dataset size matches unique row count.
|
||||
"""
|
||||
# Convert datasets to sets of tuples for unordered comparison
|
||||
actual_rows = set(tuple(row.values()) for row in actual_dataset)
|
||||
expected_rows = set(tuple(row.values()) for row in expected_dataset)
|
||||
|
||||
# Verify deduplication correctness
|
||||
assert actual_rows == expected_rows, f"Mismatch in {dataset_name} dataset"
|
||||
|
||||
# Verify size consistency
|
||||
assert len(actual_rows) == len(
|
||||
actual_dataset
|
||||
), f"Size mismatch in {dataset_name} dataset after deduplication"
|
||||
|
||||
|
||||
class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
"""
|
||||
test class for deduplication function in data utils
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# Sample data with duplicates
|
||||
self.data = {
|
||||
"column1": ["apple", "banana", "apple", "orange", "banana"],
|
||||
"column2": [1, 2, 1, 3, 2],
|
||||
"column3": ["red", "yellow", "red", "orange", "yellow"],
|
||||
}
|
||||
|
||||
# Expected result after deduplication
|
||||
self.expected_data = {
|
||||
"column1": ["apple", "banana", "orange"],
|
||||
"column2": [1, 2, 3],
|
||||
"column3": ["red", "yellow", "orange"],
|
||||
}
|
||||
|
||||
# Convert to Dataset format
|
||||
self.dataset = Dataset.from_dict(self.data)
|
||||
self.expected_dataset = Dataset.from_dict(self.expected_data)
|
||||
|
||||
def test_deduplication(self):
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
|
||||
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_datasets_are_none(self):
|
||||
# Test when both datasets are None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
|
||||
def test_only_train_is_none(self):
|
||||
# Test when only train_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=self.dataset
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_only_eval_is_none(self):
|
||||
# Test when only eval_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.dataset, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
|
||||
def test_exact_duplicates(self):
|
||||
# Test when datasets are exact duplicates
|
||||
duplicate_data = {
|
||||
"column1": ["apple", "apple", "apple"],
|
||||
"column2": [1, 1, 1],
|
||||
"column3": ["red", "red", "red"],
|
||||
}
|
||||
expected_data = {"column1": ["apple"], "column2": [1], "column3": ["red"]}
|
||||
|
||||
# Convert to Dataset format
|
||||
dataset = Dataset.from_dict(duplicate_data)
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
|
||||
def test_partial_duplicates(self):
|
||||
# Test when only part of the dataset is a duplicate
|
||||
partial_duplicate_data = {
|
||||
"column1": ["apple", "banana", "apple"],
|
||||
"column2": [1, 2, 1],
|
||||
"column3": ["red", "yellow", "red"],
|
||||
}
|
||||
expected_data = {
|
||||
"column1": ["apple", "banana"],
|
||||
"column2": [1, 2],
|
||||
"column3": ["red", "yellow"],
|
||||
}
|
||||
|
||||
# Convert to Dataset format
|
||||
dataset = Dataset.from_dict(partial_duplicate_data)
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
|
||||
def test_combined_duplicates_empty(self):
|
||||
# Test when only part of the dataset is a duplicate
|
||||
partial_duplicate_data = {
|
||||
"column1": ["apple", "banana", "apple"],
|
||||
"column2": [1, 2, 1],
|
||||
"column3": ["red", "yellow", "red"],
|
||||
}
|
||||
expected_data_train = {
|
||||
"column1": ["apple", "banana"],
|
||||
"column2": [1, 2],
|
||||
"column3": ["red", "yellow"],
|
||||
}
|
||||
expected_data_eval = {
|
||||
"column1": [],
|
||||
"column2": [],
|
||||
"column3": [],
|
||||
}
|
||||
|
||||
# Convert to Dataset format
|
||||
dataset = Dataset.from_dict(partial_duplicate_data)
|
||||
expected_dataset_train = Dataset.from_dict(expected_data_train)
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset, eval_dataset=dataset
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
||||
|
||||
def test_combined_duplicates_one(self):
|
||||
# Test when only part of the dataset is a duplicate
|
||||
partial_duplicate_data_train = {
|
||||
"column1": ["apple", "banana", "apple"],
|
||||
"column2": [1, 2, 1],
|
||||
"column3": ["red", "yellow", "red"],
|
||||
}
|
||||
partial_duplicate_data_eval = {
|
||||
"column1": ["apple", "orange", "apple"],
|
||||
"column2": [1, 2, 1],
|
||||
"column3": ["red", "orange", "red"],
|
||||
}
|
||||
expected_data_train = {
|
||||
"column1": ["apple", "banana"],
|
||||
"column2": [1, 2],
|
||||
"column3": ["red", "yellow"],
|
||||
}
|
||||
expected_data_eval = {
|
||||
"column1": ["orange"],
|
||||
"column2": [2],
|
||||
"column3": ["orange"],
|
||||
}
|
||||
|
||||
# Convert to Dataset format
|
||||
dataset_train = Dataset.from_dict(partial_duplicate_data_train)
|
||||
dataset_eval = Dataset.from_dict(partial_duplicate_data_eval)
|
||||
expected_dataset_train = Dataset.from_dict(expected_data_train)
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset_train, eval_dataset=dataset_eval
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset_eval, "eval_dataset")
|
||||
|
||||
|
||||
class TestDeduplicateRLDataset(unittest.TestCase):
|
||||
"""Test a configured dataloader with deduplication."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
self.cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"dataset_exact_deduplication": True,
|
||||
"datasets": [
|
||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def test_load_with_deduplication(self):
|
||||
"""Verify that loading with deduplication removes duplicates."""
|
||||
|
||||
# Load the dataset using the deduplication setting
|
||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
||||
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
|
||||
def test_load_without_deduplication(self):
|
||||
"""Verify that loading without deduplication retains duplicates."""
|
||||
self.cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
len(train_dataset) == 1800 * 2
|
||||
), "Dataset deduplication occurred when it should not have"
|
||||
|
||||
|
||||
class TestDeduplicateNonRL(unittest.TestCase):
|
||||
"""Test prepare_dataset function with different configurations."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
||||
self.cfg_1 = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"dataset_exact_deduplication": True,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"val_set_size": 0.0,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"batch_size": 10,
|
||||
"micro_batch_size": 10,
|
||||
"num_epochs": 1,
|
||||
}
|
||||
)
|
||||
|
||||
def test_prepare_dataset_with_deduplication_train(self):
|
||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||
self.cfg_1.dataset_exact_deduplication = True
|
||||
|
||||
# Load tokenizer and processor
|
||||
tokenizer = load_tokenizer(self.cfg_1)
|
||||
processor = (
|
||||
load_processor(self.cfg_1, tokenizer=tokenizer)
|
||||
if self.cfg_1.processor_type
|
||||
else None
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, _, _, _ = prepare_dataset(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
len(train_dataset),
|
||||
2000,
|
||||
"Train dataset should have 2000 samples after deduplication.",
|
||||
)
|
||||
|
||||
def test_prepare_dataset_with_deduplication_eval(self):
|
||||
"""Verify that prepare_dataset function processes the dataset correctly with deduplication."""
|
||||
self.cfg_1.dataset_exact_deduplication = True
|
||||
self.cfg_1.val_set_size = 0.5
|
||||
# Load tokenizer and processor
|
||||
tokenizer = load_tokenizer(self.cfg_1)
|
||||
processor = (
|
||||
load_processor(self.cfg_1, tokenizer=tokenizer)
|
||||
if self.cfg_1.processor_type
|
||||
else None
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
_, eval_dataset, _, _ = prepare_dataset(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
len(eval_dataset),
|
||||
1000,
|
||||
"Eval dataset should have 2000 samples after deduplication.",
|
||||
)
|
||||
|
||||
def test_prepare_dataset_without_deduplication(self):
|
||||
"""Verify that prepare_dataset function processes the dataset correctly without deduplication."""
|
||||
self.cfg_1.dataset_exact_deduplication = False
|
||||
self.cfg_1.val_set_size = 0.1
|
||||
# Load tokenizer and processor
|
||||
tokenizer = load_tokenizer(self.cfg_1)
|
||||
processor = (
|
||||
load_processor(self.cfg_1, tokenizer=tokenizer)
|
||||
if self.cfg_1.processor_type
|
||||
else None
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, eval_dataset, _, _ = prepare_dataset(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# Verify that the dataset has been prepared correctly
|
||||
self.assertEqual(
|
||||
len(train_dataset),
|
||||
1800 * 2,
|
||||
"Train dataset should have 3600 samples without deduplication.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(eval_dataset),
|
||||
200 * 2,
|
||||
"Train dataset should have 400 samples after deduplication.",
|
||||
)
|
||||
|
||||
|
||||
class TestWrongCollisions(unittest.TestCase):
|
||||
"""Creating mock datasets for testing wrong collisions"""
|
||||
|
||||
def setUp(self):
|
||||
self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]}
|
||||
self.eval_data = {
|
||||
"text": [
|
||||
"sample 5",
|
||||
"sample 7",
|
||||
], # Different label but same text as in train_data
|
||||
"label": [2, 3],
|
||||
}
|
||||
self.dataset_data = {
|
||||
"text": ["sample 5", "sample 9", "sample 5"],
|
||||
"label": [1, 2, 8],
|
||||
}
|
||||
self.train_dataset = Dataset.from_dict(self.train_data)
|
||||
self.eval_dataset = Dataset.from_dict(self.eval_data)
|
||||
self.dataset = Dataset.from_dict(self.dataset_data)
|
||||
|
||||
@patch(
|
||||
"axolotl.utils.data.utils.sha256",
|
||||
side_effect=lambda x: hashlib.sha256(
|
||||
"forced_collision_hash".encode("utf-8")
|
||||
).hexdigest()
|
||||
if "sample 5" in x
|
||||
else hashlib.sha256(x.encode("utf-8")).hexdigest(),
|
||||
)
|
||||
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
|
||||
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_train),
|
||||
2,
|
||||
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
2,
|
||||
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
len(self.eval_dataset),
|
||||
"The output eval dataset should have the same number of rows as the input eval dataset.",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(dedup_eval),
|
||||
str(self.eval_dataset),
|
||||
"The string representation of the output eval dataset should be identical to the input eval dataset.",
|
||||
)
|
||||
|
||||
def test_deduplication_dataset_only(self):
|
||||
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
self.assertEqual(
|
||||
len(dedup_dataset), 3, "Dataset should have all original values"
|
||||
)
|
||||
self.assertEqual(
|
||||
str(dedup_dataset),
|
||||
str(self.dataset),
|
||||
"The string representation of the output dataset should not differ.",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -7,35 +7,40 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
|
||||
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
MODEL_NAME = "HuggingFaceTB/SmolLM2-135M"
|
||||
|
||||
|
||||
@fixture()
|
||||
def metric(tokenizer):
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
return Perplexity(tokenizer=tokenizer, max_seq_len=512)
|
||||
|
||||
return Perplexity(model, tokenizer, 512)
|
||||
|
||||
@fixture()
|
||||
def model():
|
||||
return AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
|
||||
|
||||
@fixture()
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
tokenizer_ = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
tokenizer_.add_special_tokens({"pad_token": "<|endoftext|>"})
|
||||
return tokenizer_
|
||||
|
||||
|
||||
def test_perplexity_longer_than_stride(metric):
|
||||
def test_perplexity_longer_than_stride(model, metric):
|
||||
# taken from https://huggingface.co/datasets/roneneldan/TinyStories
|
||||
sample_text = """
|
||||
Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.
|
||||
One day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. "Hi, I am Fin. Do you want to play?" asked the little fish. The crab looked at Fin and said, "No, I don't want to play. I am cold and I don't feel fine." Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, "Please, sun, help my new friend feel fine and not freeze!" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, "Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!" And so, Fin and the crab played and became good friends.
|
||||
"""
|
||||
result = metric.compute([sample_text])
|
||||
result = metric.compute(model, [sample_text])
|
||||
ppl = result["score"]
|
||||
assert round(ppl, 2) == 5.37
|
||||
assert round(ppl, 2) == 7.41
|
||||
|
||||
|
||||
def test_perplexity_short(metric):
|
||||
def test_perplexity_short(model, metric):
|
||||
# taken from https://huggingface.co/datasets/roneneldan/TinyStories
|
||||
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
|
||||
result = metric.compute([sample_text])
|
||||
result = metric.compute(model, [sample_text])
|
||||
ppl = result["score"]
|
||||
assert round(ppl, 2) == 10.02
|
||||
assert round(ppl, 2) == 10.33
|
||||
|
||||
@@ -672,6 +672,9 @@ class TestValidation(BaseValidation):
|
||||
{
|
||||
"bf16": True,
|
||||
"capabilities": {"bf16": False},
|
||||
"env_capabilities": {
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -1160,6 +1163,38 @@ class TestValidation(BaseValidation):
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_torch_version_adopt_req(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"optimizer": "adopt_adamw",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*ADOPT optimizer is incompatible with torch version*",
|
||||
):
|
||||
env_capabilities = {"torch_version": "2.3.0"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.1"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.2"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
|
||||
class TestValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
|
||||
@@ -72,6 +72,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
@@ -124,6 +127,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
@@ -177,6 +183,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
@@ -231,6 +240,9 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
env_capabilities={
|
||||
"torch_version": "2.5.1",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
Reference in New Issue
Block a user