Compare commits

...

26 Commits
mm3 ... 1947fix

Author SHA1 Message Date
sunny
66a1e209e3 changed yml 2024-10-18 10:46:36 -04:00
sunny
dba033cb5b fixed yml for issue1947 testing 2024-10-17 12:31:11 -04:00
sunny
6a32e9a0df added yml for testing issue 1947 2024-10-17 11:26:32 -04:00
sunny
4afb2656b3 added yml for testing issue 1947 2024-10-17 11:23:47 -04:00
JohanWork
6d9a3c4d81 examples: Fix config llama3 (#1833) [skip ci]
* update llama3 config

* llama3 config
2024-10-14 16:00:48 -04:00
Wing Lian
335027f155 upgrade accelerate to 1.0.1 (#1969) 2024-10-13 20:04:30 -04:00
Wing Lian
ec4272c3a0 add ds zero3 to multigpu biweekly tests (#1900)
* add ds zero3 to multigpu biweekly tests

* fix for upstream api change

* use updated accelerate and fix deepspeed tests

* stringify the Path, and run multigpu tests if the multigpu tests change for a PR

* use correct json rather than yaml

* revert accelerate for deepspeed
2024-10-13 17:34:37 -04:00
Wing Lian
68b1369de9 Reward model (#1879) 2024-10-13 15:11:13 -04:00
Wing Lian
cd2d89f467 wip add new proposed message structure (#1904)
* wip add new proposed message structure

* tokenization

* wip

* wip transform builder

* wip make the chat dataset loadable

* wip chatml + llama 3 new chat objects

* chore: lint

* chore: lint

* fix tokenization

* remove dacite dependency since we're using pydantic now

* fix handling when already correctly split in messages

* make sure to remove chat features from tokenized ds

* move chat to be a input transform for messages

* make sure llama3 has the bos token

* remove non-working special token code

* fix messages strat loader
2024-10-13 12:15:18 -04:00
Vincent Haines
1834cdc364 Add support for qwen 2.5 chat template (#1934) 2024-10-12 21:41:43 -04:00
NanoCode012
ac128b7b1d fix: update eval causal lm metrics to add perplexity (#1951) [skip ci] 2024-10-12 21:41:13 -04:00
pandora
31591bd94c Fixing Validation - Mistral Templates (#1962) 2024-10-12 21:40:39 -04:00
Wing Lian
d20b48a61e only install torchao for torch versions >= 2.4.0 (#1963) 2024-10-12 20:53:48 -04:00
Wing Lian
09bf1ceacc update hf deps (#1964)
* update hf deps

* remove deprecated set_caching_enabled
2024-10-12 18:19:48 -04:00
Afrizal Hasbi Azizy
df359c8a6e Handle image input as string paths for MMLMs (#1958)
* Update mm_chat.py

Handle string image (paths)

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-11 13:34:13 -04:00
Wing Lian
76883851d2 add warning that sharegpt will be deprecated (#1957)
* add warning that sharegpt will be deprecated

* add helper script for chat_templates and document deprecation

* Update src/axolotl/prompt_strategies/sharegpt.py

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-10-11 13:33:20 -04:00
Adam Hazell
922db77521 Add MLFlow run name option in config (#1961)
Co-authored-by: Adam Hazell <adam.hazell@mindfoundry.ai>
2024-10-11 13:33:06 -04:00
Thomas Cleberg
e73b8dff8d Add Support for revision Dataset Parameter to specify reading from Huggingface Dataset Revision (#1912)
* Add support for `revision` dataset parameter

* only use revision on hf hub backed datasets

* use revision tied to head

* set download to use revision

* feat: add config to model validator class

* feat: add revision config to RL and tests for it

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2024-10-11 13:32:50 -04:00
Wing Lian
2fbc6b0c64 Axo logo new (#1956)
* update axolotl ascii art

* spacing for logo

* cleanup dithering

* cleanup ascii logo a bit
2024-10-10 15:57:37 -04:00
Wing Lian
8159cbd1ab lm_eval harness post train (#1926)
* wip, lm_eval harness post train

* include latex parser

* add dtype and doc

* add validation when doing bench evals

* automatically add test dataset when doing benches
2024-10-10 15:04:17 -04:00
pandora
979534c851 add mistral templates (#1927)
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-10 09:22:53 -04:00
Boris Feld
6d3caadf90 Comet integration (#1939)
* Add first version of a Comet integration

* Remove debug prints

* Add test for Comet Configuration transformation to env variables

* Fix last lint warning

* Update Readme for Comet logging documentation

* Update Comet integration to be optional, update code and tests

* Add documentation for Comet configuration

* Add missing check
2024-10-09 16:03:37 -04:00
aarush gupta
dee77232fe fix type annotations (#1941) [skip ci] 2024-10-09 16:03:16 -04:00
NanoCode012
a560593b1d fix(log): update perplexity log to clarify from eval split (#1952) [skip ci] 2024-10-09 16:02:32 -04:00
Wing Lian
e8d3da0081 upgrade pytorch from 2.4.0 => 2.4.1 (#1950)
* upgrade pytorch from 2.4.0 => 2.4.1

* update xformers for updated pytorch version

* handle xformers version case for torch==2.3.1
2024-10-09 11:53:56 -04:00
Wing Lian
4ca0a47cfb add 2.4.1 to base models (#1953) 2024-10-09 08:43:11 -04:00
65 changed files with 2581 additions and 76 deletions

View File

@@ -28,7 +28,13 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.0
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout

View File

@@ -27,7 +27,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
pytorch: 2.4.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -84,7 +84,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
pytorch: 2.4.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -26,7 +26,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
pytorch: 2.4.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -83,7 +83,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
pytorch: 2.4.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.0"]
pytorch_version: ["2.3.1", "2.4.1"]
timeout-minutes: 20
steps:
@@ -91,7 +91,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"

View File

@@ -36,7 +36,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.0"]
pytorch_version: ["2.3.1", "2.4.1"]
timeout-minutes: 20
steps:
@@ -94,7 +94,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -1,3 +1,3 @@
[settings]
profile=black
known_third_party=wandb
known_third_party=wandb,comet_ml

View File

@@ -14,7 +14,7 @@ Features:
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb or mlflow
- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation
# fastchat conversation (deprecation soon, use chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt
@@ -515,6 +515,22 @@ wandb_name:
wandb_log_model:
```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:

View File

@@ -90,6 +90,7 @@ datasets:
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -265,8 +266,21 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -301,7 +315,7 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

View File

@@ -0,0 +1,63 @@
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,50 @@
base_model: meta-llama/Llama-3.1-8B-Instruct
save_safetensors: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: ./last_run_prepared
output_dir: ./outputs/fft-out
sequence_len: 8192
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
learning_rate: 2e-5
bf16: auto
fp16:
tf32: false
logging_steps: 2
xformers_attention:
flash_attention: true
warmup_steps: 2
evals_per_epoch: 2
save_steps: 2
max_steps: 2
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: true
fsdp_cpu_ram_efficient_loading: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -11,7 +11,6 @@ rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
chat_template: llama3
field_messages: conversation
field_chosen: chosen
field_rejected: rejected

View File

@@ -10,7 +10,6 @@ chat_template: llama3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: llama3
field_messages: messages
message_field_role: role
message_field_content: content

View File

@@ -1,11 +1,11 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.0
transformers==4.45.1
tokenizers>=0.19.1
bitsandbytes==0.44.0
accelerate==0.34.2
datasets==2.21.0
peft==0.13.2
transformers==4.45.2
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
datasets==3.0.1
deepspeed==0.14.4
pydantic==2.6.3
addict
@@ -16,7 +16,7 @@ flash-attn==2.6.3
sentencepiece
wandb
einops
xformers==0.0.27
xformers==0.0.28.post1
optimum==1.16.2
hf_transfer
colorama
@@ -46,3 +46,11 @@ gcsfs>=2024.5.0
trl==0.9.6
zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.4
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.5.0

315
requirements_env.txt Normal file
View File

@@ -0,0 +1,315 @@
accelerate==0.34.1
addict==2.4.0
aiofiles==23.2.1
aiohttp==3.9.0
aiosignal==1.3.1
aiostream==0.5.2
alembic==1.13.1
annotated-types==0.6.0
annoy==1.17.3
ansible==6.7.0
ansible-core==2.13.13
ansible-vault==2.1.0
anyio==3.7.1
appdirs==1.4.4
art==6.0
asgiref==3.7.2
async-timeout==4.0.2
attrdict==2.0.1
attrs==22.2.0
awscli==1.32.75
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
backoff==2.2.1
base58==2.1.1
beartype==0.17.2
bitnet==0.2.1
bitsandbytes==0.42.0
bittensor==6.7.0
black==23.7.0
blinker==1.7.0
boto3==1.34.75
botocore==1.34.75
cachetools==5.3.3
cachy==0.1.1
certifi==2023.7.22
cffi==1.16.0
cfgv==3.3.1
chai-guanaco==1.2.4
charset-normalizer==3.2.0
cleo==0.6.8
click==8.1.7
cloudpickle==2.0.0
cohere==4.11.2
colorama==0.4.4
coloredlogs==15.0.1
CoLT5-attention==0.10.20
contextlib2==21.6.0
contourpy==1.2.0
cryptography==41.0.3
cycler==0.12.1
cytoolz==0.12.3
databricks-cli==0.18.0
dataclasses-json==0.5.7
datasets==2.11.0
ddt==1.6.0
decorator==5.1.1
deepspeed==0.15.0
# Editable Git install with no remote (dialogpt==0.1)
-e /Users/wing/Projects/ml/dialogpt/src
dill==0.3.6
distlib==0.3.6
docker==7.0.0
docker-pycreds==0.4.0
docstring-parser==0.15
docutils==0.16
ecdsa==0.18.0
einops==0.7.0
einops-exts==0.0.4
einx==0.1.3
entrypoints==0.4
eth-hash==0.6.0
eth-keys==0.5.0
eth-typing==4.0.0
eth-utils==2.3.1
evaluate==0.4.0
exceptiongroup==1.1.1
fastapi==0.109.2
fastcore==1.5.29
ffmpy==0.4.0
filelock==3.12.2
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
fire==0.5.0
first==2.0.2
flake8==7.0.0
Flask==3.0.1
fonttools==4.47.2
frozendict==2.4.1
frozenlist==1.3.3
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
fsspec==2023.6.0
fuzzywuzzy==0.18.0
gitdb==4.0.10
GitPython==3.1.31
google-pasta==0.2.0
gradio==4.42.0
gradio_client==1.3.0
greenlet==2.0.2
grpclib==0.4.7
gunicorn==21.2.0
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==0.17.3
httpx==0.24.1
huggingface-hub==0.23.4
humanfriendly==10.0
hyperframe==6.0.1
identify==2.5.24
idna==3.4
immutables==0.20
importlib-metadata==6.7.0
importlib-resources==6.1.1
inflection==0.5.1
iniconfig==2.0.0
itsdangerous==2.1.2
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
jsonlines==3.1.0
jsonschema==2.6.0
kiwisolver==1.4.5
langchain==0.0.144
Levenshtein==0.24.0
libcst==1.1.0
liger-kernel==0.0.0
lion-pytorch==0.1.2
llama-cpp-python==0.1.36
llvmlite==0.40.1
local-attention==1.9.0
loguru==0.7.0
Mako==1.3.2
Markdown==3.5.2
markdown-it-py==3.0.0
markdown2==2.4.10
MarkupSafe==2.1.2
marshmallow==3.19.0
marshmallow-enum==1.5.1
matplotlib==3.8.2
mccabe==0.7.0
mdurl==0.1.2
MEGABYTE-pytorch==0.0.7
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
mlflow==2.10.0
modal==0.62.77
more-itertools==10.2.0
mpmath==1.2.1
msgpack==1.0.7
msgpack-numpy-opentensor==0.5.0
multidict==6.0.4
multiprocess==0.70.14
munch==2.5.0
mypy==1.3.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
netaddr==0.10.1
networkx==3.0rc1
nh3==0.2.14
nodeenv==1.8.0
nomic==2.0.2
numba==0.57.1
numexpr==2.8.4
numpy==1.24.4
oauthlib==3.2.2
openai==0.27.4
openapi==1.1.0
openapi-schema-pydantic==1.2.4
optimum==1.8.6
orjson==3.10.7
packaging==23.1
pandas==2.0.0
parameterized==0.9.0
password-strength==0.0.3.post2
pastel==0.1.1
pathos==0.3.0
pathspec==0.11.1
pathtools==0.1.2
peft==0.11.1
pendulum==3.0.0
Pillow==9.5.0
pip-tools==1.11.0
platformdirs==3.2.0
pluggy==1.4.0
poetry==0.7.1
pox==0.3.2
ppft==1.7.6.6
pre-commit==3.3.2
prettytable==3.10.0
prompt-toolkit==3.0.39
protobuf==3.20.2
protobuf3-to-dict==0.1.5
psutil==5.9.5
psycopg==3.1.18
PuLP==2.8.0
py==1.11.0
py-bip39-bindings==0.1.11
py-cpuinfo==9.0.0
py-ed25519-zebra-bindings==1.0.1
py-sr25519-bindings==0.2.0
pyarrow==11.0.0
pyasn1==0.6.0
pycodestyle==2.11.1
pycparser==2.21
pycryptodome==3.20.0
pydantic==2.5.3
pydantic_core==2.14.6
pydub==0.25.1
pyfiglet==0.8.post1
pyflakes==3.2.0
Pygments==2.15.1
PyJWT==2.8.0
pylev==1.4.0
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==2.4.7
pyrsistent==0.14.11
pytest==8.0.2
pytest-asyncio==0.23.4
python-dateutil==2.8.2
python-dotenv==1.0.1
python-Levenshtein==0.24.0
python-multipart==0.0.9
pytz==2023.3
PyYAML==6.0.1
querystring-parser==1.2.4
rapidfuzz==3.6.1
regex==2023.6.3
requests==2.31.0
requests-toolbelt==0.8.0
resolvelib==0.8.1
responses==0.18.0
retry==0.9.2
rich==13.7.0
rsa==4.7.2
ruff==0.6.3
s3transfer==0.10.1
safetensors==0.4.5
sagemaker==2.148.0
scalecodec==1.2.7
schedulefree==1.2.1
schema==0.7.5
scikit-learn==1.4.0
scipy==1.9.3
seaborn==0.13.2
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==1.19.1
setproctitle==1.3.2
shellingham==1.5.4
shortuuid==1.0.11
shtab==1.6.5
sigtools==4.0.1
six==1.16.0
skypilot==0.4.1
smdebug-rulesconfig==1.0.1
smmap==5.0.0
sniffio==1.3.0
SQLAlchemy==1.4.47
sqlparse==0.4.4
starlette==0.36.3
substrate-interface==1.5.2
svgwrite==1.4.3
sympy==1.11.1
synchronicity==0.6.7
tabulate==0.9.0
tblib==1.7.0
tenacity==8.2.2
tensor-parallel==2.0.0
termcolor==2.2.0
text2art==0.2.0
threadpoolctl==3.2.0
tiktoken==0.6.0
time-machine==2.14.1
timm==0.9.16
tokenizers==0.19.1
tokenmonster==1.1.12
toml==0.9.6
tomli==2.0.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.2.0
torchdata==0.6.1
torchdiffeq==0.2.3
TorchFix==0.4.0
torchtext==0.15.2
torchvision==0.17.0
tqdm==4.66.2
transformers==4.44.2
trl==0.9.6
typer==0.12.5
types-certifi==2021.10.8.3
types-requests==2.31.0.20240125
types-setuptools==69.0.0.20240125
types-toml==0.10.8.7
typing==3.7.4.3
typing-inspect==0.8.0
typing_extensions==4.9.0
tyro==0.5.18
tzdata==2023.3
unique-names-generator==1.0.2
urllib3==2.2.2
uvicorn==0.22.0
vector_quantize_pytorch==1.14.1
virtualenv==20.23.0
voyager==2.0.2
wandb==0.16.2
watchfiles==0.21.0
wavedrom==2.0.3.post3
wcwidth==0.2.6
websocket-client==1.7.0
websockets==12.0
Werkzeug==3.0.1
wonderwords==2.2.0
xxhash==3.2.0
yarl==1.8.2
zetascale==2.2.7
zipp==3.15.0

60
scripts/chat_datasets.py Normal file
View File

@@ -0,0 +1,60 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset
@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
ds_cfg = {}
ds_cfg["path"] = dataset
ds_cfg["split"] = split
ds_cfg["type"] = "chat_template"
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
dataset = load_dataset(dataset, split=split)
features = dataset.features
feature_keys = features.keys()
field_messages = None
for key in ["conversation", "conversations", "messages"]:
if key in feature_keys:
field_messages = key
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
break
if not message_field_role:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
break
if not message_field_content:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
print(yaml.dump({"datasets": [ds_cfg]}))
if __name__ == "__main__":
parse_dataset()

View File

@@ -30,6 +30,7 @@ def parse_requirements():
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
@@ -49,14 +50,24 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 3):
if (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")

View File

@@ -31,6 +31,7 @@ from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
@@ -54,8 +55,22 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_axolotl_text_art(suffix=None):
def print_legacy_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
@@ -68,6 +83,13 @@ def print_axolotl_text_art(suffix=None):
print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
@@ -421,6 +443,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg

View File

@@ -27,6 +27,7 @@ from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -70,10 +71,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
with disable_datasets_caching():
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.download:
model_name = parsed_cfg.base_model

View File

@@ -3,13 +3,11 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Tuple, Union
from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -20,6 +18,7 @@ from axolotl.cli import (
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
@@ -39,7 +38,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
def do_train(cfg, cli_args) -> None:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
@@ -64,7 +63,13 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
if __name__ == "__main__":

View File

View File

View File

@@ -0,0 +1,34 @@
"""
ChatML transformation functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(
message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument
) -> Messages:
if message.is_chat_formatted:
return message
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|im_start|>{message.role}\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|im_end|>", weight=message.weight)
)
message.content.append(MessageContents(type="text", value="\n", weight=0))
message = wrap_tools(message)
message.is_chat_formatted = True
return message

View File

@@ -0,0 +1,45 @@
"""
Llama 3.x chat formatting functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
if message.is_chat_formatted:
return message
message_role = message.role
if message.role == "tool":
message_role = "ipython"
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
)
message = wrap_tools(message)
if message_index == 0:
message.content.insert(
0,
MessageContents(
type="text",
value="<|begin_of_text|>",
weight=0,
),
)
message.is_chat_formatted = True
return message

View File

@@ -0,0 +1,47 @@
"""
shared functions for format transforms
"""
from axolotl.core.chat.messages import MessageContents, Messages
def wrap_tools(message: Messages):
# loop over message.content by index to find tool calls, we need to wrap each with tags,
# so be wary of indexing issues when changing the list while iterating.
# iterate over the range in reverse order to avoid index shifting
for i in range(len(message.content) - 1, -1, -1):
if message.content[i].type == "tool_call":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_call>\n", weight=message.weight
),
)
# make sure the actual tool call content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_call>\n", weight=message.weight
),
)
elif message.content[i].type == "tool_response":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_response>\n", weight=message.weight
),
)
# make sure the actual tool response content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_response>\n", weight=message.weight
),
)
return message

View File

@@ -0,0 +1,230 @@
"""
internal message representations of chat messages
"""
import json
from enum import Enum
from typing import Any, Callable, List, Optional, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizer
class MessageRoles(str, Enum):
"""
Message roles for the system, user, assistant, and tools
"""
system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
ipython = ( # pylint: disable=invalid-name
# for responses from builtin tools
"ipython"
)
class MessageContentTypes(str, Enum):
"""
Message content types for text, image, audio, tool calls, and tool responses
"""
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
tool_response = "tool_response" # pylint: disable=invalid-name
class SpecialToken(str, Enum):
"""
Special tokens for beginning of string and end of string
"""
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
class ToolCallFunction(BaseModel):
"""
Tool call function with name and arguments
"""
name: str
arguments: dict[str, str]
class Tool(BaseModel):
"""
Tool with description, function, and parameters
"""
description: str
function: ToolCallFunction
parameters: dict[str, str] # .properties
class ToolCallContents(BaseModel):
"""
Tool call contents with name, arguments, and optional id
"""
name: str
arguments: dict[str, Union[str, int]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class ToolResponseContents(BaseModel):
"""
Tool response contents with name, content, and optional id
"""
name: str
content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "content": self.content}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class MessageContents(BaseModel):
"""
Message contents with type, value, metadata, weight, newline, and end of contents
"""
type: Union[str, MessageContentTypes]
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
has_newline: bool = False
eoc: bool = False # end of contents
def __str__(self) -> str:
str_val = str(self.value)
if self.has_newline and not str_val.endswith("\n"):
str_val += "\n"
return str_val
class Messages(BaseModel):
"""
Messages with role, content, metadata, weight, and chat formatting
"""
role: Union[MessageRoles, str] # allows for arbitrary roles
content: List["MessageContents"]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
is_chat_formatted: bool = False
def __str__(self) -> str:
return "".join(str(c) for c in self.content)
def tokenized(
self, tokenizer: PreTrainedTokenizer, ignore_index=-100
) -> dict[str, List[int]]:
# iterate over the contents, tokenizing the concatenated string values up to the current MessageContents
# returns a dictionary mapping w input_ids, attention_mask, and labels
input_ids: List[int] = []
labels: List[int] = []
pending_input_ids: List[int] = []
pending_weight = self.weight
running_content = ""
for _, msg_content in enumerate(self.content):
# TODO also handle non-text content types
if msg_content.type in [
MessageContentTypes.text.value,
MessageContentTypes.tool_call.value,
MessageContentTypes.tool_response.value,
]:
running_content += str(msg_content)
tok_results = tokenizer(running_content, add_special_tokens=False)
tok_input_ids = tok_results["input_ids"]
if pending_input_ids:
new_pending_inputs = tok_input_ids[
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
pending_input_ids = tok_results["input_ids"][len(input_ids) :]
pending_weight = self.weight and msg_content.weight not in [0, 0.0]
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
attention_mask = [1] * len(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class Chats(BaseModel):
"""
top level data structure for chat conversations
"""
conversation: List[Messages]
def __str__(self) -> str:
return "".join(str(c) for c in self.conversation)
def tokenized(
self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100
) -> dict[str, List[int]]:
input_ids = []
attention_mask = []
labels = []
for msg in self.conversation:
msg_results = msg.tokenized(tokenizer, ignore_index)
input_ids.extend(msg_results["input_ids"])
attention_mask.extend(msg_results["attention_mask"])
labels.extend(msg_results["labels"])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class ChatFormattedChats(Chats):
"""
Chat formatted chats with formatter and optional train on inputs
"""
formatter: Callable # [[Union[dict, Chats]], Chats]
train_on_inputs: bool = False
def model_post_init(self, __context):
for i, msg in enumerate(self.conversation):
self.conversation[i] = self.formatter(msg, message_index=i)
if self.train_on_inputs:
self.conversation[i].weight = 1
class PreferenceChats(BaseModel):
"""
representation for preference data for chat
"""
prompt: List[Messages]
chosen: Messages
rejected: Messages

View File

View File

@@ -0,0 +1,55 @@
"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union
from datasets import Dataset
from transformers import PreTrainedTokenizer
from axolotl.core.chat.messages import ChatFormattedChats
class TokenizedChatDataset(Dataset):
"""
Tokenized chat dataset
"""
def __init__(
self,
data: Dataset,
model_transform: Union[PreTrainedTokenizer, Callable],
*args,
message_transform: Optional[Callable] = None,
formatter=None,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
):
def map_fn(ex):
if message_transform is not None:
ex = message_transform(ex)
if formatter is not None:
ex = ChatFormattedChats(
formatter=formatter,
**ex,
)
else:
ex = ChatFormattedChats(
**ex,
)
return ex.tokenized(model_transform)
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,
num_proc=num_proc,
keep_in_memory=keep_in_memory,
remove_columns=features,
desc="Tokenizing Chats",
)
super().__init__(tokenized_data.data, *args, **kwargs)

View File

@@ -0,0 +1,150 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
train_on_inputs=False,
conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
message_field_content: Union[str, list[str]] = [
"value",
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
):
"""Builds a transform that takes a row from the dataset and converts it to a Chat
Args:
train_on_inputs (bool, optional):
If True, the transform will train on the inputs. If False, the transform will train on the targets.
Defaults to False.
conversations_field (str, optional):
The field name of the conversations. Defaults to "conversations".
message_field_role (str | list[str], optional):
The field name of the role. Defaults to "role".
message_field_content (str | list[str], optional):
The field name of the message content. Defaults to "content".
message_field_training (str | list[str], optional):
The field name of the train/weight. Defaults to "weight".
Returns:
Callable:
A function that takes a list of conversations and returns a list of messages.
"""
message_field_role = (
[message_field_role]
if isinstance(message_field_role, str)
else message_field_role
)
message_field_content = (
[message_field_content]
if isinstance(message_field_content, str)
else message_field_content
)
message_weight_fields = (
[message_field_training]
if isinstance(message_field_training, str)
else message_field_training
)
role_value_mappings = {
"system": "system",
"user": "user",
"human": "user",
"assistant": "assistant",
"gpt": "assistant",
"tool": "tool",
"ipython": "ipython",
}
if train_on_inputs:
role_default_weights_mappings = {
"system": 1,
"user": 1,
"assistant": 1,
"tool": 1,
"ipython": 1,
}
else:
role_default_weights_mappings = {
"system": 0,
"user": 0,
"assistant": 1,
"tool": 0,
"ipython": 0,
}
def transform_builder(sample: Mapping[str, Any]):
if conversations_field not in sample:
raise ValueError(f"Field '{conversations_field}' not found in sample.")
# if none of the role fields are in the message, raise an error
if not any(
role in sample[conversations_field][0] for role in message_field_role
):
raise ValueError("No role field found in message.")
role_field = next(
role
for role in message_field_role
if role in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_content
):
raise ValueError("No message_content field found in message.")
message_content_field = next(
field
for field in message_field_content
if field in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_training
):
message_weight_field = None
else:
message_weight_field = next(
field
for field in message_weight_fields
if field in sample[conversations_field][0]
)
messages = []
for message in sample[conversations_field]:
role = role_value_mappings[message[role_field]]
weight = (
int(message[message_weight_field])
if message_weight_field
else role_default_weights_mappings[role]
)
# TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
if isinstance(message[message_content_field], str):
messages.append(
{
"role": role,
"content": [
{
"type": "text",
"value": message[message_content_field],
}
],
"weight": weight,
}
)
else:
messages.append(
{
"role": role,
"content": message[message_content_field],
"weight": weight,
}
)
return {"conversation": messages}
return transform_builder

View File

@@ -43,12 +43,14 @@ from trl import (
KTOTrainer,
ORPOConfig,
ORPOTrainer,
RewardConfig,
RewardTrainer,
)
from trl.trainer.utils import pad_to_length
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -301,6 +303,13 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
)
@dataclass
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
@@ -398,12 +407,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def __init__(
self,
*_args,
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
@@ -1039,6 +1046,14 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
@@ -1111,6 +1126,12 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
return callbacks
@@ -1179,6 +1200,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
@@ -1203,6 +1229,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -1430,11 +1458,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_name if self.cfg.use_wandb else None
)
if self.cfg.use_wandb:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
@@ -1537,6 +1570,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
@@ -1580,10 +1616,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config"
] = self.cfg.accelerator_config
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model
else AxolotlRewardConfig
)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)
@@ -1605,10 +1644,24 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not self.cfg.reward_model:
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not self.cfg.reward_model:
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
@@ -1616,16 +1669,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
eval_data_collator=self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
@@ -1659,9 +1703,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
RewardDataCollatorWithPadding,
]
]
if use_batch_sampler_collator:
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (

View File

@@ -159,6 +159,29 @@ class BasePlugin:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
def post_train(self, cfg, model):
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
"""
def post_train_unload(self, cfg):
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def load_plugin(plugin_name: str) -> BasePlugin:
"""
@@ -381,3 +404,17 @@ class PluginManager:
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_train_unload(cfg)

View File

@@ -0,0 +1,13 @@
# LM Eval Harness
### Usage
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
```

View File

@@ -0,0 +1,42 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
class LMEvalPlugin(BasePlugin):
"""
Plugin for LM Evaluation Harness integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -0,0 +1,15 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel
class LMEvalArgs(BaseModel):
"""
Input args for lm eval harness
"""
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8

View File

@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)

View File

@@ -11,6 +11,10 @@ LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
try:
if strategy == "messages":
from .messages import load as messages_load
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
@@ -31,4 +35,5 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
return None
raise exc
return None

View File

@@ -0,0 +1,10 @@
### example yaml
```yaml
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
```

View File

@@ -0,0 +1,35 @@
"""Module to load prompt strategies."""
import importlib
import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg):
# pylint: disable=duplicate-code
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
)
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
return None

View File

@@ -0,0 +1,88 @@
"""
Bradley-Terry model with chat template prompt strategy.
"""
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import chat_templates
class BTChatTemplateStrategy(ChatTemplateStrategy):
"""
Bradley-Terry reward model pairwise chat template prompt strategy.
"""
def tokenize_prompt(self, prompt):
"""
:param prompt: the actual row of data from the underlying dataset
:return:
"""
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]})
rejected_tokenized = super().tokenize_prompt(prompt)
return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0,
"input_ids_rejected": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0,
}
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail"
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy = BTChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -0,0 +1,27 @@
"""
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
"""
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
prompt = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -403,6 +403,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
prompter_params = {

View File

@@ -0,0 +1,34 @@
"""Module to load message prompt strategies."""
import importlib
import inspect
import logging
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
def load(tokenizer, cfg, ds_cfg, processor=None):
try:
strategy = ds_cfg.get("input_transform", "chat")
# pylint: disable=duplicate-code
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.messages"
)
func = getattr(mod, load_fn)
load_kwargs = {}
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -0,0 +1,84 @@
"""
Chat dataset wrapping strategy for new internal messages representations
"""
from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset
from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
"""
Chat dataset wrapping strategy for new internal messages representations
"""
def __init__(
self,
processor,
message_transform=None,
formatter=None,
**kwargs, # pylint: disable=unused-argument
):
"""
:param processor: tokenizer or image processor
:param kwargs:
"""
self.processor = processor
self.dataset = None
self.message_transform = message_transform
self.formatter = formatter
def wrap_dataset(
self,
dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs, # pylint: disable=unused-argument
):
self.dataset = TokenizedChatDataset(
dataset,
message_transform=self.message_transform,
model_transform=self.processor,
formatter=self.formatter,
process_count=process_count,
keep_in_memory=keep_in_memory,
)
return self.dataset
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
field_messages = ds_cfg.get("field_messages")
message_field_role = ds_cfg.get("message_field_role")
message_field_content = ds_cfg.get("message_field_content")
message_field_training = ds_cfg.get("message_field_training")
builder_kwargs = {}
if field_messages:
builder_kwargs["conversations_field"] = field_messages
if message_field_role:
builder_kwargs["message_field_role"] = message_field_role
if message_field_content:
builder_kwargs["message_field_content"] = message_field_content
if message_field_training:
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):
from axolotl.core.chat.format.llama3x import format_message # noqa F811
message_transform: Callable = chat_message_transform_builder(
train_on_inputs=ds_cfg.get("train_on_inputs", False),
**builder_kwargs,
)
strategy = ChatMessageDatasetWrappingStrategy(
tokenizer, message_transform=message_transform, formatter=format_message
)
return strategy

View File

@@ -61,6 +61,9 @@ def build_loader(
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.",
)
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg

View File

@@ -30,6 +30,12 @@ class InvalidDataException(Exception):
"""
class DatasetWrappingStrategy(abc.ABC):
"""
Abstract class for wrapping datasets for Chat Messages
"""
class PromptTokenizingStrategy(abc.ABC):
"""
Abstract class for tokenizing strategies

View File

@@ -10,7 +10,6 @@ from typing import Optional, Tuple, Union
import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
@@ -97,12 +96,11 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model.generation_config.do_sample = True
if model.generation_config is not None:
model.generation_config.do_sample = True
model_ref = None
if cfg.rl and cfg.rl != "orpo":

View File

@@ -1,8 +1,12 @@
"""
Basic utils for Axolotl
"""
import importlib
import importlib.util
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None

View File

@@ -29,7 +29,7 @@ from transformers import (
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_mlflow_available
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
references=[[r] for r in references],
predictions=predictions,
)
scores[metric_name] = score
scores["eval_" + metric_name] = score
return scores
def predict_with_generate():
@@ -747,6 +747,15 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
elif logger == "comet_ml" and is_comet_available():
import comet_ml
experiment = comet_ml.get_running_experiment()
if experiment:
experiment.log_table(
f"{name} - Predictions vs Ground Truth.csv",
pd.DataFrame(table_data),
)
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)

View File

@@ -0,0 +1,43 @@
"""Comet module for trainer callbacks"""
import logging
from typing import TYPE_CHECKING
import comet_ml
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
"""Callback to save axolotl config to comet"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
comet_experiment = comet_ml.start(source="axolotl")
comet_experiment.log_other("Created from", "axolotl")
comet_experiment.log_asset(
self.axolotl_config_path,
file_name="axolotl-config",
)
LOG.info(
"The Axolotl config has been saved to the Comet Experiment under assets."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
return control

File diff suppressed because one or more lines are too long

View File

@@ -4,6 +4,7 @@ Collators for multi-modal chat messages and packing
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
@@ -52,7 +53,12 @@ class MultiModalChatDataCollator(DataCollatorMixin):
)
for example in examples
]
images = [example["images"] for example in examples]
images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]

View File

@@ -0,0 +1,93 @@
"""Module for wandb utilities"""
import logging
import os
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.utils.comet_")
COMET_ENV_MAPPING_OVERRIDE = {
"comet_mode": "COMET_START_MODE",
"comet_online": "COMET_START_ONLINE",
}
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
"auto_log_co2": "COMET_AUTO_LOG_CO2",
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
"log_code": "COMET_AUTO_LOG_CODE",
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
"log_graph": "COMET_AUTO_LOG_GRAPH",
"name": "COMET_START_EXPERIMENT_NAME",
"offline_directory": "COMET_OFFLINE_DIRECTORY",
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
"tags": "COMET_START_EXPERIMENT_TAGS",
}
def python_value_to_environ_value(python_value):
if isinstance(python_value, bool):
if python_value is True:
return "true"
return "false"
if isinstance(python_value, int):
return str(python_value)
if isinstance(python_value, list): # Comet only have one list of string parameter
return ",".join(map(str, python_value))
return python_value
def setup_comet_env_vars(cfg: DictDefault):
# TODO, we need to convert Axolotl configuration to environment variables
# as Transformers integration are call first and would create an
# Experiment first
for key in cfg.keys():
if key.startswith("comet_") and key != "comet_experiment_config":
value = cfg.get(key, "")
if value is not None and value != "":
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
final_value = python_value_to_environ_value(value)
os.environ[env_variable_name] = final_value
if cfg.comet_experiment_config:
for key, value in cfg.comet_experiment_config.items():
if value is not None and value != "":
config_env_variable_name = (
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
)
if config_env_variable_name is None:
LOG.warning(
f"Unknown Comet Experiment Config name {key}, ignoring it"
)
continue
final_value = python_value_to_environ_value(value)
os.environ[config_env_variable_name] = final_value
# Enable comet if project name is present
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
cfg.use_comet = True

View File

@@ -102,10 +102,12 @@ class SFTDataset(BaseModel):
path: Optional[str] = None
split: Optional[str] = None
type: Optional[Union[str, UserDefinedPrompterType]] = None
input_transform: Optional[str] = None
shards: Optional[int] = None
conversation: Optional[str] = None
chat_template: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None
input_format: Optional[str] = None
name: Optional[str] = None
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
@@ -125,6 +127,7 @@ class SFTDataset(BaseModel):
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
class UserDefinedDPOType(BaseModel):
@@ -146,6 +149,7 @@ class DPODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None
class UserDefinedKTOType(BaseModel):
@@ -167,6 +171,7 @@ class KTODataset(BaseModel):
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
class RLType(str, Enum):
@@ -184,7 +189,9 @@ class ChatTemplate(str, Enum):
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
@@ -193,6 +200,7 @@ class ChatTemplate(str, Enum):
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -444,6 +452,7 @@ class MLFlowConfig(BaseModel):
use_mlflow: Optional[bool] = None
mlflow_tracking_uri: Optional[str] = None
mlflow_experiment_name: Optional[str] = None
mlflow_run_name: Optional[str] = None
hf_mlflow_log_artifacts: Optional[bool] = None
@@ -489,6 +498,19 @@ class WandbConfig(BaseModel):
return data
class CometConfig(BaseModel):
"""Comet configuration subset"""
use_comet: Optional[bool] = None
comet_api_key: Optional[str] = None
comet_workspace: Optional[str] = None
comet_project_name: Optional[str] = None
comet_experiment_key: Optional[str] = None
comet_mode: Optional[str] = None
comet_online: Optional[bool] = None
comet_experiment_config: Optional[Dict[str, Any]] = None
class GradioConfig(BaseModel):
"""Gradio configuration subset"""
@@ -509,6 +531,7 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
CometConfig,
LISAConfig,
GradioConfig,
RemappedParameters,
@@ -528,6 +551,7 @@ class AxolotlInputConfig(
resize_token_embeddings_to_32x: Optional[bool] = None
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
@@ -833,6 +857,17 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def hint_reward_model_pad(cls, data):
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using reward_model"
)
if data.get("pad_to_sequence_len") is None:
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before")
@classmethod
def check_gas_bsz(cls, data):
@@ -966,6 +1001,26 @@ class AxolotlInputConfig(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if data.get("do_bench_eval") and not (
data.get("evals_per_epoch") or data.get("eval_steps")
):
raise ValueError(
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
)
return data
@model_validator(mode="before")
@classmethod
def check_test_datasets_bench(cls, data):
if (
data.get("do_bench_eval")
and not data.get("test_datasets")
and not data.get("val_set_size")
):
LOG.warning(
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
)
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
return data
@model_validator(mode="before")

View File

@@ -90,6 +90,7 @@ def load_prepare_dpo_datasets(cfg):
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)

View File

@@ -19,10 +19,12 @@ from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy,
DatasetWrappingStrategy,
GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy,
@@ -242,6 +244,7 @@ def load_tokenized_prepared_datasets(
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -346,6 +349,7 @@ def load_tokenized_prepared_datasets(
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
@@ -380,6 +384,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
@@ -389,6 +394,7 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
@@ -433,8 +439,8 @@ def load_tokenized_prepared_datasets(
config_dataset=config_dataset,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
dataset=ds,
d_prompt_style=d_prompt_style,
processor=processor,
)
@@ -454,7 +460,7 @@ def load_tokenized_prepared_datasets(
else:
LOG.debug("NOT shuffling merged datasets")
if not cfg.skip_prepare_dataset:
if cfg.sample_packing and not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
@@ -569,7 +575,7 @@ def get_dataset_wrapper(
d_base_type,
dataset,
d_prompt_style=None,
processor=None,
processor=None, # pylint: disable=unused-argument
):
dataset_wrapper = None
dataset_prompter = None
@@ -604,8 +610,10 @@ def get_dataset_wrapper(
)
elif cfg.skip_prepare_dataset:
dataset_wrapper = dataset
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
elif ds_strategy := config_dataset.type.startswith(
"bradley_terry"
) and bradley_terry_load(
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
@@ -613,6 +621,18 @@ def get_dataset_wrapper(
dataset,
**ds_kwargs,
)
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
):
if isinstance(ds_strategy, DatasetWrappingStrategy):
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
else:
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(

View File

@@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.cuda
from accelerate.logging import get_logger
from datasets import set_caching_enabled
from datasets import disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
disable_caching()
yield
finally:
set_caching_enabled(True)
enable_caching()
def add_position_ids(sample):
@@ -306,7 +306,11 @@ def process_pretraining_datasets_for_packing(
def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens and not cfg.skip_prepare_dataset:
if (
not cfg.total_num_tokens
and not cfg.skip_prepare_dataset
and not cfg.reward_model
):
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
@@ -323,6 +327,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
not skip_estimates
and not cfg.total_supervised_tokens
and not cfg.skip_prepare_dataset
and not cfg.reward_model
):
total_supervised_tokens = (
train_dataset.data.column("labels")

View File

View File

View File

@@ -0,0 +1,197 @@
"""
Tests for the chat messages module
"""
import unittest
import pytest
from transformers import AddedToken, AutoTokenizer
from axolotl.core.chat.format.chatml import format_message
from axolotl.core.chat.messages import ChatFormattedChats, Chats
@pytest.fixture(scope="session", name="llama_tokenizer")
def llama_tokenizer_fixture():
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
@pytest.fixture(scope="session", name="chatml_tokenizer")
def llama_tokenizer_w_chatml(llama_tokenizer):
llama_tokenizer.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
llama_tokenizer.add_tokens(
[
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
]
)
return llama_tokenizer
@pytest.fixture(scope="session", name="chat_msgs")
def chat_msgs_fixture():
return {
"conversation": [
{
"role": "system",
"content": [
{"type": "text", "value": "You are a helpful assistant."},
],
},
{
"role": "user",
"content": [
{"type": "text", "value": "What is today's stock price of Apple?"},
],
},
{
"role": "assistant",
"content": [
{
"type": "tool_call",
"value": {
"name": "get_date",
"arguments": {},
},
},
{
"type": "tool_call",
"value": {
"name": "get_stock_price",
"arguments": {"symbol": "AAPL"},
},
},
],
"weight": 1,
},
{
"role": "tool",
"content": [
{
"type": "tool_response",
"value": {
"name": "get_date",
"content": {"date": "2024-09-09"},
},
},
{
"type": "tool_response",
"value": {
"name": "get_stock_price",
"content": {"symbol": "AAPL", "price": 123.45},
},
},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"value": "The stock price of Apple is $123.45.\n",
"weight": 0,
},
{
"type": "text",
"value": "<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>",
},
{
"type": "text",
"value": "The stock price of Apple on September 9, 2024 is $123.45.",
},
],
"weight": 1,
},
]
}
class TestMessagesCase:
"""
Test cases for the chat messages module
"""
def test_tool_call_stringify(self, chat_msgs):
chat_msgs_as_obj = Chats(**chat_msgs)
assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str(
chat_msgs_as_obj.conversation[2].content[1].value
)
def test_chatml_formatted_wrapper(self, chat_msgs):
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
target_chatml = """<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is today's stock price of Apple?<|im_end|>
<|im_start|>assistant
<tool_call>
{"name": "get_date", "arguments": {}}
</tool_call>
<tool_call>
{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}
</tool_call>
<|im_end|>
<|im_start|>tool
<tool_response>
{"name": "get_date", "content": {"date": "2024-09-09"}}
</tool_response>
<tool_response>
{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}}
</tool_response>
<|im_end|>
<|im_start|>assistant
The stock price of Apple is $123.45.
<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n"""
assert target_chatml == str(chat_msg_formatted)
def test_chatml_formatting_tool_call(self, chat_msgs):
chat_msgs_as_obj = Chats(**chat_msgs)
target_chatml_turn2 = """<|im_start|>assistant\n<tool_call>\n{"name": "get_date", "arguments": {}}\n</tool_call>\n<tool_call>\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n</tool_call>\n<|im_end|>\n"""
assert target_chatml_turn2 == str(
format_message(chat_msgs_as_obj.conversation[2])
)
def test_train_labels(self, chatml_tokenizer, chat_msgs):
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer)
# fmt: off
target_labels = [
-100, -100, -100, # role
27, 14506, 13735, 397, 5018, 609, 794,
330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524,
14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794,
330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314,
794, 330, 84016, 43, 96742, 524, 14506, 13735, 397,
128256, # <|im_end|>
-100 # trailing newline
]
# fmt: on
assert tokenized["labels"] == target_labels
def test_train_labels_2(self, chatml_tokenizer, chat_msgs):
# also test if indivudal contents are set not to train
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer)
# fmt: off
target_labels = [
-100, -100, -100, # role
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response
27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430,
315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457,
5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315,
8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400,
4513, 13, 1774, 13,
128256, # <|im_end|>
-100, # trailing newline
]
# fmt: on
assert tokenized["labels"] == target_labels
if __name__ == "__main__":
unittest.main()

View File

@@ -19,6 +19,8 @@ from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@pytest.fixture(scope="session", autouse=True)
def download_model():
@@ -346,3 +348,115 @@ class TestMultiGPULlama(unittest.TestCase):
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_ds_zero3_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_ds_zero3_qlora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -0,0 +1,74 @@
"""
E2E tests for reward model lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestRewardModelLoraLlama(unittest.TestCase):
"""
Test case for Llama reward models using LoRA
"""
@with_temp_dir
def test_rm_fft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"model_type": "AutoModelForSequenceClassification",
"tokenizer_type": "LlamaTokenizer",
"chat_template": "alpaca",
"reward_model": True,
"sequence_len": 1024,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "bradley_terry.chat_template",
},
],
"remove_unused_columns": False,
"max_steps": 10,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -0,0 +1,62 @@
"""
tests for chat_template prompt strategy
"""
# pylint: disable=duplicate-code
import logging
import unittest
from axolotl.prompt_strategies.messages.chat import load
from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
class TestMessagesChatLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy.
"""
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
LOG.info("Loading llama-3 tokenizer with assistant dataset")
strategy = load(
llama3_tokenizer,
DictDefault(
{
"train_on_inputs": False,
"sequence_len": 512,
}
),
DictDefault(
{
"chat_template": "llama3",
"message_field_role": "role",
"message_field_content": "content",
"field_messages": "messages",
}
),
)
res = strategy.wrap_dataset(assistant_dataset)
input_ids = res[0]["input_ids"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
if __name__ == "__main__":
unittest.main()

View File

@@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault
@@ -267,6 +268,143 @@ class TestDatasetPreparation(unittest.TestCase):
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_hub_with_revision_with_dpo(self):
"""Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"revision": "ea82cff",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":
unittest.main()

View File

@@ -9,6 +9,7 @@ from typing import Optional
import pytest
from pydantic import ValidationError
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault
@@ -1329,3 +1330,105 @@ class TestValidationWandb(BaseValidation):
os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_DISABLED", None)
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
class TestValidationComet(BaseValidation):
"""
Validation test for comet
"""
def test_comet_sets_env(self, minimal_cfg):
from axolotl.utils.comet_ import setup_comet_env_vars
comet_config = {
"comet_api_key": "foo",
"comet_workspace": "some_workspace",
"comet_project_name": "some_project",
"comet_experiment_key": "some_experiment_key",
"comet_mode": "get_or_create",
"comet_online": False,
"comet_experiment_config": {
"auto_histogram_activation_logging": False,
"auto_histogram_epoch_rate": 2,
"auto_histogram_gradient_logging": True,
"auto_histogram_tensorboard_logging": False,
"auto_histogram_weight_logging": True,
"auto_log_co2": False,
"auto_metric_logging": True,
"auto_metric_step_rate": 15,
"auto_output_logging": False,
"auto_param_logging": True,
"comet_disabled": False,
"display_summary_level": 2,
"distributed_node_identifier": "some_distributed_node_identifier",
"log_code": True,
"log_env_cpu": False,
"log_env_details": True,
"log_env_disk": False,
"log_env_gpu": True,
"log_env_host": False,
"log_env_network": True,
"log_git_metadata": False,
"log_git_patch": True,
"log_graph": False,
"name": "some_name",
"offline_directory": "some_offline_directory",
"parse_args": True,
"tags": ["tag1", "tag2"],
},
}
cfg = DictDefault(comet_config) | minimal_cfg
new_cfg = validate_config(cfg)
setup_comet_env_vars(new_cfg)
comet_env = {
key: value for key, value in os.environ.items() if key.startswith("COMET_")
}
assert (
len(comet_env)
== len(comet_config) + len(comet_config["comet_experiment_config"]) - 1
)
assert comet_env == {
"COMET_API_KEY": "foo",
"COMET_AUTO_LOG_CLI_ARGUMENTS": "true",
"COMET_AUTO_LOG_CO2": "false",
"COMET_AUTO_LOG_CODE": "true",
"COMET_AUTO_LOG_DISABLE": "false",
"COMET_AUTO_LOG_ENV_CPU": "false",
"COMET_AUTO_LOG_ENV_DETAILS": "true",
"COMET_AUTO_LOG_ENV_DISK": "false",
"COMET_AUTO_LOG_ENV_GPU": "true",
"COMET_AUTO_LOG_ENV_HOST": "false",
"COMET_AUTO_LOG_ENV_NETWORK": "true",
"COMET_AUTO_LOG_GIT_METADATA": "false",
"COMET_AUTO_LOG_GIT_PATCH": "true",
"COMET_AUTO_LOG_GRAPH": "false",
"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false",
"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2",
"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true",
"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false",
"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true",
"COMET_AUTO_LOG_METRIC_STEP_RATE": "15",
"COMET_AUTO_LOG_METRICS": "true",
"COMET_AUTO_LOG_OUTPUT_LOGGER": "false",
"COMET_AUTO_LOG_PARAMETERS": "true",
"COMET_DISPLAY_SUMMARY_LEVEL": "2",
"COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier",
"COMET_EXPERIMENT_KEY": "some_experiment_key",
"COMET_OFFLINE_DIRECTORY": "some_offline_directory",
"COMET_PROJECT_NAME": "some_project",
"COMET_START_EXPERIMENT_NAME": "some_name",
"COMET_START_EXPERIMENT_TAGS": "tag1,tag2",
"COMET_START_MODE": "get_or_create",
"COMET_START_ONLINE": "false",
"COMET_WORKSPACE": "some_workspace",
}
for key in comet_env.keys():
os.environ.pop(key, None)