Compare commits

..

2 Commits

Author SHA1 Message Date
sunny
a02af506ed pixtral example 2024-10-03 16:11:15 -04:00
sunny
431a0b0f9d added pixtral example 2024-10-03 16:01:21 -04:00
73 changed files with 238 additions and 3411 deletions

View File

@@ -28,13 +28,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -27,7 +27,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
@@ -84,7 +84,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -26,7 +26,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
@@ -83,7 +83,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -25,7 +25,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"] pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -91,7 +91,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
nightly_build: "true" nightly_build: "true"

View File

@@ -36,7 +36,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
python_version: ["3.10", "3.11"] python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1"] pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20 timeout-minutes: 20
steps: steps:
@@ -94,7 +94,7 @@ jobs:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.4.1 pytorch: 2.4.0
num_gpus: 1 num_gpus: 1
axolotl_extras: axolotl_extras:
steps: steps:

View File

@@ -1,3 +1,3 @@
[settings] [settings]
profile=black profile=black
known_third_party=wandb,comet_ml known_third_party=wandb

View File

@@ -14,7 +14,7 @@ Features:
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking - Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed - Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud - Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet - Log results and optionally checkpoints to wandb or mlflow
- And more! - And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25"> <a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript - typescript
type: ... # unimplemented custom format type: ... # unimplemented custom format
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template) # fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ... - path: ...
type: sharegpt type: sharegpt
@@ -515,22 +515,6 @@ wandb_name:
wandb_log_model: wandb_log_model:
``` ```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens ##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:

View File

@@ -83,14 +83,13 @@ lora_on_cpu: true
datasets: datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
# Optional[str] fastchat conversation type, only used with type: sharegpt # Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -124,48 +123,6 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column # For `completion` datsets only, uses the provided field instead of `text` column
field: field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
system: ["system"]
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`. # If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. # The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true shuffle_merged_datasets: true
@@ -184,16 +141,9 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto' # use RL training: 'dpo', 'ipo', 'kto'
rl: rl:
# The name of the chat template to use for training, following values are supported: # Saves the desired chat template to the tokenizer_config.json for easier inferencing
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. # Currently supports chatml and inst (mistral/mixtral)
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py chat_template: chatml
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message # Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so
@@ -315,21 +265,8 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it # mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name mlflow_experiment_name: # Your experiment name
mlflow_run_name: # Your run name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Comet configuration if you're using it
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
use_comet: # Enable or disable Comet integration.
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Where to save the full-finetuned model to # Where to save the full-finetuned model to
output_dir: ./completed-model output_dir: ./completed-model
@@ -364,7 +301,7 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"] eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

View File

@@ -6,8 +6,6 @@ order: 3
## sharegpt ## sharegpt
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt) conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```{.json filename="data.jsonl"} ```{.json filename="data.jsonl"}
@@ -71,138 +69,3 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
```{.json filename="data.jsonl"} ```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
``` ```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "content": "..."}]}
```
See `config.qmd` for full configs and supported templates.
### Migrating from sharegpt
Most configs can be adapted as follows:
```yaml
# old
chat_template: chatml
datasets:
- path: ...
type: sharegpt
conversation: chatml
# new (if using tokenizer's chat_template)
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
- path: ...
type: chat_template
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
```{.json filename="data.jsonl"}
{
"conversations": [
{"from": "system", "value": "You are an AI assistant.", "train": false},
{"from": "human", "value": "Hello", "train": false},
{"from": "assistant", "value": "Hello", "train": true},
{"from": "human", "value": "How are you?", "train": true},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": false},
{"begin_offset": 9, "end_offset": 18, "train": true},
{"begin_offset": 19, "end_offset": 30, "train": false},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": true,
},
{"from": "assistant", "value": "Hi there!", "train": true}
]
}
```
The configuration would look like:
```yaml
datasets:
- path: ...
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.

View File

@@ -1,63 +0,0 @@
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,65 @@
base_model: mistral-community/pixtral-12b
processor_type: AutoProcessor
load_in_8bit: true
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,11 +1,11 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.13.2 peft==0.13.0
transformers==4.45.2 transformers==4.45.1
tokenizers>=0.20.1 tokenizers>=0.19.1
bitsandbytes==0.44.1 bitsandbytes==0.44.0
accelerate==1.0.1 accelerate==0.34.2
datasets==3.0.1 datasets==2.21.0
deepspeed==0.14.4 deepspeed==0.14.4
pydantic==2.6.3 pydantic==2.6.3
addict addict
@@ -16,7 +16,7 @@ flash-attn==2.6.3
sentencepiece sentencepiece
wandb wandb
einops einops
xformers==0.0.28.post1 xformers==0.0.27
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama
@@ -46,11 +46,3 @@ gcsfs>=2024.5.0
trl==0.9.6 trl==0.9.6
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore
# lm eval harness
lm_eval==0.4.4
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.5.0

View File

@@ -1,315 +0,0 @@
accelerate==0.34.1
addict==2.4.0
aiofiles==23.2.1
aiohttp==3.9.0
aiosignal==1.3.1
aiostream==0.5.2
alembic==1.13.1
annotated-types==0.6.0
annoy==1.17.3
ansible==6.7.0
ansible-core==2.13.13
ansible-vault==2.1.0
anyio==3.7.1
appdirs==1.4.4
art==6.0
asgiref==3.7.2
async-timeout==4.0.2
attrdict==2.0.1
attrs==22.2.0
awscli==1.32.75
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
backoff==2.2.1
base58==2.1.1
beartype==0.17.2
bitnet==0.2.1
bitsandbytes==0.42.0
bittensor==6.7.0
black==23.7.0
blinker==1.7.0
boto3==1.34.75
botocore==1.34.75
cachetools==5.3.3
cachy==0.1.1
certifi==2023.7.22
cffi==1.16.0
cfgv==3.3.1
chai-guanaco==1.2.4
charset-normalizer==3.2.0
cleo==0.6.8
click==8.1.7
cloudpickle==2.0.0
cohere==4.11.2
colorama==0.4.4
coloredlogs==15.0.1
CoLT5-attention==0.10.20
contextlib2==21.6.0
contourpy==1.2.0
cryptography==41.0.3
cycler==0.12.1
cytoolz==0.12.3
databricks-cli==0.18.0
dataclasses-json==0.5.7
datasets==2.11.0
ddt==1.6.0
decorator==5.1.1
deepspeed==0.15.0
# Editable Git install with no remote (dialogpt==0.1)
-e /Users/wing/Projects/ml/dialogpt/src
dill==0.3.6
distlib==0.3.6
docker==7.0.0
docker-pycreds==0.4.0
docstring-parser==0.15
docutils==0.16
ecdsa==0.18.0
einops==0.7.0
einops-exts==0.0.4
einx==0.1.3
entrypoints==0.4
eth-hash==0.6.0
eth-keys==0.5.0
eth-typing==4.0.0
eth-utils==2.3.1
evaluate==0.4.0
exceptiongroup==1.1.1
fastapi==0.109.2
fastcore==1.5.29
ffmpy==0.4.0
filelock==3.12.2
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
fire==0.5.0
first==2.0.2
flake8==7.0.0
Flask==3.0.1
fonttools==4.47.2
frozendict==2.4.1
frozenlist==1.3.3
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
fsspec==2023.6.0
fuzzywuzzy==0.18.0
gitdb==4.0.10
GitPython==3.1.31
google-pasta==0.2.0
gradio==4.42.0
gradio_client==1.3.0
greenlet==2.0.2
grpclib==0.4.7
gunicorn==21.2.0
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==0.17.3
httpx==0.24.1
huggingface-hub==0.23.4
humanfriendly==10.0
hyperframe==6.0.1
identify==2.5.24
idna==3.4
immutables==0.20
importlib-metadata==6.7.0
importlib-resources==6.1.1
inflection==0.5.1
iniconfig==2.0.0
itsdangerous==2.1.2
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
jsonlines==3.1.0
jsonschema==2.6.0
kiwisolver==1.4.5
langchain==0.0.144
Levenshtein==0.24.0
libcst==1.1.0
liger-kernel==0.0.0
lion-pytorch==0.1.2
llama-cpp-python==0.1.36
llvmlite==0.40.1
local-attention==1.9.0
loguru==0.7.0
Mako==1.3.2
Markdown==3.5.2
markdown-it-py==3.0.0
markdown2==2.4.10
MarkupSafe==2.1.2
marshmallow==3.19.0
marshmallow-enum==1.5.1
matplotlib==3.8.2
mccabe==0.7.0
mdurl==0.1.2
MEGABYTE-pytorch==0.0.7
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
mlflow==2.10.0
modal==0.62.77
more-itertools==10.2.0
mpmath==1.2.1
msgpack==1.0.7
msgpack-numpy-opentensor==0.5.0
multidict==6.0.4
multiprocess==0.70.14
munch==2.5.0
mypy==1.3.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
netaddr==0.10.1
networkx==3.0rc1
nh3==0.2.14
nodeenv==1.8.0
nomic==2.0.2
numba==0.57.1
numexpr==2.8.4
numpy==1.24.4
oauthlib==3.2.2
openai==0.27.4
openapi==1.1.0
openapi-schema-pydantic==1.2.4
optimum==1.8.6
orjson==3.10.7
packaging==23.1
pandas==2.0.0
parameterized==0.9.0
password-strength==0.0.3.post2
pastel==0.1.1
pathos==0.3.0
pathspec==0.11.1
pathtools==0.1.2
peft==0.11.1
pendulum==3.0.0
Pillow==9.5.0
pip-tools==1.11.0
platformdirs==3.2.0
pluggy==1.4.0
poetry==0.7.1
pox==0.3.2
ppft==1.7.6.6
pre-commit==3.3.2
prettytable==3.10.0
prompt-toolkit==3.0.39
protobuf==3.20.2
protobuf3-to-dict==0.1.5
psutil==5.9.5
psycopg==3.1.18
PuLP==2.8.0
py==1.11.0
py-bip39-bindings==0.1.11
py-cpuinfo==9.0.0
py-ed25519-zebra-bindings==1.0.1
py-sr25519-bindings==0.2.0
pyarrow==11.0.0
pyasn1==0.6.0
pycodestyle==2.11.1
pycparser==2.21
pycryptodome==3.20.0
pydantic==2.5.3
pydantic_core==2.14.6
pydub==0.25.1
pyfiglet==0.8.post1
pyflakes==3.2.0
Pygments==2.15.1
PyJWT==2.8.0
pylev==1.4.0
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==2.4.7
pyrsistent==0.14.11
pytest==8.0.2
pytest-asyncio==0.23.4
python-dateutil==2.8.2
python-dotenv==1.0.1
python-Levenshtein==0.24.0
python-multipart==0.0.9
pytz==2023.3
PyYAML==6.0.1
querystring-parser==1.2.4
rapidfuzz==3.6.1
regex==2023.6.3
requests==2.31.0
requests-toolbelt==0.8.0
resolvelib==0.8.1
responses==0.18.0
retry==0.9.2
rich==13.7.0
rsa==4.7.2
ruff==0.6.3
s3transfer==0.10.1
safetensors==0.4.5
sagemaker==2.148.0
scalecodec==1.2.7
schedulefree==1.2.1
schema==0.7.5
scikit-learn==1.4.0
scipy==1.9.3
seaborn==0.13.2
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==1.19.1
setproctitle==1.3.2
shellingham==1.5.4
shortuuid==1.0.11
shtab==1.6.5
sigtools==4.0.1
six==1.16.0
skypilot==0.4.1
smdebug-rulesconfig==1.0.1
smmap==5.0.0
sniffio==1.3.0
SQLAlchemy==1.4.47
sqlparse==0.4.4
starlette==0.36.3
substrate-interface==1.5.2
svgwrite==1.4.3
sympy==1.11.1
synchronicity==0.6.7
tabulate==0.9.0
tblib==1.7.0
tenacity==8.2.2
tensor-parallel==2.0.0
termcolor==2.2.0
text2art==0.2.0
threadpoolctl==3.2.0
tiktoken==0.6.0
time-machine==2.14.1
timm==0.9.16
tokenizers==0.19.1
tokenmonster==1.1.12
toml==0.9.6
tomli==2.0.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.2.0
torchdata==0.6.1
torchdiffeq==0.2.3
TorchFix==0.4.0
torchtext==0.15.2
torchvision==0.17.0
tqdm==4.66.2
transformers==4.44.2
trl==0.9.6
typer==0.12.5
types-certifi==2021.10.8.3
types-requests==2.31.0.20240125
types-setuptools==69.0.0.20240125
types-toml==0.10.8.7
typing==3.7.4.3
typing-inspect==0.8.0
typing_extensions==4.9.0
tyro==0.5.18
tzdata==2023.3
unique-names-generator==1.0.2
urllib3==2.2.2
uvicorn==0.22.0
vector_quantize_pytorch==1.14.1
virtualenv==20.23.0
voyager==2.0.2
wandb==0.16.2
watchfiles==0.21.0
wavedrom==2.0.3.post3
wcwidth==0.2.6
websocket-client==1.7.0
websockets==12.0
Werkzeug==3.0.1
wonderwords==2.2.0
xxhash==3.2.0
yarl==1.8.2
zetascale==2.2.7
zipp==3.15.0

View File

@@ -1,60 +0,0 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset
@click.command()
@click.argument("dataset", type=str)
@click.option("--split", type=str, default="train")
def parse_dataset(dataset=None, split="train"):
ds_cfg = {}
ds_cfg["path"] = dataset
ds_cfg["split"] = split
ds_cfg["type"] = "chat_template"
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
dataset = load_dataset(dataset, split=split)
features = dataset.features
feature_keys = features.keys()
field_messages = None
for key in ["conversation", "conversations", "messages"]:
if key in feature_keys:
field_messages = key
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
break
if not message_field_role:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
break
if not message_field_content:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
print(yaml.dump({"datasets": [ds_cfg]}))
if __name__ == "__main__":
parse_dataset()

View File

@@ -30,7 +30,6 @@ def parse_requirements():
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# don't install xformers on MacOS # don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
@@ -50,24 +49,14 @@ def parse_requirements():
else: else:
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 4): if (major, minor) >= (2, 3):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1") _install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2): elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1") _install_requires.append("xformers>=0.0.25.post1")
else: else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1") _install_requires.append("xformers>=0.0.23.post1")

View File

@@ -30,8 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
normalize_cfg_datasets, normalize_cfg_datasets,
normalize_config, normalize_config,
@@ -55,22 +54,8 @@ LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_axolotl_text_art(suffix=None):
def print_legacy_axolotl_text_art(suffix=None):
font = "nancyj" font = "nancyj"
ascii_text = " axolotl" ascii_text = " axolotl"
if suffix: if suffix:
@@ -83,13 +68,6 @@ def print_legacy_axolotl_text_art(suffix=None):
print_dep_versions() print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions(): def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages) max_len = max(len(pkg) for pkg in packages)
@@ -272,7 +250,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template) chat_template_str = chat_templates(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype) model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -443,8 +421,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
setup_mlflow_env_vars(cfg) setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg return cfg

View File

@@ -27,7 +27,6 @@ from axolotl.prompt_strategies.sharegpt import (
register_chatml_template, register_chatml_template,
register_llama3_template, register_llama3_template,
) )
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger("axolotl.cli.preprocess") LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -71,11 +70,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.warning(msg) LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching(): if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo": load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else:
else: load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.download: if parsed_cli_args.download:
model_name = parsed_cfg.base_model model_name = parsed_cfg.base_model

View File

@@ -3,11 +3,13 @@ CLI to run training on a model
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Tuple, Union
import fire import fire
from dotenv import load_dotenv from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -18,7 +20,6 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.prompt_strategies.sharegpt import ( from axolotl.prompt_strategies.sharegpt import (
register_chatml_template, register_chatml_template,
register_llama3_template, register_llama3_template,
@@ -38,7 +39,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return do_train(parsed_cfg, parsed_cli_args) return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg, cli_args) -> None: def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
@@ -63,13 +64,7 @@ def do_train(cfg, cli_args) -> None:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
del tokenizer
plugin_manager.post_train_unload(cfg)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,34 +0,0 @@
"""
ChatML transformation functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(
message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument
) -> Messages:
if message.is_chat_formatted:
return message
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|im_start|>{message.role}\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|im_end|>", weight=message.weight)
)
message.content.append(MessageContents(type="text", value="\n", weight=0))
message = wrap_tools(message)
message.is_chat_formatted = True
return message

View File

@@ -1,45 +0,0 @@
"""
Llama 3.x chat formatting functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages
from .shared import wrap_tools
def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
if message.is_chat_formatted:
return message
message_role = message.role
if message.role == "tool":
message_role = "ipython"
# prepend the role prefix within a MessageContents to message.content
message.content.insert(
0,
MessageContents(
type="text",
value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n",
weight=0,
),
)
message.content.append(
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
)
message = wrap_tools(message)
if message_index == 0:
message.content.insert(
0,
MessageContents(
type="text",
value="<|begin_of_text|>",
weight=0,
),
)
message.is_chat_formatted = True
return message

View File

@@ -1,47 +0,0 @@
"""
shared functions for format transforms
"""
from axolotl.core.chat.messages import MessageContents, Messages
def wrap_tools(message: Messages):
# loop over message.content by index to find tool calls, we need to wrap each with tags,
# so be wary of indexing issues when changing the list while iterating.
# iterate over the range in reverse order to avoid index shifting
for i in range(len(message.content) - 1, -1, -1):
if message.content[i].type == "tool_call":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_call>\n", weight=message.weight
),
)
# make sure the actual tool call content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_call>\n", weight=message.weight
),
)
elif message.content[i].type == "tool_response":
# append a </tool_call> MessageContents text tag after
message.content.insert(
i + 1,
MessageContents(
type="text", value="</tool_response>\n", weight=message.weight
),
)
# make sure the actual tool response content ends with a newline
message.content[i].has_newline = True
# prepend a <tool_call> MessageContents text tag before
message.content.insert(
i,
MessageContents(
type="text", value="<tool_response>\n", weight=message.weight
),
)
return message

View File

@@ -1,230 +0,0 @@
"""
internal message representations of chat messages
"""
import json
from enum import Enum
from typing import Any, Callable, List, Optional, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizer
class MessageRoles(str, Enum):
"""
Message roles for the system, user, assistant, and tools
"""
system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
ipython = ( # pylint: disable=invalid-name
# for responses from builtin tools
"ipython"
)
class MessageContentTypes(str, Enum):
"""
Message content types for text, image, audio, tool calls, and tool responses
"""
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
tool_response = "tool_response" # pylint: disable=invalid-name
class SpecialToken(str, Enum):
"""
Special tokens for beginning of string and end of string
"""
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
class ToolCallFunction(BaseModel):
"""
Tool call function with name and arguments
"""
name: str
arguments: dict[str, str]
class Tool(BaseModel):
"""
Tool with description, function, and parameters
"""
description: str
function: ToolCallFunction
parameters: dict[str, str] # .properties
class ToolCallContents(BaseModel):
"""
Tool call contents with name, arguments, and optional id
"""
name: str
arguments: dict[str, Union[str, int]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class ToolResponseContents(BaseModel):
"""
Tool response contents with name, content, and optional id
"""
name: str
content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None # pylint: disable=invalid-name
def __str__(self) -> str:
data = {"name": self.name, "content": self.content}
if self.id is not None:
data["id"] = self.id
return json.dumps(data)
class MessageContents(BaseModel):
"""
Message contents with type, value, metadata, weight, newline, and end of contents
"""
type: Union[str, MessageContentTypes]
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
has_newline: bool = False
eoc: bool = False # end of contents
def __str__(self) -> str:
str_val = str(self.value)
if self.has_newline and not str_val.endswith("\n"):
str_val += "\n"
return str_val
class Messages(BaseModel):
"""
Messages with role, content, metadata, weight, and chat formatting
"""
role: Union[MessageRoles, str] # allows for arbitrary roles
content: List["MessageContents"]
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
weight: Optional[Union[int, float]] = None
is_chat_formatted: bool = False
def __str__(self) -> str:
return "".join(str(c) for c in self.content)
def tokenized(
self, tokenizer: PreTrainedTokenizer, ignore_index=-100
) -> dict[str, List[int]]:
# iterate over the contents, tokenizing the concatenated string values up to the current MessageContents
# returns a dictionary mapping w input_ids, attention_mask, and labels
input_ids: List[int] = []
labels: List[int] = []
pending_input_ids: List[int] = []
pending_weight = self.weight
running_content = ""
for _, msg_content in enumerate(self.content):
# TODO also handle non-text content types
if msg_content.type in [
MessageContentTypes.text.value,
MessageContentTypes.tool_call.value,
MessageContentTypes.tool_response.value,
]:
running_content += str(msg_content)
tok_results = tokenizer(running_content, add_special_tokens=False)
tok_input_ids = tok_results["input_ids"]
if pending_input_ids:
new_pending_inputs = tok_input_ids[
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
pending_input_ids = tok_results["input_ids"][len(input_ids) :]
pending_weight = self.weight and msg_content.weight not in [0, 0.0]
input_ids.extend(pending_input_ids)
if pending_weight:
labels.extend(pending_input_ids)
else:
labels.extend([ignore_index] * len(pending_input_ids))
attention_mask = [1] * len(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class Chats(BaseModel):
"""
top level data structure for chat conversations
"""
conversation: List[Messages]
def __str__(self) -> str:
return "".join(str(c) for c in self.conversation)
def tokenized(
self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100
) -> dict[str, List[int]]:
input_ids = []
attention_mask = []
labels = []
for msg in self.conversation:
msg_results = msg.tokenized(tokenizer, ignore_index)
input_ids.extend(msg_results["input_ids"])
attention_mask.extend(msg_results["attention_mask"])
labels.extend(msg_results["labels"])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
class ChatFormattedChats(Chats):
"""
Chat formatted chats with formatter and optional train on inputs
"""
formatter: Callable # [[Union[dict, Chats]], Chats]
train_on_inputs: bool = False
def model_post_init(self, __context):
for i, msg in enumerate(self.conversation):
self.conversation[i] = self.formatter(msg, message_index=i)
if self.train_on_inputs:
self.conversation[i].weight = 1
class PreferenceChats(BaseModel):
"""
representation for preference data for chat
"""
prompt: List[Messages]
chosen: Messages
rejected: Messages

View File

@@ -1,55 +0,0 @@
"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union
from datasets import Dataset
from transformers import PreTrainedTokenizer
from axolotl.core.chat.messages import ChatFormattedChats
class TokenizedChatDataset(Dataset):
"""
Tokenized chat dataset
"""
def __init__(
self,
data: Dataset,
model_transform: Union[PreTrainedTokenizer, Callable],
*args,
message_transform: Optional[Callable] = None,
formatter=None,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
):
def map_fn(ex):
if message_transform is not None:
ex = message_transform(ex)
if formatter is not None:
ex = ChatFormattedChats(
formatter=formatter,
**ex,
)
else:
ex = ChatFormattedChats(
**ex,
)
return ex.tokenized(model_transform)
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,
num_proc=num_proc,
keep_in_memory=keep_in_memory,
remove_columns=features,
desc="Tokenizing Chats",
)
super().__init__(tokenized_data.data, *args, **kwargs)

View File

@@ -1,150 +0,0 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
train_on_inputs=False,
conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
message_field_content: Union[str, list[str]] = [
"value",
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
):
"""Builds a transform that takes a row from the dataset and converts it to a Chat
Args:
train_on_inputs (bool, optional):
If True, the transform will train on the inputs. If False, the transform will train on the targets.
Defaults to False.
conversations_field (str, optional):
The field name of the conversations. Defaults to "conversations".
message_field_role (str | list[str], optional):
The field name of the role. Defaults to "role".
message_field_content (str | list[str], optional):
The field name of the message content. Defaults to "content".
message_field_training (str | list[str], optional):
The field name of the train/weight. Defaults to "weight".
Returns:
Callable:
A function that takes a list of conversations and returns a list of messages.
"""
message_field_role = (
[message_field_role]
if isinstance(message_field_role, str)
else message_field_role
)
message_field_content = (
[message_field_content]
if isinstance(message_field_content, str)
else message_field_content
)
message_weight_fields = (
[message_field_training]
if isinstance(message_field_training, str)
else message_field_training
)
role_value_mappings = {
"system": "system",
"user": "user",
"human": "user",
"assistant": "assistant",
"gpt": "assistant",
"tool": "tool",
"ipython": "ipython",
}
if train_on_inputs:
role_default_weights_mappings = {
"system": 1,
"user": 1,
"assistant": 1,
"tool": 1,
"ipython": 1,
}
else:
role_default_weights_mappings = {
"system": 0,
"user": 0,
"assistant": 1,
"tool": 0,
"ipython": 0,
}
def transform_builder(sample: Mapping[str, Any]):
if conversations_field not in sample:
raise ValueError(f"Field '{conversations_field}' not found in sample.")
# if none of the role fields are in the message, raise an error
if not any(
role in sample[conversations_field][0] for role in message_field_role
):
raise ValueError("No role field found in message.")
role_field = next(
role
for role in message_field_role
if role in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_content
):
raise ValueError("No message_content field found in message.")
message_content_field = next(
field
for field in message_field_content
if field in sample[conversations_field][0]
)
if not any(
field in sample[conversations_field][0] for field in message_field_training
):
message_weight_field = None
else:
message_weight_field = next(
field
for field in message_weight_fields
if field in sample[conversations_field][0]
)
messages = []
for message in sample[conversations_field]:
role = role_value_mappings[message[role_field]]
weight = (
int(message[message_weight_field])
if message_weight_field
else role_default_weights_mappings[role]
)
# TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
if isinstance(message[message_content_field], str):
messages.append(
{
"role": role,
"content": [
{
"type": "text",
"value": message[message_content_field],
}
],
"weight": weight,
}
)
else:
messages.append(
{
"role": role,
"content": message[message_content_field],
"weight": weight,
}
)
return {"conversation": messages}
return transform_builder

View File

@@ -43,14 +43,12 @@ from trl import (
KTOTrainer, KTOTrainer,
ORPOConfig, ORPOConfig,
ORPOTrainer, ORPOTrainer,
RewardConfig,
RewardTrainer,
) )
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length from trl.trainer.utils import pad_to_length
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_mlflow_available
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
@@ -63,7 +61,7 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory, log_prediction_callback_factory,
) )
from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.collators import ( from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
@@ -303,13 +301,6 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
) )
@dataclass
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""
class SchedulerMixin(Trainer): class SchedulerMixin(Trainer):
""" """
Mixin class for scheduler setup in CausalTrainer. Mixin class for scheduler setup in CausalTrainer.
@@ -407,10 +398,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def __init__( def __init__(
self, self,
*_args, *_args,
num_epochs=1,
bench_data_collator=None, bench_data_collator=None,
eval_data_collator=None, eval_data_collator=None,
**kwargs, **kwargs,
): ):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs) super().__init__(*_args, **kwargs)
@@ -1046,14 +1039,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"] tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
Base class for trainer builder Base class for trainer builder
@@ -1126,12 +1111,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append( callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
return callbacks return callbacks
@@ -1200,11 +1179,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer, self.tokenizer, "mlflow" trainer, self.tokenizer, "mlflow"
) )
callbacks.append(LogPredictionCallback(self.cfg)) callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval: if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
@@ -1229,8 +1203,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return ReLoRATrainer return ReLoRATrainer
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
return AxolotlTrainer return AxolotlTrainer
def build(self, total_num_steps): def build(self, total_num_steps):
@@ -1458,16 +1430,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to.append("mlflow") report_to.append("mlflow")
if self.cfg.use_tensorboard: if self.cfg.use_tensorboard:
report_to.append("tensorboard") report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["report_to"] = report_to
if self.cfg.use_wandb: training_arguments_kwargs["run_name"] = (
training_arguments_kwargs["run_name"] = self.cfg.wandb_name self.cfg.wandb_name if self.cfg.use_wandb else None
elif self.cfg.use_mlflow: )
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
else:
training_arguments_kwargs["run_name"] = None
training_arguments_kwargs["optim"] = ( training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
) )
@@ -1556,7 +1523,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template: if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template( training_arguments_kwargs["chat_template"] = chat_templates(
self.cfg.chat_template self.cfg.chat_template
) )
@@ -1570,9 +1537,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.optimizer in [ if self.cfg.optimizer in [
"optimi_adamw", "optimi_adamw",
"ao_adamw_4bit", "ao_adamw_4bit",
@@ -1616,13 +1580,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config" "accelerator_config"
] = self.cfg.accelerator_config ] = self.cfg.accelerator_config
training_args_cls = ( training_args = (
AxolotlTrainingArguments AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
if not self.cfg.reward_model **training_arguments_kwargs,
else AxolotlRewardConfig )
)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
) )
training_args = self.hook_post_create_training_args(training_args) training_args = self.hook_post_create_training_args(training_args)
@@ -1644,24 +1605,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64 data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.reward_model:
data_collator_kwargs["max_length"] = self.cfg.sequence_len
trainer_cls = self._get_trainer_cls() trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer( trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls trainer_kwargs, trainer_cls
) )
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not self.cfg.reward_model:
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not self.cfg.reward_model:
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
trainer = trainer_cls( trainer = trainer_cls(
model=self.model, model=self.model,
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
@@ -1669,7 +1616,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
args=training_args, args=training_args,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs), data_collator=self.build_collator(training_args, **data_collator_kwargs),
eval_data_collator=self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs, **trainer_kwargs,
) )
trainer = self.hook_post_create_trainer(trainer) trainer = self.hook_post_create_trainer(trainer)
@@ -1703,12 +1659,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
RewardDataCollatorWithPadding,
] ]
] ]
if self.cfg.reward_model: if use_batch_sampler_collator:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq
elif ( elif (

View File

@@ -159,29 +159,6 @@ class BasePlugin:
List[callable]: A list of callback functions to be added to the TrainingArgs List[callable]: A list of callback functions to be added to the TrainingArgs
""" """
def post_train(self, cfg, model):
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
"""
def post_train_unload(self, cfg):
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def load_plugin(plugin_name: str) -> BasePlugin: def load_plugin(plugin_name: str) -> BasePlugin:
""" """
@@ -404,17 +381,3 @@ class PluginManager:
for plugin in self.plugins: for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks return callbacks
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_train_unload(cfg)

View File

@@ -1,13 +0,0 @@
# LM Eval Harness
### Usage
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
```

View File

@@ -1,42 +0,0 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
class LMEvalPlugin(BasePlugin):
"""
Plugin for LM Evaluation Harness integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -1,15 +0,0 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel
class LMEvalArgs(BaseModel):
"""
Input args for lm eval harness
"""
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8

View File

@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer( def reset_optimizer(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
*, *,
reset_params: List[str], # where str is the key to a torch.nn.Parameter reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str], optimizer_state_keys: list[str],
prune_ratio: float = 0.9, prune_ratio: float = 0.9,
): ):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)

View File

@@ -11,10 +11,6 @@ LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg, processor=None): def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
try: try:
if strategy == "messages":
from .messages import load as messages_load
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load" load_fn = "load"
if strategy.split(".")[-1].startswith("load_"): if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1] load_fn = strategy.split(".")[-1]
@@ -35,5 +31,4 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
return None return None
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}") LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc return None
return None

View File

@@ -1,10 +0,0 @@
### example yaml
```yaml
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
```

View File

@@ -1,35 +0,0 @@
"""Module to load prompt strategies."""
import importlib
import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
def load(strategy, tokenizer, cfg, ds_cfg):
# pylint: disable=duplicate-code
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
)
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
return None

View File

@@ -1,102 +0,0 @@
"""
Bradley-Terry model with chat template prompt strategy.
"""
import logging
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
class BTChatTemplateStrategy(ChatTemplateStrategy):
"""
Bradley-Terry reward model pairwise chat template prompt strategy.
"""
def tokenize_prompt(self, prompt):
"""
:param prompt: the actual row of data from the underlying dataset
:return:
"""
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)
return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
"labels_chosen": 1.0,
"input_ids_rejected": rejected_tokenized["input_ids"],
"attention_mask_rejected": rejected_tokenized["attention_mask"],
"labels_rejected": 0.0,
}
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = BTChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -1,27 +0,0 @@
"""
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
"""
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
prompt = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
return sample
return transform_fn

View File

@@ -9,7 +9,7 @@ from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import chat_templates
# Configure the logger # Configure the logger
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -403,16 +403,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {} ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = { prompter_params = {
"tokenizer": tokenizer, "tokenizer": tokenizer,
"chat_template": chat_template_string, "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "role"), "message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"), "message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None), "message_field_training": ds_cfg.get("message_field_training", None),

View File

@@ -2,16 +2,15 @@
DPO prompt strategies for using tokenizer chat templates. DPO prompt strategies for using tokenizer chat templates.
""" """
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template from axolotl.utils.chat_templates import chat_templates
def default( def default(
cfg, dataset_idx=0, **kwargs cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument ): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx] ds_cfg = cfg["datasets"][dataset_idx]
chat_template_choice, chat_template_jinja = extract_chat_template_args( chat_template_str = chat_templates(cfg.chat_template)
cfg=cfg, ds_cfg=ds_cfg
)
field_messages = ds_cfg.get("field_messages", "messages") field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen") field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected") field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -31,12 +30,6 @@ def default(
role_map[source] = target role_map[source] = target
def transform_fn(sample, tokenizer=None): def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
messages = sample[field_messages] messages = sample[field_messages]
messages = [ messages = [
{ {
@@ -53,29 +46,28 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]], "role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content], "content": sample[field_rejected][field_message_content],
} }
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {} result = {}
result["prompt"] = tokenizer.apply_chat_template( result["prompt"] = tokenizer.apply_chat_template(
messages, messages,
add_generation_prompt=True, add_generation_prompt=True,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
) )
result["chosen"] = tokenizer.apply_chat_template( result["chosen"] = tokenizer.apply_chat_template(
[dummy_user_message, chosen], [chosen],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
) )
chosen_strip_index = result["chosen"].find(chosen["content"]) chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template( result["rejected"] = tokenizer.apply_chat_template(
[dummy_user_message, rejected], [rejected],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
) )
rejected_strip_index = result["rejected"].find(rejected["content"]) rejected_strip_index = result["rejected"].find(rejected["content"])

View File

@@ -1,34 +0,0 @@
"""Module to load message prompt strategies."""
import importlib
import inspect
import logging
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
def load(tokenizer, cfg, ds_cfg, processor=None):
try:
strategy = ds_cfg.get("input_transform", "chat")
# pylint: disable=duplicate-code
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(
f".{strategy}", "axolotl.prompt_strategies.messages"
)
func = getattr(mod, load_fn)
load_kwargs = {}
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -1,84 +0,0 @@
"""
Chat dataset wrapping strategy for new internal messages representations
"""
from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset
from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
"""
Chat dataset wrapping strategy for new internal messages representations
"""
def __init__(
self,
processor,
message_transform=None,
formatter=None,
**kwargs, # pylint: disable=unused-argument
):
"""
:param processor: tokenizer or image processor
:param kwargs:
"""
self.processor = processor
self.dataset = None
self.message_transform = message_transform
self.formatter = formatter
def wrap_dataset(
self,
dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs, # pylint: disable=unused-argument
):
self.dataset = TokenizedChatDataset(
dataset,
message_transform=self.message_transform,
model_transform=self.processor,
formatter=self.formatter,
process_count=process_count,
keep_in_memory=keep_in_memory,
)
return self.dataset
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
field_messages = ds_cfg.get("field_messages")
message_field_role = ds_cfg.get("message_field_role")
message_field_content = ds_cfg.get("message_field_content")
message_field_training = ds_cfg.get("message_field_training")
builder_kwargs = {}
if field_messages:
builder_kwargs["conversations_field"] = field_messages
if message_field_role:
builder_kwargs["message_field_role"] = message_field_role
if message_field_content:
builder_kwargs["message_field_content"] = message_field_content
if message_field_training:
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):
from axolotl.core.chat.format.llama3x import format_message # noqa F811
message_transform: Callable = chat_message_transform_builder(
train_on_inputs=ds_cfg.get("train_on_inputs", False),
**builder_kwargs,
)
strategy = ChatMessageDatasetWrappingStrategy(
tokenizer, message_transform=message_transform, formatter=format_message
)
return strategy

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import chat_templates
class Message(BaseModel): class Message(BaseModel):
@@ -28,13 +28,18 @@ def load(
""" """
chatml transforms for datasets with system, input, chosen, rejected chatml transforms for datasets with system, input, chosen, rejected
""" """
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer chat_template = chat_templates("chatml")
) if ds_cfg and "chat_template" in ds_cfg:
tokenizer.chat_template = chat_template_string chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
return ORPOTokenizingStrategy( return ORPOTokenizingStrategy(
ORPOPrompter(chat_template_string, tokenizer), ORPOPrompter(chat_template, tokenizer),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -243,30 +248,28 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy() dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None): def transform_fn(sample, tokenizer=None):
res = {} res = {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, tokenizer=tokenizer
)
res["prompt"] = tokenizer.apply_chat_template( res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True, add_generation_prompt=True,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
) )
prompt_str_len = len(res["prompt"]) prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template( res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages], [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
)[prompt_str_len:] )[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template( res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages], [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_string, chat_template=chat_template_str,
tokenize=False, tokenize=False,
)[prompt_str_len:] )[prompt_str_len:]

View File

@@ -61,9 +61,6 @@ def build_loader(
default_conversation: Optional[str] = None, default_conversation: Optional[str] = None,
): ):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
)
conversation = ( conversation = (
ds_cfg["conversation"] ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg if ds_cfg and "conversation" in ds_cfg

View File

@@ -30,12 +30,6 @@ class InvalidDataException(Exception):
""" """
class DatasetWrappingStrategy(abc.ABC):
"""
Abstract class for wrapping datasets for Chat Messages
"""
class PromptTokenizingStrategy(abc.ABC): class PromptTokenizingStrategy(abc.ABC):
""" """
Abstract class for tokenizing strategies Abstract class for tokenizing strategies

View File

@@ -10,6 +10,7 @@ from typing import Optional, Tuple, Union
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
@@ -96,11 +97,12 @@ def train(
if cfg.adapter: if cfg.adapter:
msg += " and peft_config..." msg += " and peft_config..."
LOG.debug(msg) LOG.debug(msg)
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model( model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference cfg, tokenizer, processor=processor, inference=cli_args.inference
) )
if model.generation_config is not None: model.generation_config.do_sample = True
model.generation_config.do_sample = True
model_ref = None model_ref = None
if cfg.rl and cfg.rl != "orpo": if cfg.rl and cfg.rl != "orpo":

View File

@@ -1,12 +1,8 @@
""" """
Basic utils for Axolotl Basic utils for Axolotl
""" """
import importlib.util import importlib
def is_mlflow_available(): def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None return importlib.util.find_spec("mlflow") is not None
def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None

View File

@@ -29,7 +29,7 @@ from transformers import (
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
references=[[r] for r in references], references=[[r] for r in references],
predictions=predictions, predictions=predictions,
) )
scores["eval_" + metric_name] = score scores[metric_name] = score
return scores return scores
def predict_with_generate(): def predict_with_generate():
@@ -747,15 +747,6 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
artifact_file="PredictionsVsGroundTruth.json", artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri, tracking_uri=tracking_uri,
) )
elif logger == "comet_ml" and is_comet_available():
import comet_ml
experiment = comet_ml.get_running_experiment()
if experiment:
experiment.log_table(
f"{name} - Predictions vs Ground Truth.csv",
pd.DataFrame(table_data),
)
if is_main_process(): if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader) log_table_from_dataloader("Eval", eval_dataloader)

View File

@@ -1,43 +0,0 @@
"""Comet module for trainer callbacks"""
import logging
from typing import TYPE_CHECKING
import comet_ml
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
"""Callback to save axolotl config to comet"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
comet_experiment = comet_ml.start(source="axolotl")
comet_experiment.log_other("Created from", "axolotl")
comet_experiment.log_asset(
self.axolotl_config_path,
file_name="axolotl-config",
)
LOG.info(
"The Axolotl config has been saved to the Comet Experiment under assets."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
return control

File diff suppressed because one or more lines are too long

View File

@@ -4,7 +4,6 @@ Collators for multi-modal chat messages and packing
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers.data.data_collator import DataCollatorMixin from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
@@ -53,12 +52,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
) )
for example in examples for example in examples
] ]
images = [ images = [example["images"] for example in examples]
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
if max_images > 0: if max_images > 0:
images = [img_batch[:max_images] for img_batch in images] images = [img_batch[:max_images] for img_batch in images]

View File

@@ -1,93 +0,0 @@
"""Module for wandb utilities"""
import logging
import os
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.utils.comet_")
COMET_ENV_MAPPING_OVERRIDE = {
"comet_mode": "COMET_START_MODE",
"comet_online": "COMET_START_ONLINE",
}
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
"auto_log_co2": "COMET_AUTO_LOG_CO2",
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
"log_code": "COMET_AUTO_LOG_CODE",
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
"log_graph": "COMET_AUTO_LOG_GRAPH",
"name": "COMET_START_EXPERIMENT_NAME",
"offline_directory": "COMET_OFFLINE_DIRECTORY",
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
"tags": "COMET_START_EXPERIMENT_TAGS",
}
def python_value_to_environ_value(python_value):
if isinstance(python_value, bool):
if python_value is True:
return "true"
return "false"
if isinstance(python_value, int):
return str(python_value)
if isinstance(python_value, list): # Comet only have one list of string parameter
return ",".join(map(str, python_value))
return python_value
def setup_comet_env_vars(cfg: DictDefault):
# TODO, we need to convert Axolotl configuration to environment variables
# as Transformers integration are call first and would create an
# Experiment first
for key in cfg.keys():
if key.startswith("comet_") and key != "comet_experiment_config":
value = cfg.get(key, "")
if value is not None and value != "":
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
final_value = python_value_to_environ_value(value)
os.environ[env_variable_name] = final_value
if cfg.comet_experiment_config:
for key, value in cfg.comet_experiment_config.items():
if value is not None and value != "":
config_env_variable_name = (
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
)
if config_env_variable_name is None:
LOG.warning(
f"Unknown Comet Experiment Config name {key}, ignoring it"
)
continue
final_value = python_value_to_environ_value(value)
os.environ[config_env_variable_name] = final_value
# Enable comet if project name is present
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
cfg.use_comet = True

View File

@@ -228,7 +228,6 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template" f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
) )
cfg.datasets[idx].chat_template = cfg.chat_template cfg.datasets[idx].chat_template = cfg.chat_template
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):

View File

@@ -8,16 +8,9 @@ import logging
import os import os
from enum import Enum from enum import Enum
from importlib.metadata import version from importlib.metadata import version
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import ( from pydantic import BaseModel, Field, conlist, field_validator, model_validator
BaseModel,
Field,
StringConstraints,
conlist,
field_validator,
model_validator,
)
from transformers import SchedulerType from transformers import SchedulerType
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
@@ -28,37 +21,6 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel): class DeprecatedParameters(BaseModel):
"""configurations that are deprecated""" """configurations that are deprecated"""
@@ -140,22 +102,14 @@ class SFTDataset(BaseModel):
path: Optional[str] = None path: Optional[str] = None
split: Optional[str] = None split: Optional[str] = None
type: Optional[Union[str, UserDefinedPrompterType]] = None type: Optional[Union[str, UserDefinedPrompterType]] = None
input_transform: Optional[str] = None
shards: Optional[int] = None shards: Optional[int] = None
conversation: Optional[str] = None conversation: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class chat_template: Optional[str] = None
chat_template: Optional[
Union[
ChatTemplate,
str,
]
] = None
chat_template_jinja: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None data_files: Optional[Union[str, List[str]]] = None
input_format: Optional[str] = None
name: Optional[str] = None name: Optional[str] = None
ds_type: Optional[str] = None ds_type: Optional[str] = None
train_on_split: Optional[str] = None train_on_split: Optional[str] = None
field: Optional[str] = None field: Optional[str] = None
field_human: Optional[str] = None field_human: Optional[str] = None
field_model: Optional[str] = None field_model: Optional[str] = None
@@ -166,31 +120,11 @@ class SFTDataset(BaseModel):
message_field_training_detail: Optional[str] = None message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
class UserDefinedDPOType(BaseModel): class UserDefinedDPOType(BaseModel):
@@ -212,7 +146,6 @@ class DPODataset(BaseModel):
split: Optional[str] = None split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None data_files: Optional[List[str]] = None
revision: Optional[str] = None
class UserDefinedKTOType(BaseModel): class UserDefinedKTOType(BaseModel):
@@ -234,7 +167,32 @@ class KTODataset(BaseModel):
type: Optional[Union[UserDefinedKTOType, str]] = None type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
@@ -486,7 +444,6 @@ class MLFlowConfig(BaseModel):
use_mlflow: Optional[bool] = None use_mlflow: Optional[bool] = None
mlflow_tracking_uri: Optional[str] = None mlflow_tracking_uri: Optional[str] = None
mlflow_experiment_name: Optional[str] = None mlflow_experiment_name: Optional[str] = None
mlflow_run_name: Optional[str] = None
hf_mlflow_log_artifacts: Optional[bool] = None hf_mlflow_log_artifacts: Optional[bool] = None
@@ -532,19 +489,6 @@ class WandbConfig(BaseModel):
return data return data
class CometConfig(BaseModel):
"""Comet configuration subset"""
use_comet: Optional[bool] = None
comet_api_key: Optional[str] = None
comet_workspace: Optional[str] = None
comet_project_name: Optional[str] = None
comet_experiment_key: Optional[str] = None
comet_mode: Optional[str] = None
comet_online: Optional[bool] = None
comet_experiment_config: Optional[Dict[str, Any]] = None
class GradioConfig(BaseModel): class GradioConfig(BaseModel):
"""Gradio configuration subset""" """Gradio configuration subset"""
@@ -565,7 +509,6 @@ class AxolotlInputConfig(
HyperparametersConfig, HyperparametersConfig,
WandbConfig, WandbConfig,
MLFlowConfig, MLFlowConfig,
CometConfig,
LISAConfig, LISAConfig,
GradioConfig, GradioConfig,
RemappedParameters, RemappedParameters,
@@ -585,7 +528,6 @@ class AxolotlInputConfig(
resize_token_embeddings_to_32x: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None
rl: Optional[RLType] = None rl: Optional[RLType] = None
reward_model: Optional[bool] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
@@ -752,13 +694,7 @@ class AxolotlInputConfig(
gpu_memory_limit: Optional[Union[int, str]] = None gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None low_cpu_mem_usage: Optional[bool] = None
chat_template: Optional[ chat_template: Optional[ChatTemplate] = None
Union[
ChatTemplate,
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
]
] = None
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None fix_untrained_tokens: Optional[bool] = None
@@ -867,23 +803,6 @@ class AxolotlInputConfig(
return data return data
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sample_packing_wo_flash(cls, data): def check_sample_packing_wo_flash(cls, data):
@@ -914,17 +833,6 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def hint_reward_model_pad(cls, data):
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using reward_model"
)
if data.get("pad_to_sequence_len") is None:
data["pad_to_sequence_len"] = True
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_gas_bsz(cls, data): def check_gas_bsz(cls, data):
@@ -1058,26 +966,6 @@ class AxolotlInputConfig(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
) )
if data.get("do_bench_eval") and not (
data.get("evals_per_epoch") or data.get("eval_steps")
):
raise ValueError(
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
)
return data
@model_validator(mode="before")
@classmethod
def check_test_datasets_bench(cls, data):
if (
data.get("do_bench_eval")
and not data.get("test_datasets")
and not data.get("val_set_size")
):
LOG.warning(
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
)
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
return data return data
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -90,7 +90,6 @@ def load_prepare_dpo_datasets(cfg):
ds = load_dataset( # pylint: disable=invalid-name ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"], ds_cfg["path"],
split=ds_cfg["split"], split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
) )
split_datasets.insert(i, ds) split_datasets.insert(i, ds)

View File

@@ -19,12 +19,10 @@ from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies import load from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy, AlpacaReflectionPTStrategy,
DatasetWrappingStrategy,
GPTeacherPromptTokenizingStrategy, GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy,
@@ -244,7 +242,6 @@ def load_tokenized_prepared_datasets(
name=config_dataset.name, name=config_dataset.name,
streaming=True, streaming=True,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError): except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -349,7 +346,6 @@ def load_tokenized_prepared_datasets(
streaming=False, streaming=False,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
token=use_auth_token, token=use_auth_token,
revision=config_dataset.revision,
**load_ds_kwargs, **load_ds_kwargs,
) )
elif ds_from_cloud and remote_file_system: elif ds_from_cloud and remote_file_system:
@@ -384,7 +380,6 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=config_dataset.data_files, filename=config_dataset.data_files,
revision=config_dataset.revision,
) )
elif isinstance(config_dataset.data_files, list): elif isinstance(config_dataset.data_files, list):
fp = [] fp = []
@@ -394,7 +389,6 @@ def load_tokenized_prepared_datasets(
repo_id=config_dataset.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=file, filename=file,
revision=config_dataset.revision,
) )
) )
else: else:
@@ -439,8 +433,8 @@ def load_tokenized_prepared_datasets(
config_dataset=config_dataset, config_dataset=config_dataset,
tokenizer=tokenizer, tokenizer=tokenizer,
cfg=cfg, cfg=cfg,
d_base_type=d_base_type,
dataset=ds, dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style, d_prompt_style=d_prompt_style,
processor=processor, processor=processor,
) )
@@ -460,7 +454,7 @@ def load_tokenized_prepared_datasets(
else: else:
LOG.debug("NOT shuffling merged datasets") LOG.debug("NOT shuffling merged datasets")
if cfg.sample_packing and not cfg.skip_prepare_dataset: if not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None) dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
@@ -575,7 +569,7 @@ def get_dataset_wrapper(
d_base_type, d_base_type,
dataset, dataset,
d_prompt_style=None, d_prompt_style=None,
processor=None, # pylint: disable=unused-argument processor=None,
): ):
dataset_wrapper = None dataset_wrapper = None
dataset_prompter = None dataset_prompter = None
@@ -610,10 +604,8 @@ def get_dataset_wrapper(
) )
elif cfg.skip_prepare_dataset: elif cfg.skip_prepare_dataset:
dataset_wrapper = dataset dataset_wrapper = dataset
elif ds_strategy := config_dataset.type.startswith( elif ds_strategy := load(
"bradley_terry" config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
) and bradley_terry_load(
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
): ):
dataset_prompter = UnsupportedPrompter() dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset( dataset_wrapper = TokenizedPromptDataset(
@@ -621,18 +613,6 @@ def get_dataset_wrapper(
dataset, dataset,
**ds_kwargs, **ds_kwargs,
) )
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
):
if isinstance(ds_strategy, DatasetWrappingStrategy):
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
else:
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
elif d_base_type == "alpaca": elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style) dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy( ds_strategy = AlpacaPromptTokenizingStrategy(

View File

@@ -50,7 +50,7 @@ from axolotl.monkeypatch.multipack import (
) )
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only from axolotl.utils.distributed import zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
@@ -293,10 +293,7 @@ def load_tokenizer(cfg):
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template: if cfg.chat_template:
chat_template_string = get_chat_template_from_config( chat_template_string = chat_templates(cfg.chat_template)
cfg=cfg,
tokenizer=tokenizer,
)
if cfg.default_system_message and cfg.chat_template == "chatml": if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace( chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message "You are a helpful assistant.", cfg.default_system_message
@@ -1117,7 +1114,8 @@ def load_lora(model, cfg, inference=False, config_only=False):
fan_in_fan_out=cfg.lora_fan_in_fan_out, fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none", bias="none",
task_type="CAUSAL_LM", # task_type="CAUSAL_LM",
task_type="CONDITIONAL_GENERATION" if cfg.is_multimodal else "CAUSAL_LM",
**lora_config_kwargs, **lora_config_kwargs,
) )

View File

@@ -11,7 +11,7 @@ import numpy as np
import torch import torch
import torch.cuda import torch.cuda
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import disable_caching, enable_caching from datasets import set_caching_enabled
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
@@ -87,10 +87,10 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
@contextmanager @contextmanager
def disable_datasets_caching(): def disable_datasets_caching():
try: try:
disable_caching() set_caching_enabled(False)
yield yield
finally: finally:
enable_caching() set_caching_enabled(True)
def add_position_ids(sample): def add_position_ids(sample):
@@ -306,11 +306,7 @@ def process_pretraining_datasets_for_packing(
def calculate_total_num_steps(cfg, train_dataset, update=True): def calculate_total_num_steps(cfg, train_dataset, update=True):
if ( if not cfg.total_num_tokens and not cfg.skip_prepare_dataset:
not cfg.total_num_tokens
and not cfg.skip_prepare_dataset
and not cfg.reward_model
):
total_num_tokens = np.sum( total_num_tokens = np.sum(
train_dataset.data.column("input_ids") train_dataset.data.column("input_ids")
.to_pandas() .to_pandas()
@@ -327,7 +323,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
not skip_estimates not skip_estimates
and not cfg.total_supervised_tokens and not cfg.total_supervised_tokens
and not cfg.skip_prepare_dataset and not cfg.skip_prepare_dataset
and not cfg.reward_model
): ):
total_supervised_tokens = ( total_supervised_tokens = (
train_dataset.data.column("labels") train_dataset.data.column("labels")

View File

@@ -1,197 +0,0 @@
"""
Tests for the chat messages module
"""
import unittest
import pytest
from transformers import AddedToken, AutoTokenizer
from axolotl.core.chat.format.chatml import format_message
from axolotl.core.chat.messages import ChatFormattedChats, Chats
@pytest.fixture(scope="session", name="llama_tokenizer")
def llama_tokenizer_fixture():
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
@pytest.fixture(scope="session", name="chatml_tokenizer")
def llama_tokenizer_w_chatml(llama_tokenizer):
llama_tokenizer.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
llama_tokenizer.add_tokens(
[
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
]
)
return llama_tokenizer
@pytest.fixture(scope="session", name="chat_msgs")
def chat_msgs_fixture():
return {
"conversation": [
{
"role": "system",
"content": [
{"type": "text", "value": "You are a helpful assistant."},
],
},
{
"role": "user",
"content": [
{"type": "text", "value": "What is today's stock price of Apple?"},
],
},
{
"role": "assistant",
"content": [
{
"type": "tool_call",
"value": {
"name": "get_date",
"arguments": {},
},
},
{
"type": "tool_call",
"value": {
"name": "get_stock_price",
"arguments": {"symbol": "AAPL"},
},
},
],
"weight": 1,
},
{
"role": "tool",
"content": [
{
"type": "tool_response",
"value": {
"name": "get_date",
"content": {"date": "2024-09-09"},
},
},
{
"type": "tool_response",
"value": {
"name": "get_stock_price",
"content": {"symbol": "AAPL", "price": 123.45},
},
},
],
},
{
"role": "assistant",
"content": [
{
"type": "text",
"value": "The stock price of Apple is $123.45.\n",
"weight": 0,
},
{
"type": "text",
"value": "<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>",
},
{
"type": "text",
"value": "The stock price of Apple on September 9, 2024 is $123.45.",
},
],
"weight": 1,
},
]
}
class TestMessagesCase:
"""
Test cases for the chat messages module
"""
def test_tool_call_stringify(self, chat_msgs):
chat_msgs_as_obj = Chats(**chat_msgs)
assert '{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}' == str(
chat_msgs_as_obj.conversation[2].content[1].value
)
def test_chatml_formatted_wrapper(self, chat_msgs):
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
target_chatml = """<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
What is today's stock price of Apple?<|im_end|>
<|im_start|>assistant
<tool_call>
{"name": "get_date", "arguments": {}}
</tool_call>
<tool_call>
{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}
</tool_call>
<|im_end|>
<|im_start|>tool
<tool_response>
{"name": "get_date", "content": {"date": "2024-09-09"}}
</tool_response>
<tool_response>
{"name": "get_stock_price", "content": {"symbol": "AAPL", "price": 123.45}}
</tool_response>
<|im_end|>
<|im_start|>assistant
The stock price of Apple is $123.45.
<reflection>The original query asked for today's stock price of Apple. This implies they also wanted the date included in the response.</reflection>The stock price of Apple on September 9, 2024 is $123.45.<|im_end|>\n"""
assert target_chatml == str(chat_msg_formatted)
def test_chatml_formatting_tool_call(self, chat_msgs):
chat_msgs_as_obj = Chats(**chat_msgs)
target_chatml_turn2 = """<|im_start|>assistant\n<tool_call>\n{"name": "get_date", "arguments": {}}\n</tool_call>\n<tool_call>\n{"name": "get_stock_price", "arguments": {"symbol": "AAPL"}}\n</tool_call>\n<|im_end|>\n"""
assert target_chatml_turn2 == str(
format_message(chat_msgs_as_obj.conversation[2])
)
def test_train_labels(self, chatml_tokenizer, chat_msgs):
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
tokenized = chat_msg_formatted.conversation[2].tokenized(chatml_tokenizer)
# fmt: off
target_labels = [
-100, -100, -100, # role
27, 14506, 13735, 397, 5018, 609, 794,
330, 456, 4257, 498, 330, 16774, 794, 4792, 534, 524,
14506, 13735, 397, 27, 14506, 13735, 397, 5018, 609, 794,
330, 456, 31641, 9217, 498, 330, 16774, 794, 5324, 19314,
794, 330, 84016, 43, 96742, 524, 14506, 13735, 397,
128256, # <|im_end|>
-100 # trailing newline
]
# fmt: on
assert tokenized["labels"] == target_labels
def test_train_labels_2(self, chatml_tokenizer, chat_msgs):
# also test if indivudal contents are set not to train
chat_msg_formatted = ChatFormattedChats(**chat_msgs, formatter=format_message)
tokenized = chat_msg_formatted.conversation[4].tokenized(chatml_tokenizer)
# fmt: off
target_labels = [
-100, -100, -100, # role
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # initial response
27, 78098, 16761, 4113, 3319, 4691, 369, 3432, 596, 5708, 3430,
315, 8325, 13, 1115, 24897, 814, 1101, 4934, 279, 2457,
5343, 304, 279, 2077, 4005, 78098, 16761, 5708, 3430, 315,
8325, 389, 6250, 220, 24, 11, 220, 2366, 19, 374, 400,
4513, 13, 1774, 13,
128256, # <|im_end|>
-100, # trailing newline
]
# fmt: on
assert tokenized["labels"] == target_labels
if __name__ == "__main__":
unittest.main()

View File

@@ -19,8 +19,6 @@ from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_model(): def download_model():
@@ -348,115 +346,3 @@ class TestMultiGPULlama(unittest.TestCase):
str(Path(temp_dir) / "config.yaml"), str(Path(temp_dir) / "config.yaml"),
] ]
) )
@with_temp_dir
def test_ds_zero3_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_ds_zero3_qlora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero3_bf16.json"),
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -1,74 +0,0 @@
"""
E2E tests for reward model lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestRewardModelLoraLlama(unittest.TestCase):
"""
Test case for Llama reward models using LoRA
"""
@with_temp_dir
def test_rm_fft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"model_type": "AutoModelForSequenceClassification",
"tokenizer_type": "LlamaTokenizer",
"chat_template": "alpaca",
"reward_model": True,
"sequence_len": 1024,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "bradley_terry.chat_template",
},
],
"remove_unused_columns": False,
"max_steps": 10,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -1,62 +0,0 @@
"""
tests for chat_template prompt strategy
"""
# pylint: disable=duplicate-code
import logging
import unittest
from axolotl.prompt_strategies.messages.chat import load
from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
class TestMessagesChatLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy.
"""
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
LOG.info("Loading llama-3 tokenizer with assistant dataset")
strategy = load(
llama3_tokenizer,
DictDefault(
{
"train_on_inputs": False,
"sequence_len": 512,
}
),
DictDefault(
{
"chat_template": "llama3",
"message_field_role": "role",
"message_field_content": "content",
"field_messages": "messages",
}
),
)
res = strategy.wrap_dataset(assistant_dataset)
input_ids = res[0]["input_ids"]
# fmt: off
expected_input_ids = [
128000, # bos
128006, 882, 128007, # user header
271, 15339, 128009, # user prompt eot
128006, 78191, 128007, # assistant header
271, 15339, 128009, # assistant response eot
128006, 882, 128007,
271, 19045, 29474, 128009,
128006, 78191, 128007,
271, 19045, 29474, 128009,
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
if __name__ == "__main__":
unittest.main()

View File

@@ -1,125 +0,0 @@
"""
Tests for utils in axolotl.utils.chat_templates
"""
import unittest
import pytest
from transformers import AutoTokenizer
from axolotl.utils.chat_templates import (
_CHAT_TEMPLATES,
extract_chat_template_args,
get_chat_template,
)
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
return tokenizer
class TestGetChatTemplateUtils:
"""
Tests the get_chat_template function.
"""
def test_known_chat_template(self):
chat_template_str = get_chat_template("llama3")
assert chat_template_str == _CHAT_TEMPLATES["llama3"]
def test_invalid_chat_template(self):
with pytest.raises(ValueError) as exc:
get_chat_template("invalid_template")
assert str(exc) == "Template 'invalid_template' not found."
def test_tokenizer_default_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=None)
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_tokenizer_default_fallback_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
self, llama3_tokenizer
):
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == get_chat_template("chatml")
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
self, llama3_tokenizer
):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_jinja_template_mode(self):
jinja_template = "example_jinja_template"
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
assert chat_template_str == jinja_template
def test_jinja_template_mode_no_jinja_template(self):
with pytest.raises(ValueError):
get_chat_template("jinja", jinja_template=None)
def test_extract_chat_template_args(self):
# No ds_cfg
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml"},
)
assert chat_template_choice == "chatml"
assert chat_template_jinja is None
# ds_cfg provided
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
)
assert chat_template_choice == "llama3"
assert chat_template_jinja is None
# ds_cfg provided with jinja template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml", "chat_template_jinja": None},
ds_cfg={
"chat_template": "jinja",
"chat_template_jinja": "ds_jinja_template",
},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "ds_jinja_template"
# ds_cfg provided with no chat_template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "global_jinja_template"
if __name__ == "__main__":
unittest.main()

View File

@@ -11,7 +11,7 @@ from axolotl.prompt_strategies.chat_template import (
load, load,
) )
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=get_chat_template("llama3"), chat_template=chat_templates("llama3"),
message_field_role="role", message_field_role="role",
message_field_content="content", message_field_content="content",
roles={ roles={
@@ -113,7 +113,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
phi35_tokenizer, phi35_tokenizer,
chat_template=get_chat_template("phi_35"), chat_template=chat_templates("phi_35"),
message_field_role="role", message_field_role="role",
message_field_content="content", message_field_content="content",
roles={ roles={
@@ -171,7 +171,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=get_chat_template("llama3"), chat_template=chat_templates("llama3"),
message_field_role="role", message_field_role="role",
message_field_content="content", message_field_content="content",
message_field_training="training", message_field_training="training",
@@ -230,7 +230,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -283,7 +283,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -336,7 +336,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,

View File

@@ -12,7 +12,7 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy, ChatTemplateStrategy,
) )
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.chat_templates import chat_templates
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -35,7 +35,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=True") LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=True, train_on_inputs=True,
@@ -80,7 +80,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=False") LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -123,7 +123,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with assistant only") LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -151,7 +151,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with all roles") LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=True, train_on_inputs=True,
@@ -184,7 +184,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with empty roles_to_train") LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -205,7 +205,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='all'") LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -232,7 +232,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='turn'") LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -282,7 +282,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='last'") LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -315,7 +315,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='none'") LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3") llama3_tokenizer, chat_template=chat_templates("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -343,7 +343,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=get_chat_template("llama3"), chat_template=chat_templates("llama3"),
drop_system_message=True, drop_system_message=True,
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
@@ -371,7 +371,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=get_chat_template("llama3"), chat_template=chat_templates("llama3"),
roles=custom_roles, roles=custom_roles,
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
@@ -424,7 +424,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=get_chat_template("llama3"), chat_template=chat_templates("llama3"),
message_field_training="train", message_field_training="train",
message_field_training_detail="train_detail", message_field_training_detail="train_detail",
), ),

View File

@@ -86,20 +86,6 @@ def fixture_llama3_tokenizer():
return tokenizer return tokenizer
@pytest.fixture(name="phi3_tokenizer")
def fixture_phi3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
return tokenizer
@pytest.fixture(name="gemma_tokenizer")
def fixture_gemma_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
return tokenizer
class TestAssistantDPOChatTemplateLlama3: class TestAssistantDPOChatTemplateLlama3:
""" """
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
@@ -113,7 +99,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3", "chat_template": "llama3",
"datasets": [ "datasets": [
{ {
"type": "chat_template", "chat_template": "llama3",
} }
], ],
} }
@@ -138,7 +124,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3", "chat_template": "llama3",
"datasets": [ "datasets": [
{ {
"type": "chat_template", "chat_template": "llama3",
"field_messages": "conversation", "field_messages": "conversation",
"field_chosen": "better", "field_chosen": "better",
"field_rejected": "worse", "field_rejected": "worse",
@@ -166,65 +152,5 @@ class TestAssistantDPOChatTemplateLlama3:
assert result["rejected"] == "party on<|eot_id|>" assert result["rejected"] == "party on<|eot_id|>"
class TestAssistantDPOChatTemplatePhi3:
"""
Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
"""
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
assert result["prompt"] == (
"<|user|>\nhello<|end|>\n"
+ "<|assistant|>\nhello<|end|>\n"
+ "<|user|>\ngoodbye<|end|>\n"
+ "<|assistant|>\n"
)
assert result["chosen"] == "goodbye<|end|>"
assert result["rejected"] == "party on<|end|>"
class TestAssistantDPOChatTemplateGemma:
"""
Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
"""
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
assert result["prompt"] == (
"<bos><start_of_turn>user\nhello<end_of_turn>\n"
+ "<start_of_turn>model\nhello<end_of_turn>\n"
+ "<start_of_turn>user\ngoodbye<end_of_turn>\n"
+ "<start_of_turn>model\n"
)
assert result["chosen"] == "goodbye<end_of_turn>"
assert result["rejected"] == "party on<end_of_turn>"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -12,7 +12,6 @@ from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -268,143 +267,6 @@ class TestDatasetPreparation(unittest.TestCase):
assert "attention_mask" in dataset.features assert "attention_mask" in dataset.features
assert "labels" in dataset.features assert "labels" in dataset.features
def test_load_hub_with_dpo(self):
"""Verify that processing dpo data from the hub works"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
def test_load_hub_with_revision(self):
"""Verify that processing data from the hub works with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
def test_load_hub_with_revision_with_dpo(self):
"""Verify that processing dpo data from the hub works with a specific revision"""
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"chat_template": "llama3",
"revision": "ea82cff",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
}
],
}
)
train_dataset, _ = load_prepare_dpo_datasets(cfg)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -9,7 +9,6 @@ from typing import Optional
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -1330,105 +1329,3 @@ class TestValidationWandb(BaseValidation):
os.environ.pop("WANDB_PROJECT", None) os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_DISABLED", None) os.environ.pop("WANDB_DISABLED", None)
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
class TestValidationComet(BaseValidation):
"""
Validation test for comet
"""
def test_comet_sets_env(self, minimal_cfg):
from axolotl.utils.comet_ import setup_comet_env_vars
comet_config = {
"comet_api_key": "foo",
"comet_workspace": "some_workspace",
"comet_project_name": "some_project",
"comet_experiment_key": "some_experiment_key",
"comet_mode": "get_or_create",
"comet_online": False,
"comet_experiment_config": {
"auto_histogram_activation_logging": False,
"auto_histogram_epoch_rate": 2,
"auto_histogram_gradient_logging": True,
"auto_histogram_tensorboard_logging": False,
"auto_histogram_weight_logging": True,
"auto_log_co2": False,
"auto_metric_logging": True,
"auto_metric_step_rate": 15,
"auto_output_logging": False,
"auto_param_logging": True,
"comet_disabled": False,
"display_summary_level": 2,
"distributed_node_identifier": "some_distributed_node_identifier",
"log_code": True,
"log_env_cpu": False,
"log_env_details": True,
"log_env_disk": False,
"log_env_gpu": True,
"log_env_host": False,
"log_env_network": True,
"log_git_metadata": False,
"log_git_patch": True,
"log_graph": False,
"name": "some_name",
"offline_directory": "some_offline_directory",
"parse_args": True,
"tags": ["tag1", "tag2"],
},
}
cfg = DictDefault(comet_config) | minimal_cfg
new_cfg = validate_config(cfg)
setup_comet_env_vars(new_cfg)
comet_env = {
key: value for key, value in os.environ.items() if key.startswith("COMET_")
}
assert (
len(comet_env)
== len(comet_config) + len(comet_config["comet_experiment_config"]) - 1
)
assert comet_env == {
"COMET_API_KEY": "foo",
"COMET_AUTO_LOG_CLI_ARGUMENTS": "true",
"COMET_AUTO_LOG_CO2": "false",
"COMET_AUTO_LOG_CODE": "true",
"COMET_AUTO_LOG_DISABLE": "false",
"COMET_AUTO_LOG_ENV_CPU": "false",
"COMET_AUTO_LOG_ENV_DETAILS": "true",
"COMET_AUTO_LOG_ENV_DISK": "false",
"COMET_AUTO_LOG_ENV_GPU": "true",
"COMET_AUTO_LOG_ENV_HOST": "false",
"COMET_AUTO_LOG_ENV_NETWORK": "true",
"COMET_AUTO_LOG_GIT_METADATA": "false",
"COMET_AUTO_LOG_GIT_PATCH": "true",
"COMET_AUTO_LOG_GRAPH": "false",
"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false",
"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2",
"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true",
"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false",
"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true",
"COMET_AUTO_LOG_METRIC_STEP_RATE": "15",
"COMET_AUTO_LOG_METRICS": "true",
"COMET_AUTO_LOG_OUTPUT_LOGGER": "false",
"COMET_AUTO_LOG_PARAMETERS": "true",
"COMET_DISPLAY_SUMMARY_LEVEL": "2",
"COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier",
"COMET_EXPERIMENT_KEY": "some_experiment_key",
"COMET_OFFLINE_DIRECTORY": "some_offline_directory",
"COMET_PROJECT_NAME": "some_project",
"COMET_START_EXPERIMENT_NAME": "some_name",
"COMET_START_EXPERIMENT_TAGS": "tag1,tag2",
"COMET_START_MODE": "get_or_create",
"COMET_START_ONLINE": "false",
"COMET_WORKSPACE": "some_workspace",
}
for key in comet_env.keys():
os.environ.pop(key, None)

View File

@@ -1,238 +0,0 @@
"""Module for testing the validation module for the dataset config"""
import warnings
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
from axolotl.utils.dict import DictDefault
warnings.filterwarnings("error")
@pytest.fixture(name="minimal_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
# pylint: disable=too-many-public-methods (duplicate-code)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
class TestValidationCheckDatasetConfig(BaseValidation):
"""
Test the validation for the dataset config to ensure no correct parameters are dropped
"""
def test_dataset_config_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "sharegpt",
"conversation": "chatml",
"shards": 10,
}
]
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template is None
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == ChatTemplate.chatml
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"chat_template": "gemma",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == cfg.chat_template
assert (
checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()