diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 2d3c209cc..5b5cc5489 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python3 -m pip install jupyter quartodoc - python3 -m pip install -e . --no-deps + python3 -m pip install -e . - name: Build autodoc run: quartodoc build - name: Publish to GitHub Pages (and render) diff --git a/.github/workflows/preview-docs.yml b/.github/workflows/preview-docs.yml index 5af70b0dc..f93cfa660 100644 --- a/.github/workflows/preview-docs.yml +++ b/.github/workflows/preview-docs.yml @@ -8,7 +8,9 @@ on: paths: - '**/*.md' # any Markdown file - '**/*.qmd' # any Quarto file - - '_quarto.yaml' + - '_quarto.yml' + - docs/scripts/generate_config_docs.py + - src/axolotl/utils/schemas/**.py permissions: checks: write @@ -38,7 +40,7 @@ jobs: - name: Install dependencies run: | python3 -m pip install jupyter quartodoc - python3 -m pip install -e . --no-deps + python3 -m pip install -e . - name: Build autodoc run: quartodoc build diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e519314b3..921b5dbf6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -187,7 +187,7 @@ jobs: if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] - timeout-minutes: 90 + timeout-minutes: 120 needs: [pre-commit, pytest, pytest-sdist] strategy: @@ -245,7 +245,7 @@ jobs: if: github.repository_owner == 'axolotl-ai-cloud' # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] - timeout-minutes: 90 + timeout-minutes: 120 # Only run the remainder of the matrix if the first e2e check passed; # this is to save on wasted compute costs for known failures that get caught in the first run needs: [pre-commit, pytest, docker-e2e-tests-1st] diff --git a/.runpod/README.md b/.runpod/README.md index a631c3937..60c661eef 100644 --- a/.runpod/README.md +++ b/.runpod/README.md @@ -328,7 +328,7 @@ The following optimizers are supported: - Use `gradient_checkpointing: true` to reduce memory usage - Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory -For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html). +For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config-reference.html). ### Errors: diff --git a/README.md b/README.md index ef5523898..3bfce8df1 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge ## 📚 Documentation - [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments -- [Configuration Guide](https://docs.axolotl.ai/docs/config.html) - Full configuration options and examples +- [Configuration Guide](https://docs.axolotl.ai/docs/config-reference.html) - Full configuration options and examples - [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources - [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them - [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) diff --git a/_quarto.yml b/_quarto.yml index 9b97095ce..93141aa9e 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -1,5 +1,6 @@ project: type: website + pre-render: docs/scripts/generate_config_docs.py quartodoc: dir: docs/api @@ -235,7 +236,7 @@ website: - docs/installation.qmd - docs/inference.qmd - docs/cli.qmd - - docs/config.qmd + - docs/config-reference.qmd - text: "API Reference" href: docs/api diff --git a/cicd/e2e_tests.py b/cicd/e2e_tests.py index 610e3730d..cb8020c68 100644 --- a/cicd/e2e_tests.py +++ b/cicd/e2e_tests.py @@ -8,7 +8,7 @@ from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd @app.function( image=cicd_image, gpu=GPU_CONFIG, - timeout=90 * 60, # 90 min + timeout=120 * 60, # 90 min cpu=8.0, memory=131072, volumes=VOLUME_CONFIG, diff --git a/cicd/multigpu.py b/cicd/multigpu.py index f028d0f68..2b66f21f9 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -69,7 +69,7 @@ def run_cmd(cmd: str, run_folder: str): @app.function( image=cicd_image, gpu=GPU_CONFIG, - timeout=90 * 60, + timeout=120 * 60, cpu=16.0, memory=131072 * N_GPUS, volumes=VOLUME_CONFIG, diff --git a/deepspeed_configs/zero2_torch_compile.json b/deepspeed_configs/zero2_torch_compile.json new file mode 100644 index 000000000..c3bcf98cf --- /dev/null +++ b/deepspeed_configs/zero2_torch_compile.json @@ -0,0 +1,31 @@ +{ + "compile": { + "disable": false, + "backend": "inductor" + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + }, + "contiguous_gradients": true, + "overlap_comm": true + }, + "bf16": { + "enabled": "auto" + }, + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 32, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/docs/.gitignore b/docs/.gitignore index 6c3cb2070..89407326f 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -2,3 +2,4 @@ _site/ /api/*.qmd /api/*.html +config-reference.qmd diff --git a/docs/config.qmd b/docs/config.qmd deleted file mode 100644 index d146b4c84..000000000 --- a/docs/config.qmd +++ /dev/null @@ -1,801 +0,0 @@ ---- -title: Config Reference -description: A complete list of all configuration options. ---- - -```yaml -# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files -# This can also be a relative path to a model on disk -base_model: ./llama-7b-hf -# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc) -base_model_ignore_patterns: -# If the base_model repo on hf hub doesn't include configuration .json files, -# You can set that here, or leave this empty to default to base_model -base_model_config: ./llama-7b-hf -# You can specify to choose a specific model revision from huggingface hub -revision_of_model: -# Optional tokenizer configuration path in case you want to use a different tokenizer -# than the one defined in the base model -tokenizer_config: -# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too -model_type: AutoModelForCausalLM -# Corresponding tokenizer for the model AutoTokenizer is a good choice -tokenizer_type: AutoTokenizer -# Trust remote code for untrusted source -trust_remote_code: -# use_fast option for tokenizer loading from_pretrained, default to True -tokenizer_use_fast: -# Whether to use the legacy tokenizer setting, defaults to True -tokenizer_legacy: -# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer. -tokenizer_use_mistral_common: -# Resize the model embeddings when new tokens are added to multiples of 32 -# This is reported to improve training speed on some models -resize_token_embeddings_to_32x: -# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink. -shrink_embeddings: -# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs -embeddings_skip_upcast: -# Whether to load the model with randomly initialized weights. Useful for -# pre-training a model from scratch or debugging purposes. -random_init_weights: - -# (Internal use only) -# Used to identify which the model is based on -is_falcon_derived_model: -is_llama_derived_model: -is_qwen_derived_model: -# Please note that if you set this to true, `padding_side` will be set to "left" by default -is_mistral_derived_model: - -# optional overrides to the base model configuration -overrides_of_model_config: - # RoPE Scaling https://github.com/huggingface/transformers/pull/24653 - rope_scaling: - type: # linear | dynamic - factor: # float - -# optional overrides the base model loading from_pretrained -overrides_of_model_kwargs: - # use_cache: False - -# optional overrides to the bnb 4bit quantization configuration -# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig -bnb_config_kwargs: - # These are default values - llm_int8_has_fp16_weight: false - bnb_4bit_quant_type: nf4 - bnb_4bit_use_double_quant: true - -# quantization aware training -qat: - activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" - weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8" - group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization - fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after - -# post-training quantization -quantization: - weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8 - activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8" - group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization - quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer. - - -# Whether you are training a 4-bit GPTQ quantized model -gptq: true - -# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer -load_in_8bit: true -# Use bitsandbytes 4 bit -load_in_4bit: - -# Use CUDA bf16 -bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere -# Use CUDA fp16 -fp16: true -# Use CUDA tf32 -tf32: true # require >=ampere -# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting - -# No AMP (automatic mixed precision) -bfloat16: true # require >=ampere -float16: true - -# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset -gpu_memory_limit: 20GiB -# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge -lora_on_cpu: true - -# List[str]. Add plugins to extend the pipeline. -# See `src/axolotl/integrations` for the available plugins or doc below for more details. -# https://docs.axolotl.ai/docs/custom_integrations.html -plugins: - # - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin - -# A list of one or more datasets to finetune the model with -# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets -# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats -datasets: - # HuggingFace dataset repo | s3:// | gs:// | path to local file or directory - - path: vicgalle/alpaca-gpt4 - # The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection] - type: alpaca # format | format: (chat/instruct) | .load_ - ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file - data_files: # Optional[str] path to source data files - - shards: # Optional[int] split dataset into N pieces (use with shards_idx) - shards_idx: # Optional[int] = 0 the index of sharded dataset to use - - preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`) - - name: # Optional[str] name of dataset configuration to load - split: train # Optional[str] name of dataset split to load from - revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets. - trust_remote_code: # Optional[bool] Trust remote code for untrusted source - - # Custom user instruction prompt - - path: repo - type: - # The below are defaults. only set what's needed if you use a different column name. - system_prompt: "" - system_format: "{system}" - field_system: system - field_instruction: instruction - field_input: input - field_output: output - - # Customizable to be single line or multi-line - # Use {instruction}/{input} as key to be replaced - # 'format' can include {input} - format: |- - User: {instruction} {input} - Assistant: - # 'no_input_format' cannot include {input} - no_input_format: "{instruction} " - - # For `completion` datsets only, uses the provided field instead of `text` column - field: - - # Using chat template - - path: ... - # Set type to `chat_template` to use this strategy - type: chat_template - # Specify the name of the chat template to use - # The name of the chat template to use for training, following values are supported: - # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. - # - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py - # - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. - # - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. - chat_template: tokenizer_default - - # Custom jinja chat template. Used only if `chat_template: jinja` or empty. - chat_template_jinja: - - # Key containing the messages (default: "messages") - field_messages: messages - - # Key containing the tools (default: "tools") - # Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step). - field_tools: tools - - # Key containing the system message (default: "system") - # If the system message is not present in the dataset sample, it will be loaded from the field_system property. - field_system: system - - # Mapping of properties from the input dataset to the chat template. - # (default: message_property_mappings={'role':'role', 'content':'content'}) - # If a property exists in the template but not in this mapping, the system will attempt - # to load it directly from the message using the property name as the key. - # Example: In the mapping below, 'from' is loaded from input dataset and used as 'role', - # while 'value' is loaded and used as 'content' in the chat template. - message_property_mappings: - role: from - content: value - # ... - - # Optional[Dict[str, List]]. Roles mapping in the messages. - # The format is {target_role: [source_roles]}. All source roles will be mapped to the target role. - # The default is: - roles: - user: ["human", "user"] - assistant: ["gpt", "assistant"] - system: ["system"] - tool: ["tool"] - - # Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template. - # This does not drop the default system message from chat_template if it exists. If you wish to, - # we recommend using a custom jinja template with the default system message removed or - # adding a system turn with empty content. - drop_system_message: - - # Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags - # See example at `docs/dataset-formats/conversation.qmd` - split_thinking: - - # IMPORTANT: The following fields determine which parts of the conversation to train on. - # Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train - # See examples at `docs/dataset-formats/conversation.qmd` - # Note: If the below 5 fields are empty, defaults to training only on the last message. - - # Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss. - roles_to_train: ["assistant"] # default - # Optional[str]. Which EOS tokens to train on in the conversation. Possible values are: - # - all: train on all EOS tokens - # - turn (default): train on the EOS token at the end of each trainable turn - # - last: train on the last EOS token in the conversation - # TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`. - train_on_eos: turn - # Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are: - # - all: train on all EOT tokens - # - turn: train on the EOT token at the end of each trainable turn - # - last: train on the last EOT token in the conversation - # If not specified, defaults to the value of train_on_eos for backward compatibility. - train_on_eot: - # The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`. - message_field_training: training - # The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. - # The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train). - message_field_training_detail: train_detail - - -# If false, the datasets will not be shuffled and will keep their original order in `datasets`. -# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. -shuffle_merged_datasets: true - -# Deduplicates datasets and test_datasets with identical entries. -dataset_exact_deduplication: true - -# A list of one or more datasets to eval the model with. -# You can use either test_datasets, or val_set_size, but not both. -test_datasets: - - path: /workspace/data/eval.jsonl - ds_type: json - # You need to specify a split. For "json" datasets the default split is called "train". - split: train - type: completion - data_files: - - /workspace/data/eval.jsonl - -# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo' -rl: -rl_beta: # Optional[float]. The beta parameter for the RL training. - -# dpo -dpo_use_weighting: # Optional[bool]. Whether to perform weighting. -rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper. - -# orpo -orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping. - -# kto -kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss. -kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss. - -# simpo -cpo_alpha: 1.0 # Weight of the BC regularizer -simpo_gamma: 0.5 # Target reward margin for the SimPO loss - -# grpo -trl: - use_vllm: # Optional[bool]. Whether to use VLLM for RL training. - vllm_server_host: # Optional[str]. Host of the vLLM server to connect to. - vllm_server_port: # Optional[int]. Port of the vLLM server to connect to. - vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond. - vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding. - - beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use - max_completion_length: # Optional[int]. Maximum length of the completion for RL training. - - reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir. - reward_weights: # Optional[list[float]]. List of reward weights for the reward functions. - - num_generations: # Optional[int]. Number of generations to sample. - log_completions: # Optional[bool]. Whether to log completions. - num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True. - - sync_ref_model: # Optional[bool]. Whether to sync the reference model. - ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model. - ref_model_sync_steps: # Optional[int]. Sync steps for the reference model. - scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation. - - temperature: # Optional[float]. Sampling temperature for the GRPO policy. - top_p: # Optional[float]. Top-p sampling probability for the generation policy. - top_k: # Optional[int]. Top-k sampling for the generation policy. - min_p: # Optional[float]. Minimum probability for the generation policy. - repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text. - - num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO. - epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm. - epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm. - use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO. - loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo. - mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation. - - -# reward modelling: `True` or `False` -reward_model: - -# process reward modelling: `True` or `False` -process_reward_model: - -# The name of the chat template to use for training, following values are supported: -# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. -# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py -# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. -# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. -# The selected chat template will be saved to the tokenizer_config.json for easier inferencing -# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template. -chat_template: tokenizer_default -# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. -chat_template_jinja: null -# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training. -# These tokens mark the boundaries between conversation turns. -# For example: ["/INST", "", "[/SYSTEM_PROMPT]"] -# If not specified, defaults to just the model's eos_token. -# This is useful for templates that use multiple delimiter tokens. -eot_tokens: - # - "" - # - "[/INST]" - # - "[/SYSTEM_PROMPT]" -# Changes the default system message -default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. -# Axolotl attempts to save the dataset as an arrow after packing the data together so -# subsequent training attempts load faster, relative path -dataset_prepared_path: data/last_run_prepared -# Push prepared dataset to hub -push_dataset_to_hub: # Optional[str] repo_org/repo_name -# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` -# if not set. -dataset_processes: # defaults to os.cpu_count() if not set -# Keep dataset in memory while preprocessing -# Only needed if cached dataset is taking too much storage -dataset_keep_in_memory: -# push checkpoints to hub -hub_model_id: # private repo path to push finetuned model -# how to push checkpoints to hub -# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy -hub_strategy: -# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets -# Required to be true when used in combination with `push_dataset_to_hub` -hf_use_auth_token: # boolean -# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval. -val_set_size: 0.04 -# Num shards for whole dataset -dataset_shard_num: -# Index of shard to use for whole dataset -dataset_shard_idx: - -# The maximum length of an input to train with, this should typically be less than 2048 -# as most models have a token/context limit of 2048 -sequence_len: 2048 -# Pad inputs so each step uses constant sized buffers -# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently -pad_to_sequence_len: -# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' -sample_packing: -# Set to 'false' if getting errors during eval with sample_packing on. -eval_sample_packing: -# You can set these packing optimizations AFTER starting a training at least once. -# The trainer will provide recommended values for these values. -sample_packing_eff_est: -total_num_tokens: -# Increasing the following values helps with packing, but usually only slightly (<%1.) -# The number of samples packed at a time. -sample_packing_group_size: 100000 -# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. -sample_packing_bin_size: 200 -sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially. - -# whether to concatenate samples during pretraining -pretraining_sample_concatenation: - -curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning - -# Use batch flattening for speedups when not using sample_packing -batch_flattening: - -# Passed through to transformers when loading the model when launched without accelerate -# Use `sequential` when training w/ model parallelism to limit memory -device_map: -# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model. -max_memory: - -# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model -adapter: lora -# If you already have a lora model trained that you want to load, put that here. -# This means after training, if you want to test the model, you should set this to the value of `output_dir`. -# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`. -lora_model_dir: - -# LoRA hyperparameters -# For more details about the following options, see: -# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2 -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: - - q_proj - - v_proj -# - k_proj -# - o_proj -# - gate_proj -# - down_proj -# - up_proj -lora_target_linear: # If true, will target all linear modules - -# List[int] | int. # The layer indices to transform, otherwise, apply to all layers -# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform -peft_layers_to_transform: - -# Optional[bool]. Whether to use DoRA. -# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora -peft_use_dora: - -# Optional[bool]. Whether to use RSLoRA. -# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora -peft_use_rslora: - -# Optional[list[tuple[int, int]]]. List of layer indices to replicate. -# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora -peft_layer_replication: - -# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"] -# How to initialize LoRA weights. Default to True which is MS original implementation. -# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization -peft_init_lora_weights: - -# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. -# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. -# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities. -# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994 -lora_modules_to_save: -# - embed_tokens -# - lm_head - -lora_fan_in_fan_out: false - -# Apply custom LoRA autograd functions and activation function Triton kernels for -# speed and memory savings -# See: https://docs.axolotl.ai/docs/lora_optims.html -lora_mlp_kernel: true -lora_qkv_kernel: true -lora_o_kernel: true - -# LoRA+ hyperparameters -# For more details about the following options, see: -# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py` -loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4. -loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6. - -peft: - # Configuration options for loftq initialization for LoRA - # https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization - loftq_config: - loftq_bits: # typically 4 bits - -# ReLoRA configuration -# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed -relora_steps: # Number of steps per ReLoRA restart -relora_warmup_steps: # Number of per-restart warmup steps -relora_anneal_steps: # Number of anneal steps for each relora cycle -relora_prune_ratio: # threshold for optimizer magnitude when pruning -relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings - -# wandb configuration if you're using it -# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`. -wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb -wandb_project: # Your wandb project name -wandb_entity: # A wandb Team name if using a Team -wandb_watch: -wandb_name: # Set the name of your wandb run -wandb_run_id: # Set the ID of your wandb run -wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training - -# mlflow configuration if you're using it -mlflow_tracking_uri: # URI to mlflow -mlflow_experiment_name: # Your experiment name -mlflow_run_name: # Your run name -hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry - -# Comet configuration if you're using it -# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`. -# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start -use_comet: # Enable or disable Comet integration. -comet_api_key: # API key for Comet. Recommended to set via `comet login`. -comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace. -comet_project_name: # Project name in Comet. Defaults to Uncategorized. -comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key. -comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration. -comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True. -comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details. - -# Tensorboard -use_tensorboard: # Optional[bool] - -# Where to save the full-finetuned model to -output_dir: ./completed-model - -# Whether to use torch.compile and which backend to use -# setting to `auto` will enable torch compile when torch>=2.5.1 -torch_compile: # Optional[Union[Literal["auto"], bool]] -torch_compile_backend: # Optional[str] -torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune' - -# Training hyperparameters - -# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps. -gradient_accumulation_steps: 1 -# The number of samples to include in each batch. This is the number of samples sent to each GPU. -# Batch size per gpu = micro_batch_size * gradient_accumulation_steps -micro_batch_size: 2 -eval_batch_size: -num_epochs: 4 -warmup_steps: 100 # cannot use with warmup_ratio -warmup_ratio: 0.05 # cannot use with warmup_steps -learning_rate: 0.00003 -lr_quadratic_warmup: -logging_steps: -eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps -evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps -eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`. -save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`. -save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps -saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps -save_total_limit: # Checkpoints saved at a time -save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints. -# Maximum number of iterations to train for. It precedes num_epochs which means that -# if both are set, num_epochs will not be guaranteed. -# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps -max_steps: - -# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time. -include_tokens_per_second: # Optional[bool] - -# whether to find batch size that fits in memory. Passed to underlying transformers Trainer -auto_find_batch_size: # Optional[bool] - -eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 -eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 -do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`. -eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"] - -profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir. - # see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information - # snapshots can be visualized @ https://pytorch.org/memory_viz - -loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training) -loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3) - -# Save model as safetensors (require safetensors package). Default True -save_safetensors: - -# Whether to mask out or include the human's prompt from the training labels -train_on_inputs: false -# Group similarly sized data to minimize padding. -# May be slower to start, as it must download and sort the entire dataset. -# Note that training loss may have an oscillating pattern with this enabled. -group_by_length: false - -# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk". -# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing -gradient_checkpointing: false -# additional kwargs to pass to the trainer for gradient checkpointing -# gradient_checkpointing_kwargs: -# use_reentrant: true - -# Stop training after this many evaluation losses have increased in a row -# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback -early_stopping_patience: 3 - -# Specify a scheduler and kwargs to use with the optimizer -# Valid values are driven by the Transformers SchedulerType class, see: -# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420 -# Valid values include -# - 'linear' -# - 'cosine' (default) -# - 'cosine_with_restarts' -# - 'polynomial' -# - 'constant' -# - 'constant_with_warmup' -# - 'inverse_sqrt' -# - 'reduce_lr_on_plateau' -# - 'cosine_with_min_lr' -# - 'warmup_stable_decay' - -# Additional schedulers include: -# - 'one_cycle' -# - 'rex' -lr_scheduler: -lr_scheduler_kwargs: -cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr -cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf) - -# For one_cycle optim -lr_div_factor: # Learning rate div factor - -# Specify optimizer -# Valid values are driven by the Transformers OptimizerNames class, see: -# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189 -# -# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of -# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used -# in the examples/ for your model and fine-tuning use case. -# -# Valid values for 'optimizer' include: -# - adamw_torch -# - adamw_torch_fused (default) -# - adamw_torch_xla -# - adamw_torch_npu_fused -# - adamw_apex_fused -# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1) -# - adafactor -# - adamw_anyprecision -# - adamw_torch_4bit -# - ademamix -# - sgd -# - adagrad -# - adamw_bnb_8bit -# - adamw_8bit # alias for adamw_bnb_8bit -# - ademamix_8bit -# - lion_8bit -# - lion_32bit -# - paged_adamw_32bit -# - paged_adamw_8bit -# - paged_ademamix_32bit -# - paged_ademamix_8bit -# - paged_lion_32bit -# - paged_lion_8bit -# - rmsprop -# - rmsprop_bnb -# - rmsprop_bnb_8bit -# - rmsprop_bnb_32bit -# - galore_adamw -# - galore_adamw_8bit -# - galore_adafactor -# - galore_adamw_layerwise -# - galore_adamw_8bit_layerwise -# - galore_adafactor_layerwise -# - lomo -# - adalomo -# - grokadamw -# - schedule_free_adamw -# - schedule_free_sgd -# - apollo_adamw -# - apollo_adamw_layerwise -# -# Additional custom optimizers include: -# - optimi_adamw -# - ao_adamw_8bit -# - ao_adamw_fp8 -# - came_pytorch -optimizer: -# Dictionary of arguments to pass to the optimizer -optim_args: -# For Galore Optimizers the following optim_args are available -# rank: # type: int -# update_proj_gap # type: int -# scale # type: float -# proj_type: # type: str, default = std - -# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm -optim_target_modules: -# - self_attn # for llama -# - mlp - -# Specify weight decay -weight_decay: -# adamw hyperparams -adam_beta1: -adam_beta2: -adam_beta3: # only used for CAME Optimizer -adam_epsilon: -adam_epsilon2: # only used for CAME Optimizer -# Gradient clipping max norm -max_grad_norm: - -# Augmentation techniques -# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings -# currently only supported on Llama and Mistral -neftune_noise_alpha: - -# Optional[bool]. Whether to bettertransformers -flash_optimum: - -# Note: Only one of the following attention patches can be used at a time. -# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`. - -# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers: -xformers_attention: -# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention: -flash_attention: -flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only -flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only -flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation -flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation -# Optional[bool]. Whether to use scaled-dot-product attention -# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html -sdp_attention: -# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf -s2_attention: - -# Optional[bool]. Whether to use low_cpu_mem_usage -low_cpu_mem_usage: -# Optional[str]. Resume from a specific checkpoint dir -resume_from_checkpoint: -# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off. -# Be careful with this being turned on between different models. -auto_resume_from_checkpoints: false - -## Multimodal section -# int | tuple[int, int] | None . Size to resize images to, width x height. -# Will read from model/processor config if not set. -image_size: -# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear". -image_resize_algorithm: 'bilinear' -## End of multimodal section - -# Don't mess with this, it's here for accelerate and torchrun -local_rank: - -# Add or change special tokens. -# If you add tokens here, you don't need to add them to the `tokens` list. -special_tokens: - # bos_token: "" - # eos_token: "" - # unk_token: "" - # pad_token: "[PAD]" - -# Optional[list[str]]. Add extra tokens to the tokenizer. -tokens: - # - "<|startoftext|>" - # - "<|endoftext|>" - -# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer. -# Only works for tokens that are not part of the base vocab (aka are added_tokens). -# Can be checked if they exist in tokenizer.json added_tokens. -added_tokens_overrides: # Dict[int, str] -# 128041: "<|im_start|>" -# 128042: "<|im_end|>" - -# FSDP -fsdp: -fsdp_config: - -# Deepspeed config path. e.g., deepspeed_configs/zero3.json -deepspeed: - -# Advanced DDP Arguments -ddp_timeout: -ddp_bucket_cap_mb: -ddp_broadcast_buffers: - -# Sequence parallelism -# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. -# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. -# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized -# subsequences, or set to 4 to split into four equal-sized subsequences. -# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details. -sequence_parallel_degree: -# Optional; strides across the key dimension. Larger values use more memory but should make training faster. -# Must evenly divide the number of KV heads in your model. -heads_k_stride: 1 -# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3" -# in the sample packing case, and "batch_ring" in the non-sample packing case. -ring_attn_func: - -# Path to torch distx for optim 'adamw_anyprecision' -torchdistx_path: - -# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize -pretraining_dataset: - -# Debug mode -debug: - -# Seed -seed: - -# Allow overwrite yml config using from cli -strict: -``` diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 290841c08..d1fca9441 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -12,7 +12,7 @@ Chat Template strategy uses a jinja2 template that converts a list of messages i {"conversations": [{"role": "...", "content": "..."}]} ``` -See [configs](../config.qmd) for full configs and supported templates. +See [configs](../config-reference.qmd) for full configs and supported templates. ### Migrating from sharegpt @@ -130,13 +130,13 @@ datasets: ``` ::: {.callout-tip} -See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens. +See [config documentation](../config-reference.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens. ::: ::: {.callout-note} Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior. -You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details. +You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config-reference.qmd) for more details. ::: - Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`. diff --git a/docs/dataset-formats/inst_tune.qmd b/docs/dataset-formats/inst_tune.qmd index d89c6adaf..f5bd7ab8f 100644 --- a/docs/dataset-formats/inst_tune.qmd +++ b/docs/dataset-formats/inst_tune.qmd @@ -186,4 +186,4 @@ datasets: no_input_format: "[INST] {instruction} [/INST]" ``` -See full config options under [here](../config.qmd). +See full config options under [here](../config-reference.qmd). diff --git a/docs/dataset_loading.qmd b/docs/dataset_loading.qmd index b78f86a98..bcffe7f0f 100644 --- a/docs/dataset_loading.qmd +++ b/docs/dataset_loading.qmd @@ -36,7 +36,7 @@ This matches the API of [`datasets.load_dataset`](https://github.com/huggingface For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading). -For full details on the config, see [config.qmd](config.qmd). +For full details on the config, see [config-reference.qmd](config-reference.qmd). ::: {.callout-note} diff --git a/docs/getting-started.qmd b/docs/getting-started.qmd index 6f1b54348..de059c397 100644 --- a/docs/getting-started.qmd +++ b/docs/getting-started.qmd @@ -55,7 +55,7 @@ output_dir: ./outputs/lora-out - To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`. ::: -See our [Config options](config.qmd) for more details. +See our [config options](config-reference.qmd) for more details. ### Training {#sec-training} @@ -179,7 +179,7 @@ Now that you have the basics, you might want to: Check our other guides for details on these topics: -- [Configuration Guide](config.qmd) - Full configuration options +- [Configuration Guide](config-reference.qmd) - Full configuration options - [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources - [Dataset Formats](dataset-formats) - Working with different data formats - [Multi-GPU Training](multi-gpu.qmd) diff --git a/docs/installation.qmd b/docs/installation.qmd index 15f2db57b..c905e93cd 100644 --- a/docs/installation.qmd +++ b/docs/installation.qmd @@ -14,7 +14,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir ## Requirements {#sec-requirements} - NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU -- Python ≥3.10 +- Python ≥3.11 - PyTorch ≥2.5.1 ## Installation Methods {#sec-installation-methods} @@ -153,7 +153,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker. ### Conda/Pip venv {#sec-conda} -1. Install Python ≥3.10 +1. Install Python ≥3.11 2. Install PyTorch: https://pytorch.org/get-started/locally/ 3. Install Axolotl: ```{.bash} diff --git a/docs/quantize.qmd b/docs/quantize.qmd index 294efda8b..113fcafbe 100644 --- a/docs/quantize.qmd +++ b/docs/quantize.qmd @@ -32,7 +32,7 @@ output_dir: # The path to the output directory. Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory. -You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which +You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.qmd) - you can do this by using the existing QAT configuration file which you used to train the model: ```yaml diff --git a/docs/scripts/generate_config_docs.py b/docs/scripts/generate_config_docs.py new file mode 100644 index 000000000..e22da7d05 --- /dev/null +++ b/docs/scripts/generate_config_docs.py @@ -0,0 +1,752 @@ +# type: ignore + +""" +Quarto documentation generation from Pydantic models. Uses Pydantic model source code +to automatically group fields, including inherited fields from parent classes. +""" + +import ast +import inspect +import textwrap +import types +import typing +from typing import Any, FrozenSet, Type, Union + +from pydantic import BaseModel + +from axolotl.utils.schemas.config import AxolotlInputConfig + + +class QuartoGenerator: + """Generate Quarto documentation from Pydantic models.""" + + def __init__(self): + self._class_fields_cache = {} + self._inheritance_map_cache = {} + self._nested_models_cache = {} + + def _get_direct_fields(self, cls: Type[BaseModel]) -> FrozenSet[str]: + """Get fields defined directly in a single class (not inherited).""" + if cls in self._class_fields_cache: + return self._class_fields_cache[cls] + + fields = set() + + # Get annotated fields + if hasattr(cls, "__annotations__"): + fields.update(cls.__annotations__.keys()) + + # Filter out private/special methods + fields = {f for f in fields if not f.startswith("_")} + + result = frozenset(fields) + self._class_fields_cache[cls] = result + return result + + def _is_pydantic_model(self, type_obj) -> bool: + """Check if a type is a Pydantic BaseModel.""" + return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel) + + # pylint: disable=too-many-return-statements + def _extract_nested_type(self, field_type) -> Any: + """Extract the actual type from complex type annotations.""" + # Handle Annotated types (Python 3.9+) + if hasattr(typing, "get_origin") and hasattr(typing, "get_args"): + origin = typing.get_origin(field_type) + args = typing.get_args(field_type) + + if origin is not None: + # Handle Annotated[SomeType, ...] - extract the first argument + if hasattr(typing, "Annotated") and origin is typing.Annotated: + if args: + return self._extract_nested_type( + args[0] + ) # Recursively process the actual type + + # Handle list[SomeType], List[SomeType], etc. + elif origin in (list, typing.List): + if args: + return self._extract_nested_type( + args[0] + ) # Extract element type + + # Handle Union types (including | syntax) + elif origin is typing.Union: + # Get non-None types from the Union + non_none_types = [arg for arg in args if arg is not type(None)] + if len(non_none_types) >= 1: + # Prioritize Pydantic models over primitive types + pydantic_models = [ + arg + for arg in non_none_types + if self._is_pydantic_model(arg) + ] + if pydantic_models: + # Return the first Pydantic model found + return self._extract_nested_type(pydantic_models[0]) + + # No Pydantic models, return the first non-None type + return self._extract_nested_type(non_none_types[0]) + + # Handle new Python 3.10+ union syntax (PeftConfig | None) + if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType: + # Get non-None types from the Union + non_none_types = [ + arg for arg in field_type.__args__ if arg is not type(None) + ] + if len(non_none_types) >= 1: + # Prioritize Pydantic models over primitive types + pydantic_models = [ + arg for arg in non_none_types if self._is_pydantic_model(arg) + ] + if pydantic_models: + return self._extract_nested_type(pydantic_models[0]) + return self._extract_nested_type(non_none_types[0]) + + # Handle old typing.Union syntax (fallback) + if hasattr(field_type, "__origin__"): + if field_type.__origin__ is Union: + # Get non-None types from the Union + non_none_types = [ + arg for arg in field_type.__args__ if arg is not type(None) + ] + if len(non_none_types) >= 1: + # Prioritize Pydantic models over primitive types + pydantic_models = [ + arg for arg in non_none_types if self._is_pydantic_model(arg) + ] + if pydantic_models: + return self._extract_nested_type(pydantic_models[0]) + return self._extract_nested_type(non_none_types[0]) + # Handle other generic types like dict[str, Any], etc. + elif hasattr(field_type, "__args__"): + return field_type + + return field_type + + # pylint: disable=too-many-return-statements + def _extract_all_pydantic_models_from_type( + self, field_type + ) -> list[type[BaseModel]]: + """Extract all Pydantic models from a type annotation, including from Unions.""" + models = [] + + if field_type is None: + return models + + # Handle Annotated types + if hasattr(typing, "get_origin") and hasattr(typing, "get_args"): + origin = typing.get_origin(field_type) + args = typing.get_args(field_type) + + if origin is not None: + # Handle Annotated[SomeType, ...] - extract from the first argument + if hasattr(typing, "Annotated") and origin is typing.Annotated: + if args: + models.extend( + self._extract_all_pydantic_models_from_type(args[0]) + ) + return models + + # Handle list[SomeType], List[SomeType], etc. + if origin in (list, typing.List): + if args: + models.extend( + self._extract_all_pydantic_models_from_type(args[0]) + ) + return models + + # Handle Union types + if origin is typing.Union: + for arg in args: + if arg is not type(None): # Skip None type + models.extend( + self._extract_all_pydantic_models_from_type(arg) + ) + return models + + # Handle new Python 3.10+ union syntax + if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType: + for arg in field_type.__args__: + if arg is not type(None): # Skip None type + models.extend(self._extract_all_pydantic_models_from_type(arg)) + return models + + # Handle old typing.Union syntax (fallback) + if hasattr(field_type, "__origin__") and field_type.__origin__ is Union: + for arg in field_type.__args__: + if arg is not type(None): # Skip None type + models.extend(self._extract_all_pydantic_models_from_type(arg)) + return models + + # Check if this type itself is a Pydantic model + if self._is_pydantic_model(field_type): + models.append(field_type) + + return models + + def _get_nested_models( + self, model_class: type[BaseModel], visited=None + ) -> dict[str, type[BaseModel]]: + """Get all nested Pydantic models from a model class.""" + if visited is None: + visited = set() + + # Avoid infinite recursion + if model_class in visited: + return {} + + if model_class in self._nested_models_cache: + return self._nested_models_cache[model_class] + + visited.add(model_class) + nested_models = {} + + # Check all fields in the model + for field_info in model_class.model_fields.values(): + field_type = self._extract_nested_type(field_info.annotation) + + if self._is_pydantic_model(field_type): + nested_models[field_type.__name__] = field_type + # Recursively get nested models from this nested model + deeper_nested = self._get_nested_models(field_type, visited.copy()) + nested_models.update(deeper_nested) + + self._nested_models_cache[model_class] = nested_models + return nested_models + + def _build_inheritance_map(self, child_class: Type[BaseModel]): + """Build inheritance map for a class and all its parents.""" + if child_class in self._inheritance_map_cache: + return self._inheritance_map_cache[child_class] + + inheritance_map = {} + + # Get MRO and filter out BaseModel and object + mro_classes = [ + cls + for cls in child_class.__mro__ + if cls not in (BaseModel, object) and hasattr(cls, "__annotations__") + ] + + # Process each class in the MRO + for cls in mro_classes: + inheritance_map[cls] = self._get_direct_fields(cls) + + self._inheritance_map_cache[child_class] = inheritance_map + return inheritance_map + + def _wrap_comment(self, text: str, width: int = 88) -> list[str]: + """Wrap a comment to specified width, accounting for '# ' prefix.""" + if not text.strip(): + return ["#"] + + # Account for "# " prefix (2 characters) + content_width = width - 2 + wrapped_lines = textwrap.wrap(text, width=content_width) + return [f"# {line}" for line in wrapped_lines] + + def _extract_type_from_source( + self, model_class: type[BaseModel], field_name: str + ) -> str: + """Extract the actual type annotation text from source code, checking inheritance chain.""" + # Use inheritance map to check classes efficiently + inheritance_map = self._build_inheritance_map(model_class) + + # Check classes in MRO order + for cls in model_class.__mro__: + if cls in inheritance_map and field_name in inheritance_map[cls]: + type_annotation = self._get_type_from_class_source(cls, field_name) + if type_annotation != "unknown": + return type_annotation + + return "unknown" + + def _get_type_from_class_source(self, class_obj: type, field_name: str) -> str: + """Extract type annotation from a specific class's source code.""" + try: + source = inspect.getsource(class_obj) + tree = ast.parse(source) + except (OSError, TypeError): + return "unknown" + + # Find the class definition + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__: + # Find the field assignment + for body_node in node.body: + if isinstance(body_node, ast.AnnAssign) and isinstance( + body_node.target, ast.Name + ): + if body_node.target.id == field_name and body_node.annotation: + return ast.unparse(body_node.annotation) + break + + return "unknown" + + def _extract_field_groups_from_all_classes( + self, model_class: type[BaseModel] + ) -> list[dict]: + """Extract field groups from all classes in the inheritance hierarchy.""" + all_groups = [] + inheritance_map = self._build_inheritance_map(model_class) + + # Get all Pydantic base classes in MRO order (most specific first) + # This puts AxolotlInputConfig fields first, then parent class fields + pydantic_classes = [ + cls + for cls in model_class.__mro__ + if cls in inheritance_map and inheritance_map[cls] + ] + + # Extract groups from each class + for cls in pydantic_classes: + class_groups = self._extract_field_groups_from_source(cls) + for group in class_groups: + all_groups.append(group) + + # If no groups found, create a default grouping by class + if not all_groups: + for cls in pydantic_classes: + fields_in_class = inheritance_map[cls] + if fields_in_class: + all_groups.append( + { + "fields": list(fields_in_class), + } + ) + + return all_groups + + # pylint: disable=too-many-return-statements + def _extract_field_groups_from_source( + self, model_class: type[BaseModel] + ) -> list[dict]: + """Extract field groups from source code based on blank lines and comments.""" + try: + source = inspect.getsource(model_class) + tree = ast.parse(source) + except (OSError, TypeError): + # Fallback if we can't get source code + fields_in_class = self._get_direct_fields(model_class) + if fields_in_class: + return [ + { + "fields": list(fields_in_class), + } + ] + return [] + + groups = [] + current_group_fields = [] + current_group_comment = None + + # Find the class definition + class_node = None + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == model_class.__name__: + class_node = node + break + + if not class_node: + fields_in_class = self._get_direct_fields(model_class) + if fields_in_class: + return [ + { + "fields": list(fields_in_class), + } + ] + return [] + + # Parse the source lines to detect groupings + source_lines = source.split("\n") + + # Get fields that are actually defined in this specific class + fields_in_class = self._get_direct_fields(model_class) + + # Find assignments that correspond to model fields for THIS class only + field_assignments = [] + for node in class_node.body: + if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + field_name = node.target.id + if field_name in fields_in_class: + field_assignments.append( + { + "name": field_name, + "lineno": node.lineno, + "end_lineno": getattr(node, "end_lineno", node.lineno), + } + ) + + if not field_assignments: + if fields_in_class: + return [ + { + "fields": list(fields_in_class), + } + ] + return [] + + # Sort by line number + field_assignments.sort(key=lambda x: x["lineno"]) + + # Group fields based on blank lines and comments + for i, field_info in enumerate(field_assignments): + field_name = field_info["name"] + current_line = field_info["lineno"] + + # Check if this starts a new group (blank line before or significant gap) + is_new_group = False + + if i == 0: + is_new_group = True + else: + prev_end_line = field_assignments[i - 1]["end_lineno"] + + # Check for blank lines or comments between fields + lines_between = source_lines[prev_end_line : current_line - 1] + has_blank_line = any(line.strip() == "" for line in lines_between) + has_comment = any( + line.strip().startswith("#") for line in lines_between + ) + + # Start new group if there's a blank line or comment, or significant gap + if has_blank_line or has_comment or (current_line - prev_end_line > 3): + is_new_group = True + + if is_new_group and current_group_fields: + # Save the previous group + groups.append( + { + "fields": current_group_fields.copy(), + "description": current_group_comment, + } + ) + current_group_fields = [] + current_group_comment = None + + current_group_fields.append(field_name) + + # Add the final group + if current_group_fields: + groups.append( + { + "fields": current_group_fields, + "description": current_group_comment, + } + ) + + return groups + + def _generate_field_documentation( + self, + model_class: type[BaseModel], + field_name: str, + field_info: dict, + field_type_str: str, + is_required: bool, + indent_level: int = 0, + visited_models: set = None, + ) -> list[str]: + """Generate documentation for a single field, expanding nested models inline.""" + if visited_models is None: + visited_models = set() + + lines = [] + indent = " " * indent_level + + # Get the actual field type for nested model detection + if field_name in model_class.model_fields: + pydantic_field_info = model_class.model_fields[field_name] + actual_field_type = pydantic_field_info.annotation + else: + actual_field_type = None + + # Add description comment if available + description = field_info.get("description", "") + if description: + wrapped_lines = self._wrap_comment(description, width=88 - len(indent)) + for line in wrapped_lines: + lines.append(f"{indent}{line}") + + # Extract nested Pydantic models from the type annotation + nested_models = self._extract_all_pydantic_models_from_type(actual_field_type) + + # Filter out already visited models to prevent infinite recursion + expandable_models = [ + model for model in nested_models if model not in visited_models + ] + + if expandable_models: + # This field contains Pydantic models that can be expanded + + # Show the field with its full type annotation + field_line = f"{indent}{field_name}: {field_type_str}" + if field_info.get("default") is not None: + field_line += f" = {field_info['default']}" + if is_required: + field_line += " (required)" + lines.append(field_line) + + # Add to visited to prevent infinite recursion + new_visited = visited_models.copy() + new_visited.update(expandable_models) + + # Expand each nested Pydantic model + for i, nested_model in enumerate(expandable_models): + if i > 0: + lines.append("\n") + lines.append(f"{indent} # For {nested_model.__name__}:") + + # Get nested model schema + try: + nested_schema = nested_model.model_json_schema() + nested_properties = nested_schema.get("properties", {}) + nested_required = nested_schema.get("required", []) + except Exception: # pylint: disable=broad-exception-caught + # Fallback: use model fields directly + nested_properties = {} + nested_required = [] + for ( + nested_field_name, + nested_field_info, + ) in nested_model.model_fields.items(): + nested_description = "" + if ( + hasattr(nested_field_info, "json_schema_extra") + and nested_field_info.json_schema_extra + ): + nested_description = ( + nested_field_info.json_schema_extra.get( + "description", "" + ) + ) + elif ( + hasattr(nested_field_info, "description") + and nested_field_info.description + ): + nested_description = nested_field_info.description + + nested_default_val = None + if ( + hasattr(nested_field_info, "default") + and nested_field_info.default is not None + ): + if str(nested_field_info.default) != "PydanticUndefined": + nested_default_val = nested_field_info.default + + nested_properties[nested_field_name] = { + "type": "unknown", + "description": nested_description, + "default": nested_default_val, + } + + if nested_field_info.is_required(): + nested_required.append(nested_field_name) + + # Get field groups for the nested model + nested_field_groups = self._extract_field_groups_from_all_classes( + nested_model + ) + + # Generate nested fields with increased indentation + for i, group in enumerate(nested_field_groups): + if not group["fields"]: + continue + + # Add blank line between groups (except before first group) + if i > 0: + lines.append("") + + # Process nested fields + for nested_field_name in group["fields"]: + if nested_field_name not in nested_properties: + continue + + nested_field_info = nested_properties[nested_field_name] + nested_field_type = self._extract_type_from_source( + nested_model, nested_field_name + ) + nested_is_required = nested_field_name in nested_required + + # Recursively generate documentation for nested field + nested_lines = self._generate_field_documentation( + nested_model, + nested_field_name, + nested_field_info, + nested_field_type, + nested_is_required, + indent_level + 1, + new_visited, + ) + lines.extend(nested_lines) + else: + # Regular field (no expandable nested models) + field_line = f"{indent}{field_name}: {field_type_str}" + if field_info.get("default") is not None: + field_line += f" = {field_info['default']}" + if is_required: + field_line += " (required)" + lines.append(field_line) + + return lines + + def generate_qmd( + self, + model_class: type[BaseModel], + title: str | None = None, + expand_nested: bool = True, + ) -> str: + """Auto-generate config reference documentation including inherited fields.""" + + if title is None: + title = f"{model_class.__name__} Reference" + + # Try to get JSON schema, with fallback for serialization issues + try: + schema = model_class.model_json_schema() + properties = schema.get("properties", {}) + required = schema.get("required", []) + except Exception as e: # pylint: disable=broad-exception-caught + print( + f"Warning: Could not generate JSON schema ({e}). Using model fields instead." + ) + # Fallback: use model fields directly + properties = {} + required = [] + for field_name, field_info in model_class.model_fields.items(): + # Extract description from json_schema_extra or field info + description = "" + if ( + hasattr(field_info, "json_schema_extra") + and field_info.json_schema_extra + ): + description = field_info.json_schema_extra.get("description", "") + elif hasattr(field_info, "description") and field_info.description: + description = field_info.description + + # Get default value + default_val = None + if hasattr(field_info, "default") and field_info.default is not None: + # Handle special Pydantic default markers + if str(field_info.default) != "PydanticUndefined": + default_val = field_info.default + + properties[field_name] = { + "type": "unknown", + "description": description, + "default": default_val, + } + + if field_info.is_required(): + required.append(field_name) + + # Extract field groups from all classes in inheritance hierarchy + field_groups = self._extract_field_groups_from_all_classes(model_class) + + # Start building QMD content + qmd_lines = [ + "---", + f"title: {title}", + "description: A complete list of all configuration options.", + "---", + "", + ] + + # Generate one big code block with all fields (inline nested expansion) + qmd_lines.append("```yaml") + + for i, group in enumerate(field_groups): + if not group["fields"]: + continue + + # Add blank line between groups (except before first group) + if i > 0: + qmd_lines.append("") + + # Process fields in the order they appear in source + for field_name in group["fields"]: + if field_name not in properties: + continue + + field_info = properties[field_name] + field_type = self._extract_type_from_source(model_class, field_name) + is_required = field_name in required + + if expand_nested: + # Check if this field has nested models + if field_name in model_class.model_fields: + pydantic_field_info = model_class.model_fields[field_name] + nested_models = self._extract_all_pydantic_models_from_type( + pydantic_field_info.annotation + ) + has_nested = bool(nested_models) + else: + has_nested = False + + # Add blank line before nested config + if has_nested: + qmd_lines.append("") + + # Use the new inline generation method + field_lines = self._generate_field_documentation( + model_class, + field_name, + field_info, + field_type, + is_required, + indent_level=0, + visited_models=set(), + ) + qmd_lines.extend(field_lines) + + # Add blank line after nested config + if has_nested: + qmd_lines.append("") + else: + # Original simple approach + description = field_info.get("description", "") + default = field_info.get("default") + + # Add wrapped comment for description + if description: + wrapped_lines = self._wrap_comment(description) + qmd_lines.extend(wrapped_lines) + + line = f"{field_name}: {field_type}" + if default is not None: + line += f" = {default}" + if is_required: + line += " (required)" + qmd_lines.append(line) + + qmd_lines.append("```") + + # Join all lines and clean up any double newlines + content = "\n".join(qmd_lines) + + # Replace multiple consecutive newlines with just two newlines (one blank line) + import re + + content = re.sub(r"\n{3,}", "\n\n", content) + + # Ensure single newline at the very end + content = content.rstrip("\n") + "\n" + + return content + + +def main(): + generator = QuartoGenerator() + + print("Generating config reference content...") + qmd_content = generator.generate_qmd(AxolotlInputConfig, "Config Reference", True) + + print("Writing to file...") + with open("docs/config-reference.qmd", "w", encoding="utf-8") as f: + f.write(qmd_content) + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/favicon.jpg b/favicon.jpg index 43c690244..4ec358746 100644 Binary files a/favicon.jpg and b/favicon.jpg differ diff --git a/requirements.txt b/requirements.txt index cf8caba00..8bd77ab5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ packaging==23.2 huggingface_hub==0.32.2 peft==0.15.2 -transformers==4.52.3 +transformers==4.52.4 tokenizers>=0.21.1 accelerate==1.7.0 datasets==3.6.0 diff --git a/setup.py b/setup.py index 28f71f789..08c39c71c 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,7 @@ extras_require = { "yunchang==0.6.0", ], "deepspeed": [ - "deepspeed==0.17.0", + "deepspeed==0.17.1", "deepspeed-kernels", ], "mamba-ssm": [ diff --git a/src/axolotl/__init__.py b/src/axolotl/__init__.py index 63f28adda..314d22279 100644 --- a/src/axolotl/__init__.py +++ b/src/axolotl/__init__.py @@ -4,4 +4,4 @@ import pkgutil __path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package -__version__ = "0.10.0.dev0" +__version__ = "0.11.0.dev" diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index e399cf3c5..eed43e542 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -386,8 +386,10 @@ class TrainerBuilderBase(abc.ABC): elif self.cfg.eval_steps: training_args_kwargs["eval_strategy"] = "steps" training_args_kwargs["eval_steps"] = self.cfg.eval_steps + training_args_kwargs["eval_on_start"] = True elif self.cfg.eval_strategy: training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy + training_args_kwargs["eval_on_start"] = True def _configure_reporting(self, training_args_kwargs: dict): report_to = [] diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 8ff565dbb..47e33a332 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -21,18 +21,12 @@ from axolotl.core.trainers import ( AxolotlTrainer, ReLoRATrainer, ) -from axolotl.core.training_args import ( - AxolotlPRMConfig, - AxolotlRewardConfig, - AxolotlTrainingArguments, -) from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.processing_strategies import get_processing_strategy from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( - EvalFirstStepCallback, LossWatchDogCallback, SaveBetterTransformerModelCallback, bench_eval_callback_factory, @@ -63,7 +57,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() - callbacks.append(EvalFirstStepCallback()) if self.cfg.relora_steps: callbacks.append(ReLoRACallback(self.cfg)) @@ -130,6 +123,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return callbacks def _get_trainer_cls(self): + """ + Gets the trainer class for the given configuration. + """ if self.cfg.plugins: plugin_manager = PluginManager.get_instance() trainer_cls = plugin_manager.get_trainer_cls(self.cfg) @@ -146,6 +142,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return AxolotlTrainer def build(self, total_num_steps): + from axolotl.core.training_args import ( + AxolotlPRMConfig, + AxolotlRewardConfig, + AxolotlTrainingArguments, + ) + training_arguments_kwargs, trainer_kwargs = self._set_base_training_args( total_num_steps ) @@ -314,20 +316,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["image_resize_algorithm"] = ( self.cfg.image_resize_algorithm ) - if self.cfg.kd_ce_alpha is not None: - training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha - if self.cfg.kd_alpha is not None: - training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha - if self.cfg.kd_temperature is not None: - training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature - if self.cfg.kd_zscore_base_temp is not None: - training_arguments_kwargs["kd_zscore_base_temp"] = ( - self.cfg.kd_zscore_base_temp - ) - if self.cfg.kd_top_k_before_softmax is not None: - training_arguments_kwargs["kd_top_k_before_softmax"] = ( - self.cfg.kd_top_k_before_softmax - ) + + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + plugin_training_args = plugin_manager.get_training_args(self.cfg) + if plugin_training_args: + training_arguments_kwargs.update(plugin_training_args) if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig @@ -408,7 +402,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return trainer def build_collator( - self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs + self, + training_args, # type: "AxolotlTrainingArguments" # type: ignore + is_eval=False, + **kwargs, ): if training_args.pretraining: if ( @@ -437,7 +434,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] ] collator_args = [self.tokenizer] - if self.cfg.reward_model: + + collator_cls_and_kwargs = None + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs( + self.cfg, is_eval=is_eval + ) + + if collator_cls_and_kwargs: + collator = collator_cls_and_kwargs[0] + if kwargs and isinstance(kwargs, dict): + kwargs.update(collator_cls_and_kwargs[1]) + elif self.cfg.reward_model: collator = RewardDataCollatorWithPadding elif use_batch_sampler_collator: # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, @@ -468,16 +477,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator_args.pop(0) kwargs.pop("pad_to_multiple_of", None) kwargs.pop("padding", None) - elif self.cfg.kd_trainer: - from axolotl.integrations.kd.collator import ( - DataCollatorForKD, - KDBatchSamplerDataCollatorForSeq2Seq, - ) - - if self.cfg.sample_packing: - collator = KDBatchSamplerDataCollatorForSeq2Seq - else: - collator = DataCollatorForKD else: collator = DataCollatorForSeq2Seq diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 47ace7451..c5f01dd41 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -12,11 +12,6 @@ from axolotl.core.trainers import ( from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.grpo import GRPOStrategy -from axolotl.core.training_args import ( - AxolotlCPOConfig, - AxolotlKTOConfig, - AxolotlORPOConfig, -) from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import ensure_dtype from axolotl.utils.callbacks.qat import QATCallback @@ -83,6 +78,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase): """ Returns training_args and trainer_kwargs """ + from axolotl.core.training_args import ( + AxolotlCPOConfig, + AxolotlKTOConfig, + AxolotlORPOConfig, + ) + training_args_kwargs, trainer_kwargs = self._set_base_training_args( total_num_steps=total_num_steps ) @@ -150,6 +151,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if blocklist_key in training_args_kwargs: del training_args_kwargs[blocklist_key] + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + plugin_training_args = plugin_manager.get_training_args(self.cfg) + if plugin_training_args: + training_args_kwargs.update(plugin_training_args) + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg logging_first_step=True, **training_args_kwargs, diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 25ffb4cbf..fbae253d6 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -34,6 +34,7 @@ from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, ) +from axolotl.utils import get_not_null from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -104,7 +105,7 @@ class AxolotlTrainer( ) batch_max_len = train_batch_size * self.args.max_seq_length - return MultipackBatchSampler( + sampler = MultipackBatchSampler( base_sampler, lengths=get_dataset_lengths(dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, @@ -117,6 +118,9 @@ class AxolotlTrainer( num_processes=self.args.dataset_num_proc, ) + len(sampler) + return sampler + def _get_train_sampler( self, train_dataset: Optional[Dataset] = None ) -> Optional[Sampler]: @@ -224,7 +228,9 @@ class AxolotlTrainer( } if not isinstance(dataset, torch.utils.data.IterableDataset): - dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["drop_last"] = get_not_null( + self.args.dataloader_drop_last, True + ) if sampler_fn is not None: sampler = sampler_fn(dataset) if isinstance(sampler, BatchSampler): diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 2b53c6798..d5be9fc62 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -2,242 +2,17 @@ extra axolotl specific training args """ -from dataclasses import dataclass, field -from typing import Optional +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Type -from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig +from axolotl.integrations.config import merge_training_args -@dataclass -class AxolotlTrainingMixins: - """ - Mixin class for the Axolotl training args. - """ - - # pylint: disable=duplicate-code - model_type: Optional[str] = field( - default=None, metadata={"help": "HF model configuration model_type."} - ) - lr_quadratic_warmup: bool = field( - default=False, - metadata={"help": "Use quadratic warmup for cosine scheduling."}, - ) - pretraining: bool = field( - default=False, - metadata={ - "help": "Indicates to trainer whether we are doing continued pretraining." - }, - ) - sample_packing: bool = field( - default=False, - metadata={"help": "Use sample packing for efficient training."}, - ) - sample_packing_sequentially: bool = field( - default=False, - metadata={ - "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." - }, - ) - multipack_real_batches: bool = field( - default=False, - metadata={"help": "Use real batches for efficient training."}, - ) - eval_sample_packing: Optional[bool] = field( - default=None, - metadata={"help": "Use sample packing for efficient evals."}, - ) - sample_packing_efficiency: float = field( - default=1.0, - metadata={"help": "Sample packing efficiency for calculating batch length."}, - ) - sample_packing_bin_size: int = field( - default=200, - metadata={ - "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." - }, - ) - sample_packing_group_size: int = field( - default=100000, - metadata={ - "help": "The number of samples to group together for packing. Increase for better packing." - }, - ) - max_seq_length: int = field( - default=2048, - metadata={"help": "The maximum sequence length the model can handle"}, - ) - dataset_num_proc: int | None = field( - default=None, - metadata={"help": "The number of processes to use for data processing"}, - ) - relora_steps: Optional[int] = field( - default=None, - metadata={"help": "how often to reset for ReLoRA"}, - ) - relora_warmup_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_anneal_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_prune_ratio: Optional[float] = field( - default=0.9, - metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, - ) - bench_split: Optional[str] = field( - default="eval", metadata={"help": "The benchmark split to run on"} - ) - bench_dataset: Optional[str] = field( - default="pharaouk/dharma-1/dharma_1_mini.json", - metadata={ - "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" - }, - ) - do_bench_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Benchmark evaluation."} - ) - do_causal_lm_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Causal LM evaluation."} - ) - max_bench_samples: Optional[int] = field( - default=None, - metadata={ - "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." - }, - ) - bench_source_max_len: int = field( - default=2048, metadata={"help": "Maximum source sequence length for bench."} - ) - dataloader_prefetch_factor: Optional[int] = field( - default=None, - metadata={"help": "prefetch_factor argument to the dataloader"}, - ) - cosine_min_lr_ratio: Optional[float] = field( - default=None, - metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, - ) - cosine_constant_lr_ratio: Optional[float] = field( - default=None, - metadata={ - "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" - }, - ) - loraplus_lr_ratio: Optional[float] = field( - default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} - ) - loraplus_lr_embedding: Optional[float] = field( - default=1e-6, - metadata={"help": "loraplus learning rate for lora embedding layers."}, - ) - embedding_lr_scale: Optional[float] = field( - default=None, - metadata={"help": "Scale the learning rate for the embedding layers."}, - ) - lr_groups: Optional[list[dict]] = field( - default=None, - metadata={"help": "Specify learning rate groups for with different LRs."}, - ) - embedding_lr: Optional[float] = field( - default=None, - metadata={"help": "absolute learning rate for the embedding layers."}, - ) - qlora: bool = field( - default=False, - metadata={"help": "whether this is a qlora training"}, - ) - orpo_alpha: Optional[float] = field( - default=None, - ) - lisa_n_layers: Optional[int] = field( - default=None, - metadata={"help": "the number of activate layers in LISA"}, - ) - lisa_step_interval: Optional[int] = field( - default=None, - metadata={"help": "how often to switch layers in LISA"}, - ) - lisa_layers_attribute: Optional[str] = field( - default=None, - metadata={"help": "path under the model to access the layers"}, - ) - curriculum_sampling: Optional[bool] = field( - default=None, - metadata={"help": "whether to use sequential sampling for curriculum learning"}, - ) - alternate_lr_scheduler_type: Optional[str] = field( - default=None, - metadata={ - "help": "workaround to pass an alternate lr scheduler to the HF trainer" - }, - ) - chat_template: Optional[str] = field( - default=None, - metadata={"help": "Chat template converting chat messages to text"}, - ) - - kd_ce_alpha: Optional[float] = field( - default=None, - metadata={ - "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" - }, - ) - - kd_alpha: Optional[float] = field( - default=1.0, - metadata={"help": "The alpha scaling parameter for KD loss"}, - ) - - kd_temperature: Optional[float] = field( - default=1.0, - metadata={ - "help": "the temperature parameter for KL divergence loss when using KD" - }, - ) - - kd_zscore_base_temp: Optional[float] = field( - default=None, - metadata={ - "help": "the base temperature parameter for KL divergence with z-score when using KD" - }, - ) - - kd_top_k_before_softmax: Optional[bool] = field( - default=None, - metadata={ - "help": "Whether to apply top_k_before_softmax to the logits when using KD" - }, - ) - - adam_beta3: Optional[float] = field( - default=None, - metadata={ - "help": "The beta3 hyperparameter used in some optimizers such as CAME" - }, - ) - adam_epsilon2: Optional[float] = field( - default=None, - metadata={ - "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" - }, - ) - - # multi-modal section - - image_size: int | tuple[int, int] | None = field( - default=None, - metadata={"help": "The size of the image to resize to"}, - ) - - image_resize_algorithm: Resampling | None = field( - default=None, - metadata={"help": "The algorithm to use for image resizing"}, - ) - - # end of multi-modal section +AxolotlTrainingMixins: Type = merge_training_args() @dataclass diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py new file mode 100644 index 000000000..8fcaff632 --- /dev/null +++ b/src/axolotl/core/training_args_base.py @@ -0,0 +1,224 @@ +""" +Base Axolotl Training Mixins shared across various trainer configs +""" + +from dataclasses import dataclass, field +from typing import Optional + +from PIL.Image import Resampling + + +@dataclass +class AxolotlTrainingMixins: + """ + Mixin class for the Axolotl training args. + """ + + # pylint: disable=duplicate-code + model_type: Optional[str] = field( + default=None, metadata={"help": "HF model configuration model_type."} + ) + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) + sample_packing: bool = field( + default=False, + metadata={"help": "Use sample packing for efficient training."}, + ) + sample_packing_sequentially: bool = field( + default=False, + metadata={ + "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." + }, + ) + multipack_real_batches: bool = field( + default=False, + metadata={"help": "Use real batches for efficient training."}, + ) + eval_sample_packing: Optional[bool] = field( + default=None, + metadata={"help": "Use sample packing for efficient evals."}, + ) + sample_packing_efficiency: float = field( + default=1.0, + metadata={"help": "Sample packing efficiency for calculating batch length."}, + ) + sample_packing_bin_size: int = field( + default=200, + metadata={ + "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." + }, + ) + sample_packing_group_size: int = field( + default=100000, + metadata={ + "help": "The number of samples to group together for packing. Increase for better packing." + }, + ) + max_seq_length: int = field( + default=2048, + metadata={"help": "The maximum sequence length the model can handle"}, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "The number of processes to use for data processing"}, + ) + relora_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for ReLoRA"}, + ) + relora_warmup_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_anneal_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_prune_ratio: Optional[float] = field( + default=0.9, + metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, + ) + bench_split: Optional[str] = field( + default="eval", metadata={"help": "The benchmark split to run on"} + ) + bench_dataset: Optional[str] = field( + default="pharaouk/dharma-1/dharma_1_mini.json", + metadata={ + "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" + }, + ) + do_bench_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Benchmark evaluation."} + ) + do_causal_lm_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Causal LM evaluation."} + ) + max_bench_samples: Optional[int] = field( + default=None, + metadata={ + "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." + }, + ) + bench_source_max_len: int = field( + default=2048, metadata={"help": "Maximum source sequence length for bench."} + ) + dataloader_prefetch_factor: Optional[int] = field( + default=None, + metadata={"help": "prefetch_factor argument to the dataloader"}, + ) + cosine_min_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, + ) + cosine_constant_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} + ) + loraplus_lr_embedding: Optional[float] = field( + default=1e-6, + metadata={"help": "loraplus learning rate for lora embedding layers."}, + ) + embedding_lr_scale: Optional[float] = field( + default=None, + metadata={"help": "Scale the learning rate for the embedding layers."}, + ) + lr_groups: Optional[list[dict]] = field( + default=None, + metadata={"help": "Specify learning rate groups for with different LRs."}, + ) + embedding_lr: Optional[float] = field( + default=None, + metadata={"help": "absolute learning rate for the embedding layers."}, + ) + qlora: bool = field( + default=False, + metadata={"help": "whether this is a qlora training"}, + ) + orpo_alpha: Optional[float] = field( + default=None, + ) + lisa_n_layers: Optional[int] = field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = field( + default=None, + metadata={"help": "path under the model to access the layers"}, + ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) + alternate_lr_scheduler_type: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate lr scheduler to the HF trainer" + }, + ) + chat_template: Optional[str] = field( + default=None, + metadata={"help": "Chat template converting chat messages to text"}, + ) + + # kd_ce_alpha: Optional[float] = field( + # default=None, + # metadata={ + # "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" + # }, + # ) + # + # kd_alpha: Optional[float] = field( + # default=1.0, + # metadata={"help": "The alpha scaling parameter for KD loss"}, + # ) + # + # kd_temperature: Optional[float] = field( + # default=1.0, + # metadata={ + # "help": "the temperature parameter for KL divergence loss when using KD" + # }, + # ) + + adam_beta3: Optional[float] = field( + default=None, + metadata={ + "help": "The beta3 hyperparameter used in some optimizers such as CAME" + }, + ) + adam_epsilon2: Optional[float] = field( + default=None, + metadata={ + "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" + }, + ) + + # multi-modal section + + image_size: int | tuple[int, int] | None = field( + default=None, + metadata={"help": "The size of the image to resize to"}, + ) + + image_resize_algorithm: Resampling | None = field( + default=None, + metadata={"help": "The algorithm to use for image resizing"}, + ) + + # end of multi-modal section diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 0edc9fdea..9162bc745 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -22,6 +22,7 @@ from __future__ import annotations import collections import importlib +import traceback from typing import TYPE_CHECKING, Callable, OrderedDict, Union from peft import PeftModel @@ -83,6 +84,11 @@ class BasePlugin: def get_input_args(self) -> str | None: """Returns a pydantic model for the plugin's input arguments.""" + def get_training_args_mixin(self) -> str | None: + """ + Returns a dataclass model for the plugin's training arguments. + """ + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: @@ -158,6 +164,31 @@ class BasePlugin: trainer: The trainer object for training. """ + def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument): + """ + Returns custom training arguments to set on TrainingArgs. + + Args: + cfg: The global axolotl configuration. + + Returns: + object: dict containing the training arguments. + """ + + def get_collator_cls_and_kwargs( + self, cfg: DictDefault, is_eval: bool = False + ): # pylint: disable=unused-argument): + """ + Returns a custom class for the collator. + + Args: + cfg: The global axolotl configuration. + is_eval: Whether this is an eval split. + + Returns: + class: The class for the collator. + """ + # pylint: disable=unused-argument def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None: """Creates and returns an optimizer for training. @@ -278,7 +309,7 @@ def load_plugin(plugin_name: str) -> BasePlugin: return plugin -class PluginManager: +class PluginManager: # pylint: disable=too-many-public-methods """The `PluginManager` class is responsible for loading and managing plugins. It should be a singleton so it can be accessed from anywhere in the codebase. @@ -337,8 +368,11 @@ class PluginManager: plugin = load_plugin(plugin_name) self.plugins[plugin_name] = plugin LOG.info(f"Plugin loaded successfully: {plugin_name}") - except ImportError: + except ImportError as exc: LOG.error(f"Failed to load plugin: {plugin_name}") + # print stacktrace + traceback.print_exc() + print(f"Error: {exc}") def get_input_args(self) -> list[str]: """Returns a list of Pydantic classes for all registered plugins' input arguments.' @@ -353,6 +387,20 @@ class PluginManager: input_args.append(input_args_from_plugin) return input_args + def get_training_args_mixin(self): + """ + Returns a list of dataclasses for all registered plugins' training args mixins' + + Returns: + list[str]: A list of dataclsses + """ + training_args = [] + for plugin in self.plugins.values(): + training_args_from_plugin = plugin.get_training_args_mixin() + if training_args_from_plugin is not None: + training_args.append(training_args_from_plugin) + return training_args + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: @@ -442,6 +490,42 @@ class PluginManager: return trainer_cls return None + def get_training_args(self, cfg): + """ + Calls the get_training_args method of all registered plugins and returns the combined training arguments. + + Parameters: + cfg (dict): The configuration for the plugins. + + Returns: + object: The training arguments + """ + training_args_kwargs = {} + for plugin in self.plugins.values(): + training_args = plugin.get_training_args(cfg) + if training_args is not None: + training_args_kwargs.update(training_args) + + return training_args_kwargs + + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): + """ + Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class. + + Parameters: + cfg (dict): The configuration for the plugins. + is_eval (bool): Whether this is an eval split. + + Returns: + object: The collator class, or None if none was found. + """ + for plugin in self.plugins.values(): + collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval) + if collator is not None: + collator_cls, collator_kwargs = collator + return collator_cls, collator_kwargs + return None + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): """Calls the `post_trainer_create` method of all registered plugins. diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index b443f228e..f5fc07e9e 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio This was moved here to prevent circular imports. """ -from typing import Any, Dict, List +from typing import Any, Dict, List, Type from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, @@ -61,3 +61,43 @@ def merge_input_args(): ] return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase + + +def merge_training_args() -> Type: + """ + Merges training arguments from registered plugins with the base TrainingArguments. + + This function retrieves the training arguments from registered plugins using the PluginManager. + It then dynamically creates new classes, AxolotlTrainingMixins, + that inherit from the base configurations and include the training arguments from the plugins. + + Returns: + tuple: A tuple containing the newly created classes, AxolotlTrainingMixins. + """ + # pylint: disable=duplicate-code + from axolotl.core.training_args_base import ( + AxolotlTrainingMixins as AxolotlTrainingMixinsBase, + ) + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + training_args_mixins: List[str] = plugin_manager.get_training_args_mixin() + mixin_classes = [] + dynamic_input = "" + for plugin_args in training_args_mixins: + plugin_module, plugin_cls = plugin_args.rsplit(".", 1) + dynamic_input += f"from {plugin_module} import {plugin_cls}\n" + mixin_classes.append(plugin_cls) + if dynamic_input: + dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n" + + namespace: Dict[Any, Any] = {} + local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase} + exec( # pylint: disable=exec-used # nosec B102 + dynamic_input, {**globals(), **local_vars}, namespace + ) + AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name + "AxolotlTrainingMixins" + ] + return AxolotlTrainingMixins + return AxolotlTrainingMixinsBase diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index 8a6e3eda1..4c8535a0a 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -15,7 +15,12 @@ """ Plugin init to add KD support to Axolotl. """ +from typing import Any + +from transformers import Trainer + from axolotl.integrations.base import BasePlugin +from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback from .args import KDArgs # pylint: disable=unused-import. # noqa: F401 @@ -28,9 +33,75 @@ class KDPlugin(BasePlugin): def get_input_args(self): return "axolotl.integrations.kd.KDArgs" + def get_training_args_mixin(self): + return "axolotl.integrations.kd.args.KDTrainingArgsMixin" + def get_trainer_cls(self, cfg): if cfg.kd_trainer: from .trainer import AxolotlKDTrainer return AxolotlKDTrainer return None + + def get_training_args(self, cfg): + return { + "kd_ce_alpha": cfg.kd_ce_alpha, + "kd_alpha": cfg.kd_alpha, + "kd_temperature": cfg.kd_temperature, + "kd_beta": cfg.kd_beta, + "kd_normalize_topk": cfg.kd_normalize_topk, + } + + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): + if not cfg.kd_trainer: + return None, None + + from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq + + use_batch_sampler_collator = False + if is_eval is False and cfg.sample_packing: + use_batch_sampler_collator = True + if cfg.eval_sample_packing and is_eval: + use_batch_sampler_collator = True + + if cfg.kd_online_server_base_url: + from .collator_online_teacher import OnlineTeacherCollator + + return OnlineTeacherCollator, { + "kd_online_server_base_url": cfg.kd_online_server_base_url, + "kd_online_topk": cfg.kd_online_topk, + "kd_temperature": cfg.kd_temperature, + "kd_online_server": cfg.kd_online_server, + "kd_online_timeout": cfg.kd_online_timeout, + "kd_normalize_topk": cfg.kd_normalize_topk, + } + + if use_batch_sampler_collator: + return KDBatchSamplerDataCollatorForSeq2Seq, {} + return DataCollatorForKD, {} + + def pre_model_load(self, cfg): + from .kernels.models import apply_kernel + + apply_kernel(cfg.model_config_type) + + def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: + """ + Adds temp scheduler callback to the Trainer instance. + + Args: + cfg (Any): Configuration object containing the sparse recipe. + trainer (Trainer): Huggingface Trainer instance. + + Returns: + list: List containing the configured callback instances. + """ + if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url: + callback = KDTemperatureSchedulerCallback( + cfg.kd_temperature, + cfg.kd_temperature_min, + trainer, + ) + return [callback] + + return [] diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 2fbba2c6a..758bc8917 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -15,9 +15,19 @@ """ Plugin args for KD support. """ -from typing import Optional +from dataclasses import dataclass +from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field + + +class InferenceServerType(str, Enum): + """ + Online inferences server types to handle different request args + """ + + vllm = "vllm" # pylint: disable=invalid-name + sglang = "sglang" # pylint: disable=invalid-name class KDArgs(BaseModel): @@ -25,13 +35,41 @@ class KDArgs(BaseModel): Input args for knowledge distillation. """ - kd_trainer: Optional[bool] = None # whether to use KD trainer - kd_ce_alpha: Optional[float] = ( + kd_trainer: float | None = None # whether to use KD trainer + kd_ce_alpha: float | None = ( None # loss coefficient for cross-entropy loss during KD ) - kd_alpha: Optional[float] = None # loss coefficient for KD loss - kd_temperature: Optional[float] = None # temperature for sampling during KD - kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling - kd_top_k_before_softmax: Optional[bool] = ( - None # whether to sample top k before softmax during KD + kd_alpha: float | None = None # loss coefficient for KD loss + kd_temperature: float | None = None # temperature for sampling during KD + kd_beta: float | None = 0.0 # beta coefficient for ratio of fwd and reverse KL + kd_normalize_topk: bool | None = ( + None # whether to normalize student logits during KD + ) + + # TODO online kd + kd_online_server_base_url: str | None = None + kd_online_topk: int | None = None + kd_online_server: InferenceServerType | None = Field( + default_factory=lambda: InferenceServerType.vllm + ) + kd_online_timeout: int | None = 120 + kd_temperature_min: float | None = ( + None # kd temperature scheduling during online kd + ) + + +@dataclass +class KDTrainingArgsMixin: + """ + Additional args for KD training. + """ + + kd_ce_alpha: float | None = ( + None # loss coefficient for cross-entropy loss during KD + ) + kd_alpha: float | None = None # loss coefficient for KD loss + kd_temperature: float | None = None # temperature for sampling during KD + kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL + kd_normalize_topk: float | None = ( + None # whether to normalize student logits during KD ) diff --git a/src/axolotl/integrations/kd/callbacks.py b/src/axolotl/integrations/kd/callbacks.py new file mode 100644 index 000000000..911c3d517 --- /dev/null +++ b/src/axolotl/integrations/kd/callbacks.py @@ -0,0 +1,36 @@ +""" +Transformers trainer callbacks to schedule the KD temperature during training +""" + +import math + +from transformers.trainer_callback import TrainerCallback + + +class KDTemperatureSchedulerCallback(TrainerCallback): + """ + KD temperature scheduler callback for the trainer. + """ + + def __init__(self, temperature_start, temperature_min, trainer): + self.temperature_start = temperature_start + self.temperature_min = temperature_min + self.temperature = temperature_start + + self.trainer = trainer + + def on_step_end( + self, args, state, control, **kwargs + ): # pylint: disable=unused-argument + # cosine decay temperature over the max steps + + progress = state.global_step / state.max_steps + # Cosine decay factor: 0.5 * (1 + cos(pi * progress)) + # This factor goes from 1 (at progress=0) to 0 (at progress=1) + decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) + self.temperature = self.temperature_start - ( + (self.temperature_start - self.temperature_min) * (1.0 - decay_factor) + ) + + if hasattr(self.trainer.data_collator, "kd_temperature"): + self.trainer.data_collator.kd_temperature = self.temperature diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 7c99a9c3d..f99dfe458 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -15,12 +15,15 @@ """ Chat template prompt strategy loader with KD support """ +import logging from typing import Any, Dict import torch from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader +LOG = logging.getLogger(__name__) + class ChatTemplateStrategyWithKD(ChatTemplateStrategy): """ @@ -101,10 +104,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # fill with -inf for padding_len tokens for top_k tokens # extend target_logprobs with a padding_len x top_k 2D list filled with -inf - # for causal models, if we start the range at 1, then we don't need to shift in the trainer - # otherwise, we need to shift in the trainer - shift = 0 - for _ in range(shift, input_padding_len): + # we shift for causal models in the trainer, so start the range from 0 + for _ in range(0, input_padding_len): target_logprobs.append([-float("inf")] * top_k) target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k) @@ -143,6 +144,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # # Convert from log to probability teacher_probs_t1 = position_logprobs_tensor.exp() + # normalize probabilities to sum to 1 in case they aren't already + teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True) + if teacher_probs_t1_sum > 1e-9: + teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum if self.kd_temperature != self.gen_temperature: # Exponentiate by factor (T1 / T2) exponent = self.gen_temperature / self.kd_temperature @@ -162,12 +167,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_logprobs.append(position_logprobs_scaled) target_token_ids.append(position_token_ids) - if shift == 1: - # since we started at index 1 for causal, we need one more padding token - target_logprobs.append([-float("inf")] * top_k) - target_token_ids.append(list(range(top_k))) - target_mask.append([0] * top_k) - # Update sample with transformed logprobs sample["target_logprobs"] = target_logprobs sample["target_token_ids"] = target_token_ids @@ -184,6 +183,117 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): return tokenized_prompt +class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD): + """ + Strat for datasets with complete structured KD logprob data + """ + + def transform_logprobs(self, sample): + """ + Transform logprobs to target format for KD training + """ + # pylint: disable=duplicate-code + + logprobs = sample.pop(self.logprobs_field) + target_seq_len = len(logprobs) + input_seq_len = len(sample["input_ids"]) + input_padding_len = input_seq_len - target_seq_len + # get non-zero top-k (prune None logprobs from vllm data step) + top_k_vals = [ + len(logprobs[i]) + for i in range(len(logprobs)) + if logprobs[i] is not None and len(logprobs[i]) + ] + max_top_k = max(set(top_k_vals), key=top_k_vals.count) + min_top_k = min(set(top_k_vals), key=top_k_vals.count) + top_k = min(max_top_k, min_top_k) + if top_k == 0: + raise ValueError("No non-zero top-k logprobs found.") + + target_logprobs = [] + target_token_ids = [] + target_mask = [] + + if input_padding_len < 0: + # logprobs is longer than target_seq_len, + # so we need to slice from the left/beginning of logprobs + logprobs = logprobs[:-input_seq_len] + input_padding_len = 0 + # target_seq_len = input_seq_len + + # truncate the second dimension of the logprobs to top_k + logprobs = [row[:top_k] for row in logprobs] + + # fill with -inf for padding_len tokens for top_k tokens + # extend target_logprobs with a padding_len x top_k 2D list filled with -inf + + # we shift for causal models in the trainer, so start the range from 0 + for _ in range(0, input_padding_len): + target_logprobs.append([-float("inf")] * top_k) + target_token_ids.append(list(range(top_k))) + target_mask.append([0] * top_k) + + for position in range(input_padding_len, input_seq_len): + if sample["labels"][position] == -100: + target_mask.append([0] * top_k) + else: + target_mask.append([1] * top_k) + + for token_pos_logprobs, pos_target_token_ids in zip( + logprobs, sample["target_token_ids"] + ): + # Convert to a tensor for easier manipulation + position_logprobs_tensor = torch.tensor( + token_pos_logprobs, dtype=torch.float + ) + + # Now we have distribution at T1 in log form, i.e. log p_{T1}(k). + # Next, re-scale to T2 = self.kd_temperature via exponent-based trick + # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z + # + # Convert from log to probability + teacher_probs_t1 = position_logprobs_tensor.exp() + # normalize probabilities to sum to 1 in case they aren't already + teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True) + if teacher_probs_t1_sum > 1e-9: + teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum + if self.kd_temperature != self.gen_temperature: + # Exponentiate by factor (T1 / T2) + exponent = self.gen_temperature / self.kd_temperature + teacher_probs_t2 = teacher_probs_t1**exponent + else: + teacher_probs_t2 = teacher_probs_t1 + # Re-normalize + teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( + dim=0, keepdim=True + ) + # Convert back to log + position_logprobs_tensor = torch.log(teacher_probs_t2) + + # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor + position_logprobs_scaled = position_logprobs_tensor.tolist() + + target_logprobs.append(position_logprobs_scaled) + target_token_ids.append(pos_target_token_ids) + + # Update sample with transformed logprobs + sample["target_logprobs"] = target_logprobs + sample["target_token_ids"] = target_token_ids + sample["target_mask"] = target_mask + + return sample + + def _tokenize_single_prompt(self, prompt): + logprobs = prompt.pop(self.logprobs_field) + target_token_ids = prompt.pop("target_token_ids") + tokenized_prompt = super()._tokenize_single_prompt(prompt) + tokenized_prompt[self.logprobs_field] = logprobs + tokenized_prompt["target_token_ids"] = target_token_ids + tokenized_prompt = self.transform_logprobs(tokenized_prompt) + + return tokenized_prompt + + class KDStrategyLoader(StrategyLoader): """ Load ChatTemplateStrategy with KD support using StrategyLoader. @@ -204,4 +314,14 @@ class KDStrategyLoader(StrategyLoader): return strategy_params -load = KDStrategyLoader() +class KDStrategyLoaderV2(KDStrategyLoader): + """ + Load KD chat template datasets with pre-tokenized logprob data + """ + + def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument + return ChatTemplateStrategyWithKDv2 + + +load_legacy = KDStrategyLoader() +load = KDStrategyLoaderV2() diff --git a/src/axolotl/integrations/kd/collator.py b/src/axolotl/integrations/kd/collator.py index de63869c7..0cc745b78 100644 --- a/src/axolotl/integrations/kd/collator.py +++ b/src/axolotl/integrations/kd/collator.py @@ -47,11 +47,16 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): position_pad_token_id: int = 0 return_tensors: str = "pt" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors padding_side = self.tokenizer.padding_side + max_len = 0 # Pad labels and position_ids first for feature_name, pad_token_id in [ @@ -102,7 +107,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): target_mask_list.append(f.pop("target_mask")) # Determine max lengths - max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list) + max_teacher_seq_len = max_len or max( + len(seq) for seq in target_logprobs_list + ) max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq) padded_target_logprobs = [] @@ -209,7 +216,9 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): # We want to produce a single "merged" feature dict for each sub-batch. out_features = [{} for _ in features] - for i, sub_features in enumerate(features): + for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks + features + ): # sub_features is a list of dicts, each dict = one sequence’s features # We'll merge them into out_features[i]. # @@ -243,10 +252,17 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): # For example, input_ids or labels are often arrays. arrays = [] for feat in sub_features: - if field_name in feat: + if field_name in feat and isinstance( + feat[field_name], (list, torch.Tensor) + ): + if isinstance( + feat[field_name][0], (dict, str) + ): # pylint: disable=too-many-nested-blocks + continue arr = np.array(feat[field_name]) arrays.append(arr) - out_features[i][field_name] = np.concatenate(arrays) + if arrays: + out_features[i][field_name] = np.concatenate(arrays) # 3) Now call the parent collator, which will do: # - padding of labels/position_ids diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py new file mode 100644 index 000000000..584ace481 --- /dev/null +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -0,0 +1,561 @@ +""" +Packed data loader for online teacher training supporting vllm and sglang. +""" + +import hashlib +import hmac +import logging +from typing import Any, Dict, List, Optional + +import requests +import torch +from orjson import orjson + +from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq +from axolotl.integrations.kd.utils import normalize_logprobs +from axolotl.utils.data.utils import retry_on_request_exceptions + +LOG = logging.getLogger(__name__) + + +def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256): + """ + Create HMAC-SHA hash from a list of integers + + Args: + int_list: List of integers + key: Secret key (string or bytes) + hash_func: Hash function (default: sha256) + + Returns: + HMAC digest as hex string + """ + # Convert key to bytes if it's a string + if isinstance(key, str): + key = key.encode("utf-8") + + # Convert list of ints to bytes + # Method 1: Convert each int to bytes and concatenate + data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list) + + # Create HMAC + h = hmac.new(key, data, hash_func) + return h.hexdigest() + + +class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): + """ + Collator for online teacher training. + """ + + DEFAULT_LABEL_PAD_TOKEN_ID: int = -100 + + def __init__( + self, + *args: Any, + kd_online_server_base_url: Optional[str] = None, + kd_online_topk: Optional[int] = None, + kd_temperature: Optional[float] = 1.0, + kd_online_server: Optional[str] = "vllm", + kd_online_timeout: Optional[int] = 120, + kd_cache_dir: Optional[str] = None, + kd_normalize_topk: Optional[bool] = True, + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + + if kd_online_server_base_url is None: + raise ValueError( + "kd_online_server_base_url must be provided for OnlineTeacherDataloader" + ) + if kd_online_topk is None or kd_online_topk <= 0: + raise ValueError( + "kd_online_topk must be a positive integer for OnlineTeacherDataloader" + ) + + self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/") + self.kd_online_topk = kd_online_topk + self.kd_temperature = kd_temperature + self.kd_online_server = kd_online_server + self.http_session = requests.Session() + self.kd_online_timeout = kd_online_timeout + self.kd_cache_dir = kd_cache_dir + self.kd_normalize_topk = kd_normalize_topk + + def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]: + """ + Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. + """ + if not raw_logprobs or self.kd_online_topk == 0: + return ( + [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else [] + ) + + raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32) + return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist() + + @retry_on_request_exceptions(max_retries=10, delay=5) + def fetch_online_logprobs_sglang( + self, batch_input_ids: List[List[int]], labels: List[List[int]] + ): + """ + Fetches logprobs from an online teacher served by sglang for a batch of input_ids. + Assumes API returns token IDs as strings in logprob dictionary keys. + """ + api_endpoint = f"{self.kd_online_server_base_url}/generate" + + payload = { + "input_ids": batch_input_ids, + "return_logprob": True, + "top_logprobs_num": self.kd_online_topk, + "logprob_start_len": 0, + "return_text_in_logprobs": True, + "echo": True, + "sampling_params": { + "max_new_tokens": 0, + "temperature": self.kd_temperature, + "skip_special_tokens": False, + }, + } + + # Initialize with empty lists, so if API call fails, these are returned. + ret_data_target_token_ids: List[List[List[int]]] = [] + ret_data_target_logprobs: List[List[List[float]]] = [] + ret_data_target_mask: List[List[List[int]]] = [] + + try: + response = self.http_session.post( + api_endpoint, json=payload, timeout=self.kd_online_timeout + ) + response.raise_for_status() + api_data: list[dict] = response.json() + + # Ensure api_data is a list, and its length matches batch_input_ids + if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids): + LOG.error( + f"API response format error. Expected a list of {len(batch_input_ids)} " + f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." + ) + # Return empty data; items processed later will get default empty KD fields + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + for sequence_data, seq_input_ids, seq_labels in zip( + api_data, batch_input_ids, labels + ): + current_target_logprobs = [] + current_target_token_ids = [] + current_target_mask = [] + + meta_info = sequence_data.pop("meta_info", {}) + # Ensure input_top_logprobs is a list + input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop( + "input_top_logprobs", [] + ) + if not isinstance(input_top_logprobs, list): + LOG.warning( + f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence." + ) + input_top_logprobs = [] # Treat as empty + + # basic check that the logprob data len matches the input len, so no need to handle padding + assert len(seq_input_ids) == len(input_top_logprobs) + + for i, _, label in zip( + range(len(seq_input_ids)), seq_input_ids, seq_labels + ): + if i < len(input_top_logprobs) and input_top_logprobs[i] is None: + # this is always the case for the first token. + # there is never logprob data for the first token since that's a true input + # so we replace the None value with padding data + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + elif ( + i < len(input_top_logprobs) + and input_top_logprobs[i] is not None + ): + pos_top_logprobs_data = input_top_logprobs[i] + # Ensure pos_top_logprobs_data is a list of lists as expected + if not ( + isinstance(pos_top_logprobs_data, list) + and all( + isinstance(item, list) for item in pos_top_logprobs_data + ) + and len(pos_top_logprobs_data) > 0 + and len(pos_top_logprobs_data[0]) == 3 + ): # [logprob, token_id, token_str] + LOG.warning( + f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position." + ) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + continue + + # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids + pos_logprobs_raw, pos_token_ids, _ = [ + list(row) for row in zip(*pos_top_logprobs_data) + ] + + # Ensure correct length (top_k) + if len(pos_logprobs_raw) < self.kd_online_topk: + pad_len = self.kd_online_topk - len(pos_logprobs_raw) + pos_logprobs_raw.extend([-float("inf")] * pad_len) + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + + # truncate to top_k in case the response was longer + current_target_token_ids.append( + pos_token_ids[: self.kd_online_topk] + ) + + if self.kd_normalize_topk: + normalized_logprobs_for_position = self._normalize_logprobs( + pos_logprobs_raw[: self.kd_online_topk] + ) + current_target_logprobs.append( + normalized_logprobs_for_position + ) + else: + current_target_logprobs.append( + pos_logprobs_raw[: self.kd_online_topk] + ) + + # Mask depends on the corresponding label for the student + if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: + current_target_mask.append([0] * self.kd_online_topk) + else: + current_target_mask.append([1] * self.kd_online_topk) + else: + # Pad if no logprobs for this position (either due to length mismatch or None entry) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + + ret_data_target_token_ids.append(current_target_token_ids) + ret_data_target_logprobs.append(current_target_logprobs) + ret_data_target_mask.append(current_target_mask) + + except requests.exceptions.RequestException as e: + LOG.error(f"Error fetching logprobs from online teacher: {e}") + raise e + # ret_logprobs_data will be returned with empty lists, handled by the caller. + except Exception as e: # Catch other potential errors during processing + LOG.error( + f"Unexpected error processing API response in fetch_online_logprobs: {e}", + exc_info=True, + ) + raise e + + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + @retry_on_request_exceptions(max_retries=10, delay=5) + def fetch_online_logprobs_vllm( + self, batch_input_ids: List[List[int]], labels: List[List[int]] + ): + """ + Fetches logprobs from an online teacher served by vllm for a batch of input_ids. + Assumes API returns token IDs as strings in logprob dictionary keys. + """ + api_endpoint = f"{self.kd_online_server_base_url}/v1/completions" + + payload = { + "prompt": batch_input_ids, + "echo": True, + "logprobs": True, + "prompt_logprobs": self.kd_online_topk, + "top_logprobs": self.kd_online_topk, + "max_new_tokens": 0, + "skip_special_tokens": False, + "temperature": self.kd_temperature, + "sampling_params": { + "max_tokens": 0, + }, + } + + # Initialize with empty lists, so if API call fails, these are returned. + ret_data_target_token_ids: List[List[List[int]]] = [] + ret_data_target_logprobs: List[List[List[float]]] = [] + ret_data_target_mask: List[List[List[int]]] = [] + + try: + headers = {"Accept-Encoding": "deflate, gzip, br, zstd"} + response = self.http_session.post( + api_endpoint, + json=payload, + headers=headers, + timeout=self.kd_online_timeout, + ) + response.raise_for_status() + api_data: dict = orjson.loads(response.content) + choices: list[dict] = api_data["choices"] + + # Ensure api_data is a list, and its length matches batch_input_ids + if not isinstance(choices, list) or len(choices) != len(batch_input_ids): + LOG.error( + f"API response format error. Expected a list of {len(batch_input_ids)} " + f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." + ) + # Return empty data; items processed later will get default empty KD fields + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + for sequence_data, seq_input_ids, seq_labels in zip( + choices, batch_input_ids, labels + ): + # seq_input_ids: List[int] + # seq_labels: List[int] + + current_target_logprobs = [] + current_target_token_ids = [] + current_target_mask = [] + + # Ensure input_top_logprobs is a list + input_top_logprobs: Optional[list[None | dict[str, dict]]] = ( + sequence_data.pop("prompt_logprobs", []) + ) + + if not isinstance(input_top_logprobs, list): + LOG.warning( + f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence." + ) + input_top_logprobs = [] # Treat as empty + + # basic check that the logprob data len matches the input len, so no need to handle padding + assert len(seq_input_ids) == len(input_top_logprobs) + + seq_len = len(seq_input_ids) + + for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels): + if i < len(input_top_logprobs) and input_top_logprobs[i] is None: + # this is always the case for the first token. + # there is never logprob data for the first token since that's a true input + continue + if ( + i < len(input_top_logprobs) + and input_top_logprobs[i] is not None + ): + pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment] + # Ensure pos_top_logprobs_data is a list of lists as expected + if not ( + isinstance(pos_top_logprobs_data, dict) + and all( + isinstance(item, dict) + for item in pos_top_logprobs_data.values() + ) + and len(pos_top_logprobs_data.keys()) > 0 + ): # [logprob, token_id, token_str] + LOG.warning( + f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position." + ) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) + current_target_mask.append([0] * self.kd_online_topk) + continue + + # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids + pos_token_ids_str = list(pos_top_logprobs_data.keys()) + pos_logprobs_dict = pos_top_logprobs_data.values() + pos_token_ids = [ + int(token_id) for token_id in pos_token_ids_str + ] + pos_logprobs_raw = [ + float(logprob.get("logprob", -float("inf"))) + for logprob in pos_logprobs_dict + ] + + # Ensure correct length (top_k) + if len(pos_logprobs_raw) < self.kd_online_topk: + pad_len = self.kd_online_topk - len(pos_logprobs_raw) + LOG.warning( + f"Padding position {i} with {pad_len} top-k tokens and logprobs." + ) + pos_logprobs_raw.extend([-float("inf")] * pad_len) + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + + # truncate to top_k in case the response was longer + current_target_token_ids.append( + pos_token_ids[: self.kd_online_topk] + ) + + if self.kd_normalize_topk: + normalized_logprobs_for_position = self._normalize_logprobs( + pos_logprobs_raw[: self.kd_online_topk] + ) + current_target_logprobs.append( + normalized_logprobs_for_position + ) + else: + current_target_logprobs.append( + pos_logprobs_raw[: self.kd_online_topk] + ) + + # Mask depends on the corresponding label for the student + if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: + current_target_mask.append([0] * self.kd_online_topk) + else: + current_target_mask.append([1] * self.kd_online_topk) + else: + # Pad if no logprobs for this position (either due to length mismatch or None entry) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) + current_target_mask.append([0] * self.kd_online_topk) + for i in range(max(0, seq_len - len(current_target_logprobs))): + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append(list(range(self.kd_online_topk))) + current_target_mask.append([0] * self.kd_online_topk) + + ret_data_target_token_ids.append(current_target_token_ids) + ret_data_target_logprobs.append(current_target_logprobs) + ret_data_target_mask.append(current_target_mask) + + # TODO save and load targets to disk for caching for next epoch + # generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int + # if self.kd_cache_dir: + # hash_input_ids = hmac_sha_from_int_list( + # seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}" + # ) + # with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f: + # pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False) + + except requests.exceptions.RequestException as e: + LOG.error(f"Error fetching logprobs from online teacher: {e}") + raise e + # ret_logprobs_data will be returned with empty lists, handled by the caller. + except Exception as e: # Catch other potential errors during processing + LOG.error( + f"Unexpected error processing API response in fetch_online_logprobs: {e}", + exc_info=True, + ) + raise e + + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + def __call__( + self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None + ) -> Dict[str, Any]: + if not features: + return super().__call__(features, return_tensors=return_tensors) + + for ( + sub_batch_features + ) in features: # sub_batch_features is List[Dict[str, Any]] + if not sub_batch_features: + continue + + input_ids_for_api_call: List[List[int]] = [] + labels_for_api_call: List[List[int]] = [] + # Store references to the original item dictionaries to update them in-place + items_for_api_call: List[Dict[str, Any]] = [] + + for item_dict in sub_batch_features: + if not isinstance(item_dict, dict): + LOG.warning( + f"Skipping non-dict item in sub_batch_features: {item_dict}" + ) + continue + + current_input_ids = item_dict.get("input_ids") + current_labels = item_dict.get("labels") + + if current_input_ids is not None and current_labels is not None: + # Ensure input_ids and labels are lists of ints for JSON serialization + input_ids_list = ( + current_input_ids.tolist() + if hasattr(current_input_ids, "tolist") + else list(current_input_ids) + ) + labels_list = ( + current_labels.tolist() + if hasattr(current_labels, "tolist") + else list(current_labels) + ) + + input_ids_for_api_call.append(input_ids_list) + labels_for_api_call.append(labels_list) + items_for_api_call.append(item_dict) + else: + # This item will not get teacher logprobs from the API. + # Initialize KD fields to empty lists so downstream collators handle them uniformly. + item_dict.setdefault("target_token_ids", []) + item_dict.setdefault("target_logprobs", []) + item_dict.setdefault("target_mask", []) + + # print(items_for_api_call) + if items_for_api_call: # Only call API if there's something to process + if self.kd_online_server == "sglang": + api_responses_for_sub_batch = self.fetch_online_logprobs_sglang( + input_ids_for_api_call, labels_for_api_call + ) + else: + api_responses_for_sub_batch = self.fetch_online_logprobs_vllm( + input_ids_for_api_call, labels_for_api_call + ) + + # api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask" + # Each value is a list, corresponding to items_for_api_call + for i, item_to_update in enumerate(items_for_api_call): + # TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly. + if api_responses_for_sub_batch and i < len( + api_responses_for_sub_batch["target_token_ids"] + ): # Check bounds + assert len( + api_responses_for_sub_batch["target_token_ids"][i] + ) == len(item_to_update["input_ids"]) + assert len( + api_responses_for_sub_batch["target_logprobs"][i] + ) == len(item_to_update["input_ids"]) + assert len( + api_responses_for_sub_batch["target_mask"][i] + ) == len(item_to_update["labels"]) + item_to_update["target_token_ids"] = ( + api_responses_for_sub_batch["target_token_ids"][i] + ) + item_to_update["target_logprobs"] = api_responses_for_sub_batch[ + "target_logprobs" + ][i] + item_to_update["target_mask"] = api_responses_for_sub_batch[ + "target_mask" + ][i] + else: + # API call failed for this item, or response was shorter than expected. + # Ensure KD fields are initialized as empty lists. + LOG.warning( + f" (index {i}), or API response was too short. " + f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}" + ) + item_to_update.setdefault("target_token_ids", []) + item_to_update.setdefault("target_logprobs", []) + item_to_update.setdefault("target_mask", []) + + return super().__call__(features, return_tensors=return_tensors) diff --git a/src/axolotl/integrations/kd/kernels/__init__.py b/src/axolotl/integrations/kd/kernels/__init__.py index e69de29bb..3f1144a45 100644 --- a/src/axolotl/integrations/kd/kernels/__init__.py +++ b/src/axolotl/integrations/kd/kernels/__init__.py @@ -0,0 +1,8 @@ +""" +Liger Chunked loss optimizations module +""" + +from .liger import LigerFusedLinearKLTopKLogprobLoss +from .models import apply_kernel + +__all__ = ["LigerFusedLinearKLTopKLogprobLoss", "apply_kernel"] diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py new file mode 100644 index 000000000..6356643c2 --- /dev/null +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -0,0 +1,485 @@ +""" +Liger Kernels for Chunked Top-K Log-Prob Distillation +""" + +import torch +import torch.nn.functional as F +from liger_kernel.chunked_loss.fused_linear_distillation import ( + LigerFusedLinearDistillationBase, +) + +from axolotl.integrations.kd.utils import normalize_logprobs + + +class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): + """ + Chunked kl-div loss for top-k logprobs + """ + + @staticmethod + def distillation_loss_fn( + student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled + target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k] + target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs + target_mask_chunk: torch.Tensor, # [chunk_size, top_k] + beta: float = 0.0, + normalize_topk: bool = True, + ) -> torch.Tensor: + """ + Compute Top-K KL divergence loss for a chunk. + Args: + student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V). + target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K). + target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K). + target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K). + beta: Controls the type of KL divergence. + 0.0 for Forward KL (P_teacher || P_student). + 1.0 for Reverse KL (P_student || P_teacher). + 0.5 for Symmetric KL (average of Forward and Reverse). + normalize_topk: Whether to normalize the log probabilities + Returns: + Sum of KL divergence losses for the chunk. + """ + topk = target_token_ids_chunk.shape[-1] + student_logits_temp_scaled = ( # [chunk_size, vocab_size] + student_logits_temp_scaled.float() + ) + target_logprobs_chunk = target_logprobs_chunk.float() + + # Gather student logits for the top-k teacher token IDs + # target_token_ids_chunk: [chunk_size, top_k] + # student_logits_topk_temp_scaled: [chunk_size, top_k] + student_logits_topk_temp_scaled = torch.gather( + student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk + ) + + # Student log-probabilities for the gathered top-k tokens + student_lse = torch.logsumexp( + student_logits_temp_scaled, dim=-1, keepdim=True + ) # [chunk_size, 1] + student_logprobs_topk_temp_scaled = ( + student_logits_topk_temp_scaled - student_lse + ) + + # we have the top-k student logprobs, normalize them + if normalize_topk: + student_logprobs_topk_temp_scaled = normalize_logprobs( + student_logprobs_topk_temp_scaled, topk + ) + + valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k] + + student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask] + teacher_logprobs_valid = target_logprobs_chunk[valid_mask] + + # Teacher probabilities P(y|x_teacher) from logprobs + # target_logprobs_valid are already normalized (log(softmax(teacher_logits/T))) + teacher_probs_valid = teacher_logprobs_valid.exp() + # Student probabilities P_student from log P_student + student_probs_topk_valid = student_logprobs_topk_valid.exp() + + # kd_loss_per_token = torch.zeros_like(target_logprobs_valid) + + # KL divergence: sum(P_teacher * (log P_teacher - log P_student)) + # = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student) + # The distillation loss is often formulated as -sum(P_teacher * log P_student) + # or as sum(P_teacher * (log_softmax_teacher - log_softmax_student)) + # Here, target_logprobs_valid are log_softmax_teacher. + # student_logprobs_topk_valid are log_softmax_student (for the selected K indices). + if beta == 0.0: # Contribution from Forward KL + fwd_kl_per_token = teacher_probs_valid * ( + teacher_logprobs_valid - student_logprobs_topk_valid + ) + kd_loss = fwd_kl_per_token.sum() + elif beta == 1.0: # Contribution from Reverse KL + rev_kl_per_token = student_probs_topk_valid * ( + student_logprobs_topk_valid - teacher_logprobs_valid + ) + kd_loss = rev_kl_per_token.sum() + else: + # JSD - Jensen-Shannon Divergence / Symmetric + mean_probs = ( + 1 - beta + ) * student_probs_topk_valid + beta * teacher_probs_valid + log_mean_probs = mean_probs.log() + student_kl = F.kl_div( + log_mean_probs, + student_logprobs_topk_valid, + reduction="sum", + log_target=True, + ) + teacher_kl = F.kl_div( + log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True + ) + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + kd_loss = jsd_loss + + return kd_loss + + @staticmethod + def _compute_loss_kl_topk( + student_input_chunk: torch.Tensor, + student_weight: torch.Tensor, + # Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value + # or through `partial`. Let's make them explicit here for clarity. + target_token_ids_chunk: torch.Tensor, + target_logprobs_chunk: torch.Tensor, + target_mask_chunk: torch.Tensor, + target_chunk: torch.Tensor, # For hard loss (true labels) + student_bias: torch.Tensor = None, # This will be one of the grad targets + # Other params passed via `partial` from `forward` + distillation_loss_fn=None, + ignore_index: int = -100, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + compute_ce_loss: bool = True, + temperature: float = 1.0, + beta: float = 0.0, + normalize_topk: bool = True, + ): + # Compute student logits for the chunk from hidden states and LM head + # student_input_chunk: [chunk_size, hidden_dim] + # student_lm_head_weight: [vocab_size, hidden_dim] + # student_logits_chunk: [chunk_size, vocab_size] + student_logits_chunk = F.linear( + student_input_chunk, student_weight, student_bias + ) + + ce_loss = torch.tensor( + 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype + ) + if compute_ce_loss and weight_hard_loss > 0.0: + ce_loss = F.cross_entropy( + student_logits_chunk.view(-1, student_logits_chunk.shape[-1]), + target_chunk.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + soft_loss = torch.tensor( + 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype + ) + if weight_soft_loss > 0.0: + student_logits_chunk_temp_scaled = student_logits_chunk / temperature + + # Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max() + # No explicit padding here; user must ensure vocab alignment or pre-pad student_weight. + + soft_loss = distillation_loss_fn( + student_logits_chunk_temp_scaled, + target_token_ids_chunk, + target_logprobs_chunk, + target_mask_chunk, + beta=beta, + normalize_topk=normalize_topk, + ) + + return soft_loss, ce_loss + + @classmethod + def forward( + cls, + ctx, + student_input: torch.Tensor, # [batch_size, seq_len, dim] + student_lm_head_weight: torch.Tensor, # [dim, vocab_size] + target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k] + target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k] + target_mask: torch.Tensor, # [batch_size, seq_len, top_k] + true_labels: torch.Tensor, # [batch_size, seq_len] + student_lm_head_bias: torch.Tensor = None, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + beta: float = 0.0, + compiled: bool = False, + chunk_size: int = 1024, + compute_ce_loss: bool = True, + normalize_topk: bool = True, + ): + CHUNK_SIZE = chunk_size # pylint: disable=invalid-name + grad_weight_acc = torch.zeros_like(student_lm_head_weight) + grad_inputs_list = [] + grad_bias_acc = ( + torch.zeros_like(student_lm_head_bias) + if student_lm_head_bias is not None + else None + ) + kd_loss_acc = torch.zeros( + (), device=student_input.device, dtype=student_input.dtype + ) + ce_loss_acc = torch.zeros( + (), device=student_input.device, dtype=student_input.dtype + ) + + # This function will be what torch.func.grad_and_value differentiates. + # It takes student_input_chunk, student_weight (full), student_bias (full) as primals. + # Other necessary data (target_*, etc.) are passed as non-differentiable arguments. + def loss_fn_for_grad( + _student_input_chunk, + _student_lm_head_weight, # full weight + _student_lm_head_bias, # full bias + # Fixed arguments for a given chunk, not differentiated: + _target_token_ids_chunk, + _target_logprobs_chunk, + _target_mask_chunk, + _true_labels_chunk, + ): + return cls._compute_loss_kl_topk( + student_input_chunk=_student_input_chunk, + student_weight=_student_lm_head_weight, + target_token_ids_chunk=_target_token_ids_chunk, + target_logprobs_chunk=_target_logprobs_chunk, + target_mask_chunk=_target_mask_chunk, + target_chunk=_true_labels_chunk, + student_bias=_student_lm_head_bias, + distillation_loss_fn=cls.distillation_loss_fn, + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + compute_ce_loss=compute_ce_loss, + temperature=temperature, + beta=beta, + normalize_topk=normalize_topk, + ) + + def accumulate_chunk_grads( + student_input_chunk_ac, + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ): + # student_weight and student_bias are closed over from the outer scope (full tensors) + if student_lm_head_bias is not None: + ( + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), + (chunk_kd_loss, chunk_ce_loss), + ) = torch.func.grad_and_value( + loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True + )( + student_input_chunk_ac, + student_lm_head_weight, + student_lm_head_bias, # primals + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ) # non-primals + grad_bias_acc.add_(chunk_grad_bias) + else: + argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight + ( + (chunk_grad_input, chunk_grad_weight), # No grad for bias + (chunk_kd_loss, chunk_ce_loss), + ) = torch.func.grad_and_value( + loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True + )( + student_input_chunk_ac, + student_lm_head_weight, + None, # Pass None for student_bias primal + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ) + + grad_weight_acc.add_(chunk_grad_weight) + kd_loss_acc.add_(chunk_kd_loss) + ce_loss_acc.add_(chunk_ce_loss) + + return chunk_grad_input + + if compiled: + accumulate_chunk_grads_compiled = torch.compile( + accumulate_chunk_grads, dynamic=True, backend="inductor" + ) # dynamic=True often helpful + else: + accumulate_chunk_grads_compiled = accumulate_chunk_grads + + # Use the same chunking logic as LigerFusedLinearDistillationBase.forward + B, N, D = student_input.shape # pylint: disable=invalid-name + K = target_token_ids.shape[-1] # pylint: disable=invalid-name + + student_input_flat = student_input.reshape(-1, student_input.shape[-1]) + target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1]) + target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1]) + target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1]) + # pad and shift for cross entropy loss + true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index) + true_labels_flat = true_labels[:, 1:].contiguous().view(-1) + + num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE) + + _student_input_chunks = torch.chunk( + student_input_flat, chunks=num_chunks, dim=0 + ) + _target_token_ids_chunks = torch.chunk( + target_token_ids_flat, chunks=num_chunks, dim=0 + ) + _target_logprobs_chunks = torch.chunk( + target_logprobs_flat, chunks=num_chunks, dim=0 + ) + _target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0) + _true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0) + + for i in range(num_chunks): + grad_input_chunk = accumulate_chunk_grads_compiled( + _student_input_chunks[i], + _target_token_ids_chunks[i], + _target_logprobs_chunks[i], + _target_mask_chunks[i], + _true_labels_chunks[i], + ) + grad_inputs_list.append(grad_input_chunk) + + grad_inputs_combined = torch.cat(grad_inputs_list, dim=0) + ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc) + + # For matching None returns in backward for non-tensor/non-grad_requiring inputs + ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature + ctx.bias_was_none = student_lm_head_bias is None + ctx.orig_dims = (B, N, D, K) + + # since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum + # we still need to scale the kd_loss by the temp^2 + kd_loss_acc = kd_loss_acc * (temperature**2) + final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc + + return final_loss + + @staticmethod + def backward(ctx, grad_output): + grad_input_flat, grad_weight, grad_bias_maybe = ( + ctx.saved_tensors + ) # grad_input_flat is (B*N, D) + + # Scale gradients by grad_output if it's not 1.0 + if not torch.equal( + grad_output, + torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype), + ): + grad_input_flat = grad_input_flat * grad_output + grad_weight = grad_weight * grad_output + if grad_bias_maybe is not None: + grad_bias_maybe = grad_bias_maybe * grad_output + + # Reshape grad_input_flat to match original student_input shape (B, N, D) + # ctx.orig_dims stores (B, N, D, K) + # We need the first three dimensions for student_input's shape. + # Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors + if ( + ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0 + and grad_input_flat.numel() == 0 + ): + # If original input was empty, gradient should also be empty with correct shape + grad_input_reshaped = torch.zeros( + ctx.orig_dims[0], + ctx.orig_dims[1], + ctx.orig_dims[2], + dtype=grad_input_flat.dtype, + device=grad_input_flat.device, + ) + elif grad_input_flat.numel() == 0 and not ( + ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0 + ): + # This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad) + # but as a safeguard: + grad_input_reshaped = torch.zeros( + ctx.orig_dims[0], + ctx.orig_dims[1], + ctx.orig_dims[2], + dtype=grad_input_flat.dtype, + device=grad_input_flat.device, + ) + else: + grad_input_reshaped = grad_input_flat.view( + ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2] + ) + + nones_for_hyperparams = [None] * ctx.hyperparams_count + grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None + + return ( + grad_input_reshaped, # Gradient for student_input (reshaped) + grad_weight, # Gradient for student_lm_head_weight + None, # Gradient for target_token_ids + None, # Gradient for target_logprobs + None, # Gradient for target_mask + None, # Gradient for true_labels + grad_bias_return, # Gradient for student_lm_head_bias + *nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss + ) + + +class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): + """ + wrapper for chunked top-k logprob kl-d + """ + + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + temperature: float = 1.0, # This is the kd_temperature + beta: float = 1.0, + ignore_index: int = -100, + compiled: bool = True, + chunk_size: int = 1024, + compute_ce_loss: bool = True, + normalize_topk: bool = True, + ): + super().__init__() + if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0): + raise ValueError("Loss weights must be between 0.0 and 1.0.") + if temperature <= 0: + raise ValueError("Temperature must be positive.") + + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.temperature = temperature + self.beta = beta + self.ignore_index = ignore_index + self.compiled = compiled + self.chunk_size = chunk_size + self.compute_ce_loss = compute_ce_loss + self.normalize_topk = normalize_topk + + if not self.compute_ce_loss and self.weight_hard_loss > 0.0: + print( + f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero." + ) + # self.weight_hard_loss = 0.0 # Or let user manage this + if self.weight_soft_loss == 0.0: + print( + "Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed." + ) + + def forward( + self, + lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head + student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head + target_token_ids: torch.Tensor, + target_logprobs: torch.Tensor, + target_mask: torch.Tensor, + true_labels: torch.Tensor, + student_bias: torch.Tensor = None, + ) -> torch.Tensor: + return LigerFusedLinearKLTopKLogprobFunction.apply( + student_hidden_states, + lm_head_weight, + target_token_ids, + target_logprobs, + target_mask, + true_labels, + student_bias, + self.weight_hard_loss, + self.weight_soft_loss, + self.ignore_index, + self.temperature, + self.beta, + self.compiled, + self.chunk_size, + self.compute_ce_loss, + self.normalize_topk, + ) diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py new file mode 100644 index 000000000..5a7c286bc --- /dev/null +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -0,0 +1,97 @@ +""" +model patcher for chunked top-k kl-div +""" + +from typing import Optional, Union, Unpack + +import torch +from transformers import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import LossKwargs + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): + """ + placeholder kwargs for hf model classes + """ + + +def kldiv_forward_llama_like( + self, + input_ids: Optional[torch.LongTensor] = None, + target_logprobs: Optional[torch.Tensor] = None, + target_token_ids: Optional[torch.LongTensor] = None, + target_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument + **kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc] +) -> CausalLMOutputWithPast: + # pylint: disable=duplicate-code + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100 + # self.loss_function should be LigerFusedLinearKLTopKLogprobLoss + + loss = self.loss_function( + self.lm_head.weight, + hidden_states, + target_token_ids, + target_logprobs, + target_mask, + true_labels=labels, + ) + num_items_in_batch = kwargs.pop("num_items_in_batch", -1) + if num_items_in_batch is not None and num_items_in_batch > 0: + loss = loss / num_items_in_batch + + return CausalLMOutputWithPast( + loss=loss, + logits=None, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def apply_kernel(model_type): + # Dynamically import the module and attention class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")]) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + model_cls.forward = kldiv_forward_llama_like diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 3c9515091..74184455f 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -16,40 +16,7 @@ loss for top_k KL divergence """ import torch - - -def zscore_standardize( - logits: torch.Tensor, - mask: torch.Tensor = None, - base_temperature: float = 1.0, - eps: float = 1e-9, -): - """ - Z-score standardize along the last dimension of `logits`. - i.e., for each [B, seq_len] row, across K entries: - z = (logits - mean) / std, - then scale by 1 / base_temperature if desired. - - mask can be broadcastable or None. If None, we standardize all elements. - """ - if mask is None: - # shape: [B, seq_len, K] - # Mean and std over dim=-1 - mean = logits.mean(dim=-1, keepdim=True) - var = logits.var(dim=-1, unbiased=False, keepdim=True) - else: - # If you have to exclude some tokens, multiply by mask, etc. - float_mask = mask.to(logits.dtype) - count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0) - mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count - var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count - - std = torch.sqrt(var.clamp_min(eps)) - z = (logits - mean) / std - - # Scale by 1 / base_temperature - z = z / base_temperature - return z +from torch import nn @torch.jit.script @@ -60,7 +27,6 @@ def loss( target_mask: torch.Tensor, num_items_in_batch: int = -1, # Use -1 to indicate "None" kd_temperature: float = 1.0, - top_k_before_softmax: int = 0, ) -> torch.Tensor: """ A KD loss function that is TorchScript-friendly. @@ -77,8 +43,6 @@ def loss( num_items_in_batch (int, optional): The number of items in the batch. kd_temperature (float, optional): The temperature for KD. Default: 1.0 - top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits - Default: 0 """ target_logprobs = target_logprobs.float() @@ -88,46 +52,24 @@ def loss( # student_logits shape: [B, student_seq_len, vocab_size] teacher_seq_len = target_token_ids.shape[1] - if top_k_before_softmax: - # Slice student logits to match teacher-provided sequence length - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, teacher_seq_len, vocab_size] + # Slice student logits to match teacher-provided sequence length + student_logits_for_kd = ( + student_logits[:, :teacher_seq_len, :] / kd_temperature + ) # [B, teacher_seq_len, vocab_size] - # Gather student logits for teacher's top-K tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, teacher_seq_len, K] + # keep in full precision for numerical stability of loss + student_logits_for_kd = student_logits_for_kd.float() - student_logits_topk = student_logits_topk.float() + # Gather student logits for teacher's top-K tokens + student_logits_topk = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids + ) # [B, teacher_seq_len, K] - # Apply KD temperature to student’s logits - if kd_temperature != 1.0: - student_logits_topk = student_logits_topk / kd_temperature + # Compute logsumexp across full vocabulary + student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True) - # Convert student top-k logits to logprobs - student_logprobs_topk = student_logits_topk - torch.logsumexp( - student_logits_topk, dim=-1, keepdim=True - ) # [B, teacher_seq_len, K] - else: - # Slice student logits to match teacher-provided sequence length - student_logits_for_kd = ( - student_logits[:, :teacher_seq_len, :] / kd_temperature - ) # [B, teacher_seq_len, vocab_size] - - # keep in full precision for numerical stability of loss - student_logits_for_kd = student_logits_for_kd.float() - - # Gather student logits for teacher's top-K tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, teacher_seq_len, K] - - # Compute logsumexp across full vocabulary - student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True) - - # Convert just the top-k logits to logprobs - student_logprobs_topk = student_logits_topk - student_lse + # Convert just the top-k logits to logprobs + student_logprobs_topk = student_logits_topk - student_lse # Convert teacher_mask to boolean for indexing # In TorchScript, .bool() is sometimes unsupported, so we do: @@ -144,10 +86,6 @@ def loss( kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss = kd_loss_per_token.sum() - # Multiply by T^2 (classical KD scaling) - if kd_temperature != 1.0: - kd_loss = kd_loss * (kd_temperature**2) - # Normalize by number of items (if provided) or by valid tokens if num_items_in_batch > 0: kd_loss = kd_loss / float(num_items_in_batch) @@ -158,80 +96,74 @@ def loss( return kd_loss -def topk_kd_loss_with_zscore( - student_logits: torch.Tensor, # [B, seq_len, vocab_size] - target_token_ids: torch.Tensor, # [B, seq_len, K] - target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space - target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len] - kd_temperature: float = 1.0, # classic KD temperature - zscore_base_temp: float = 1.0, # from the paper - num_items_in_batch: int = -1, -): +class ChunkedTopKKDLoss(nn.Module): """ - A variant of top_k KL divergence with Z-score scaling - from "Logit Standardization in Knowledge Distillation". + A wrapper that chunks (splits) the student and teacher outputs along the time dimension + to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies. + + Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs. """ - target_logprobs = target_logprobs.float() + def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0): + super().__init__() + self.num_output_chunks = num_output_chunks + self.kd_temperature = kd_temperature - B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name - # 1) Gather the student's top-k logits to match teacher - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, seq_len, vocab] - student_topk_logits = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, seq_len, K] + def forward( + self, + student_logits: torch.Tensor, # [B, seq_len, vocab_size] + target_token_ids: torch.Tensor, # [B, seq_len, K] + target_logprobs: torch.Tensor, # [B, seq_len, K] + target_mask: torch.Tensor, # [B, seq_len, K] + num_items_in_batch: int = -1, # optional batch size for normalization + ) -> torch.Tensor: - student_topk_logits = student_topk_logits.float() + # 1. Split along the "token" dimension (dim=1). + student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1) + token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1) + logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1) + mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1) - # 2) If you want to keep the "classical" T scaling, apply it first - if kd_temperature != 1.0: - student_topk_logits = student_topk_logits / kd_temperature + # We'll accumulate a global "sum of losses" and "sum of valid tokens" + # so that our final average is consistent with the entire sequence/batch. + total_loss = 0.0 + total_valid_tokens = 0 - # 3) Convert teacher logprobs -> treat them as “logits” for z-score - # (They differ by +some_constant from real logits, but in z-score - # that constant is subtracted out anyway.) - teacher_logits_for_zscore = target_logprobs # rename variable for clarity + # 2. Loop over each chunk and compute a chunk-specific loss. + for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip( + student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks + ): + # We pass num_items_in_batch=-1 so that the kd_loss + # will average over *this chunk's* valid tokens only. + chunk_loss = loss( + student_logits=st_chunk, + target_token_ids=tid_chunk, + target_logprobs=lp_chunk, + target_mask=msk_chunk, + num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens + kd_temperature=self.kd_temperature, + ) - # 4) Z-score teacher and student - # If target_mask is 2D, expand to 3D for the K dimension - if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len): - target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K) + # kd_loss returns an average over the chunk's valid tokens. + # We want a global average in the end, so we need to re‐weight + # by the number of valid tokens in this chunk and keep track of the total. + chunk_valid_mask = msk_chunk.to(torch.bool) + chunk_valid_count = chunk_valid_mask.sum() # scalar tensor - teacher_z = zscore_standardize( - teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp - ) - student_z = zscore_standardize( - student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp - ) + # Re-scale "chunk average" back to "chunk sum" + chunk_loss_sum = chunk_loss * chunk_valid_count - # 5) Convert to log-probs for KL - teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True) - student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True) + total_loss += chunk_loss_sum + total_valid_tokens += chunk_valid_count - # 6) Restrict to valid tokens if needed - valid_mask = target_mask.bool() # shape [B, seq_len, K] - teacher_probs_z = teacher_logprobs_z.exp() - teacher_probs_z = teacher_probs_z[valid_mask] - teacher_logprobs_z = teacher_logprobs_z[valid_mask] - student_logprobs_z = student_logprobs_z[valid_mask] + # 3. Normalize *once* at the end. + if num_items_in_batch > 0: + # If the user gave us a manual denominator (e.g. total items in batch), + # we divide by it. Typically used if each item is of different length. + final_loss = total_loss / float(num_items_in_batch) + else: + # Otherwise, divide by total valid tokens across all chunks. + # to get the same result as a non-chunked approach. + final_loss = total_loss / float(total_valid_tokens) - # 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] ) - kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z) - kd_loss = kd_loss_per_token.sum() - - # 8) If using classical KD scaling by T^2 - if kd_temperature != 1.0: - kd_loss = kd_loss * (kd_temperature**2) - - # Optionally scale by zscore_base_temp**2 if you want (paper might differ). - # kd_loss = kd_loss * (zscore_base_temp**2) - - # 9) Normalize - if num_items_in_batch is not None and num_items_in_batch > 0: - kd_loss = kd_loss / float(num_items_in_batch) - else: - kd_loss = kd_loss / float(kd_loss_per_token.size(0)) - - return kd_loss + return final_loss diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..7ec43333a 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -18,8 +18,7 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer -from .topk_logprob.forward_kl import loss as topk_kd_loss -from .topk_logprob.forward_kl import topk_kd_loss_with_zscore +from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss class AxolotlKDTrainer(AxolotlTrainer): @@ -27,6 +26,18 @@ class AxolotlKDTrainer(AxolotlTrainer): Custom trainer subclass for Knowledge Distillation (KD) """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_accepts_loss_kwargs = True + self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss( + self.args.kd_ce_alpha, # hard label loss + self.args.kd_alpha, # kd loss + self.args.kd_temperature, + self.args.kd_beta or 0.0, + compute_ce_loss=bool(self.args.kd_ce_alpha), + normalize_topk=self.args.kd_normalize_topk, + ) + def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() columns_to_add = [] @@ -52,12 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer): Subclass and override for custom behavior. """ - - target_logprobs = inputs.pop("target_logprobs") - target_token_ids = inputs.pop("target_token_ids") - target_mask = inputs.pop("target_mask") - - seq_len = target_token_ids.shape[1] + if ( + self.args.sample_packing + and hasattr(inputs, "attention_mask") + and hasattr(inputs, "position_ids") + ): + del inputs["attention_mask"] if self.model_accepts_loss_kwargs: loss_kwargs = {} @@ -65,49 +76,4 @@ class AxolotlKDTrainer(AxolotlTrainer): loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) - - # FIXME: account for tokenizer.padding_side - student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous() - - shift_logits = student_logits.contiguous() - target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() - target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() - target_mask_for_loss = target_mask[..., 1:, :].contiguous() - - if self.args.kd_zscore_base_temp: - loss_kd = topk_kd_loss_with_zscore( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - kd_temperature=self.args.kd_temperature, - zscore_base_temp=self.args.kd_zscore_base_temp, - num_items_in_batch=num_items_in_batch, - ) - else: - loss_kd = topk_kd_loss( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - num_items_in_batch=num_items_in_batch, - kd_temperature=self.args.kd_temperature, - top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, - ) - - if self.args.kd_ce_alpha > 0: - kd_alpha = self.args.kd_alpha - loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd - else: - loss = loss_kd - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[ # pylint: disable=attribute-defined-outside-init - self.args.past_index - ] - - if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: - loss *= self.accelerator.num_processes - - return (loss, outputs) if return_outputs else loss + return outputs[0] diff --git a/src/axolotl/integrations/kd/utils.py b/src/axolotl/integrations/kd/utils.py new file mode 100644 index 000000000..ba60694a5 --- /dev/null +++ b/src/axolotl/integrations/kd/utils.py @@ -0,0 +1,100 @@ +"""Helper KD utils""" + +import math +from typing import List, Union + +import numpy as np +import torch +from torch import FloatTensor, Tensor + + +def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor: + """ + Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. + """ + # Ensure raw_logprobs matches kd_online_topk length for tensor operations + # This should ideally be handled by the caller ensuring correct padding/truncation first + if logprobs.shape[-1] != topk: + # pad last dimension of logprobs to match topk length with -inf + padding_len = topk - logprobs.shape[-1] + padding_tensor = torch.full( + ( + *logprobs.shape[:-1], + padding_len, + ), # Takes all dimensions of logprobs except the last, then appends padding_needed + float("-inf"), + dtype=logprobs.dtype, + device=logprobs.device, + ) + logprobs = torch.cat((logprobs, padding_tensor), dim=-1) + + # Convert logprobs at T_online to probabilities + # use log sum exp trick to avoid underflow + position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True) + teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse) + + # Normalize probabilities (sum to 1) + # This is important if the top-k from server aren't a full distribution + teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True) + teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum + + final_logprobs_tensor = torch.log(teacher_probs_t_online) + + return final_logprobs_tensor + + +def strided_chunk_views( + tensor: Union[np.ndarray, torch.Tensor], + chunks: int, + dim: int = 0, + stride: int = 1, + chunk_size: int | None = None, +) -> List[Union[np.ndarray, torch.Tensor]]: + """ + Split a tensor into chunks along a dimension with striding, prioritizing views over copies. + + Args: + tensor: Input tensor (numpy array or torch tensor) + chunks: Number of chunks to create + dim: Dimension along which to chunk (default: 0) + stride: Stride between chunk starting positions (default: 1) + chunk_size: Size of each chunk. If None, calculated automatically (default: None) + + Returns: + List of tensor chunks (views when possible, copies when necessary) + """ + + # Get the size of the specified dimension + dim_size = tensor.shape[dim] + + # Calculate chunk size if not provided + if chunk_size is None: + chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division + + chunks_list = [] + + for i in range(chunks): + start_idx = i * stride + end_idx = min(start_idx + chunk_size, dim_size) + + # Break if we've gone beyond the tensor + if start_idx >= dim_size: + break + + # Create slice objects for all dimensions + slices = [slice(None)] * tensor.ndim + slices[dim] = slice(start_idx, end_idx) + + chunk = tensor[tuple(slices)] + chunks_list.append(chunk) + + return chunks_list + + +def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1): + dim_size = input_tensor.shape[dim] + stride = math.ceil(dim_size / chunks) + + return strided_chunk_views( + input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap + ) diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index b8dc6479d..f570e013c 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -25,12 +25,20 @@ class AxolotlOrWarnErrorFilter(logging.Filter): def __init__(self, **kwargs): super().__init__(**kwargs) - self.axolotl_level = logging.getLevelNamesMapping()[ - os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL) - ] - self.other_level = logging.getLevelNamesMapping()[ - os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL) - ] + axolotl_log_level = os.getenv( + "AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL + ).upper() + other_log_level = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper() + + try: + # py311+ only + level_mapping = logging.getLevelNamesMapping() + self.axolotl_level = level_mapping[axolotl_log_level] + self.other_level = level_mapping[other_log_level] + except AttributeError: + # For py310, use getLevelName directly + self.axolotl_level = logging.getLevelName(axolotl_log_level) + self.other_level = logging.getLevelName(other_log_level) def filter(self, record: LogRecord) -> bool: # General filter diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 3cdbbb6f3..cf936481e 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -17,7 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): return messages_load(tokenizer, cfg, ds_cfg, processor=processor) load_fn = "load" package = "axolotl.prompt_strategies" - if strategy.split(".")[-1].startswith("load_"): + if ( + strategy.split(".")[-1].startswith("load_") + or strategy.split(".")[-1] == "load" + ): load_fn = strategy.split(".")[-1] strategy = ".".join(strategy.split(".")[:-1]) elif len(strategy.split(".")) > 1: diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 4a358928e..0271fca24 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -596,11 +596,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if ( turn_idx == 0 and turns[0].get("role") == "system" - and ( - "mistral" in self.tokenizer.name_or_path.lower() - or "gemma" - in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer - ) + and ("mistral" in self.tokenizer.name_or_path.lower()) ): return -1, -1 diff --git a/src/axolotl/prompt_strategies/jinja_template_analyzer.py b/src/axolotl/prompt_strategies/jinja_template_analyzer.py index a5f89cfe5..e16a1e22b 100644 --- a/src/axolotl/prompt_strategies/jinja_template_analyzer.py +++ b/src/axolotl/prompt_strategies/jinja_template_analyzer.py @@ -3,6 +3,7 @@ from typing import Dict, Optional, Set, TypedDict, Union from jinja2 import Environment, meta, nodes +from jinja2.ext import Extension class JinjaTemplateAnalysis(TypedDict): @@ -27,6 +28,18 @@ class JinjaTemplateAnalysis(TypedDict): iteration_target: Optional[Union[str, list[str]]] +class GenerationTagIgnore(Extension): + """ + Ignores the generation and endgeneration tags in Jinja templates. + """ + + tags = {"generation", "endgeneration"} + + def parse(self, parser): + parser.stream.skip(1) + return nodes.Const("") + + class JinjaTemplateAnalyzer: """ Analyzes Jinja templates to extract information about variable usage, @@ -57,7 +70,9 @@ class JinjaTemplateAnalyzer: """ def __init__(self, template: str): - self.env: Environment = Environment(autoescape=True) + self.env: Environment = Environment( + autoescape=True, extensions=[GenerationTagIgnore] + ) self.property_access: Dict[str, Set[str]] = {} self.iteration_targets: Dict[str, Union[str, list[str]]] = {} self.index_access: Dict[str, Set[Union[int, float]]] = {} diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 13ac8ec0d..fa7d56913 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,10 +1,13 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +from __future__ import annotations + import importlib import inspect import os import signal import sys +import typing import weakref from contextlib import ExitStack from pathlib import Path @@ -25,7 +28,6 @@ from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) -from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.integrations.base import PluginManager from axolotl.loaders import ( ModelLoader, @@ -45,6 +47,9 @@ try: except ImportError: BetterTransformer = None +if typing.TYPE_CHECKING: + from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder + LOG = get_logger(__name__) @@ -472,7 +477,7 @@ def handle_untrained_tokens_fix( def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[ - HFRLTrainerBuilder | HFCausalTrainerBuilder, + "HFRLTrainerBuilder" | "HFCausalTrainerBuilder", PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 3d0ba7c9c..e669413f8 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -52,3 +52,10 @@ def patch_optimized_env(): if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None: os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" set_pytorch_cuda_alloc_conf() + + +def get_not_null(value, default=None): + """ + return the value if it's not None, otherwise return the default value + """ + return value if value is not None else default diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 8b8a77611..2a93ceef5 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -53,25 +53,6 @@ IGNORE_INDEX = -100 LOG = get_logger(__name__) -class EvalFirstStepCallback( - TrainerCallback -): # pylint: disable=too-few-public-methods disable=unused-argument - """ - Callback to trigger evals on the first step - """ - - def on_step_end( - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1: - control.should_evaluate = True - return control - - class SaveBetterTransformerModelCallback( TrainerCallback ): # pylint: disable=too-few-public-methods diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index bf496d2c5..83a42945b 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -32,11 +32,12 @@ _CHAT_TEMPLATES = { "llava": "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}", "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}", + "phi_4": "{% set system_message = 'You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: {Thought section} {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:' -%}{%- if messages and messages[0]['role'] == 'system' -%}{%- set system_message = messages[0]['content'] -%}{%- set messages = messages[1:] -%}{%- endif -%}<|im_start|>system<|im_sep|>{{ system_message }}<|im_end|>{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'assistant') %}{{'<|im_start|>assistant<|im_sep|>'}}{% generation %}{{message['content'] + '<|im_end|>'}}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|im_sep|>' }}{% endif %}", "deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}", "deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}", "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", - "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}", + "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- else %}\n {{- '\\n\\n' }}\n {%- endif %}\n{%- endif %}", "exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}", "metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}", "pixtral": '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n{%- endif %}\n{%- endfor %}\n{{- eos_token }}\n{%- else %}\n{{- message["content"] + eos_token }}\n{%- endif %}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}', diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index d8414d117..a28f360be 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,7 +1,7 @@ """Data collators for axolotl to pad labels and position_ids for packed sequences""" from dataclasses import dataclass -from typing import Any +from typing import Any, List import numpy as np from transformers import PreTrainedTokenizerBase @@ -163,7 +163,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): def __call__(self, features, return_tensors=None): if not isinstance(features[0], list): - features = [features] + features: List[List[dict]] = [features] out_features = [{} for _ in features] for i, features_ in enumerate(features): for feature in features_[0].keys(): diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 0ffaa932f..4f7f6f8dd 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -51,6 +51,7 @@ def retry_on_request_exceptions( except ( requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError, + requests.exceptions.HTTPError, huggingface_hub.errors.HfHubHTTPError, ) as exc: if attempt < max_retries - 1: diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index eabfc2d84..7fb5e1b41 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -259,7 +259,7 @@ class MultipackBatchSampler(BatchSampler): batch_max_len: int, # Maximum sequence length (bin capacity) lengths: np.ndarray, # Sequence lengths packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate - drop_last: bool = False, # Whether to drop final batches (might be incomplete) + drop_last: bool = True, # Whether to drop final batches (might be incomplete) num_count_samples: int = 8, # Number of times to estimate batch count sequential: bool = False, # Whether to use sequential packing group_size: int = 100_000, # Size of groups for parallel packing @@ -446,10 +446,18 @@ class MultipackBatchSampler(BatchSampler): if self._len_across_ranks is None: # Sample multiple times to get stable estimate - len_batches = min( # pylint: disable=consider-using-generator - [len(self._batches) for _ in range(self.num_count_samples)] - ) + _sampled_lens = [] + for _ in range(self.num_count_samples): + self._batches = None # Reset cached batches + _sampled_lens.append(len(self.generate_batches(set_stats=False))) + len_batches = min(_sampled_lens) + # Gather minimum across all ranks - self._len_across_ranks = self.gather_len_batches(len_batches) + if self._len_across_ranks is None: + self._len_across_ranks = self.gather_len_batches(len_batches) + else: + self._len_across_ranks = min( + self._len_across_ranks, self.gather_len_batches(len_batches) + ) return self._len_across_ranks diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 33a8f77db..259daa56f 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -12,10 +12,8 @@ from pydantic import ( Field, StringConstraints, field_serializer, - field_validator, model_validator, ) -from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import ( @@ -47,14 +45,13 @@ from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig from axolotl.utils.schemas.quantization import PTQConfig, QATConfig from axolotl.utils.schemas.training import HyperparametersConfig from axolotl.utils.schemas.trl import TRLConfig +from axolotl.utils.schemas.validation import ValidationMixin from axolotl.utils.schemas.vllm import VllmConfig LOG = get_logger(__name__, use_environ=True) -SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} - -# pylint: disable=too-many-public-methods,too-many-ancestors +# pylint: disable=too-many-ancestors class AxolotlInputConfig( ModelInputConfig, ModelOutputConfig, @@ -70,22 +67,54 @@ class AxolotlInputConfig( MultiModalConfig, RemappedParameters, DeprecatedParameters, + ValidationMixin, BaseModel, ): - """Wrapper of all config options""" + """Wrapper of all config options.""" model_config = {"populate_by_name": True} - strict: bool | None = Field(default=False) - resume_from_checkpoint: str | None = None - auto_resume_from_checkpoints: bool | None = None - resize_token_embeddings_to_32x: bool | None = None + strict: bool | None = Field( + default=False, + json_schema_extra={"description": "Allow overwrite yml config using from cli"}, + ) + resume_from_checkpoint: str | None = Field( + default=None, + json_schema_extra={"description": "Resume from a specific checkpoint dir"}, + ) + auto_resume_from_checkpoints: bool | None = Field( + default=None, + json_schema_extra={ + "description": "If resume_from_checkpoint isn't set and you simply want it to start where it left off. Be careful with this being turned on between different models." + }, + ) + resize_token_embeddings_to_32x: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Resize the model embeddings when new tokens are added to multiples of 32. This is reported to improve training speed on some models" + }, + ) mean_resizing_embeddings: bool | None = False # optionally shrink the embeddings when the tokenizer vocab size is smaller - shrink_embeddings: bool | None = None - embeddings_skip_upcast: bool | None = None + shrink_embeddings: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink." + }, + ) + embeddings_skip_upcast: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs" + }, + ) - rl: RLType | None = None + rl: RLType | None = Field( + default=None, + json_schema_extra={ + "description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'" + }, + ) trl: TRLConfig | None = Field( default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda ) @@ -94,12 +123,25 @@ class AxolotlInputConfig( ) qat: QATConfig | None = None quantization: PTQConfig | None = None - reward_model: bool | None = None - process_reward_model: bool | None = None + reward_model: bool | None = Field( + default=None, + json_schema_extra={"description": "Reward modelling: `True` or `False`"}, + ) + process_reward_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Process reward modelling: `True` or `False`" + }, + ) num_labels: int | None = None # Whether to use weighting in DPO trainer. # If `None`, default is `False` in the trainer. - dpo_use_weighting: bool | None = None + dpo_use_weighting: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to perform weighting in DPO trainer" + }, + ) dpo_use_logits_to_keep: bool | None = None dpo_label_smoothing: float | None = None dpo_norm_loss: bool | None = None @@ -111,7 +153,12 @@ class AxolotlInputConfig( MinLen(1), ] | None - ) = None + ) = Field( + default=None, + json_schema_extra={ + "description": "A list of one or more datasets to finetune the model with" + }, + ) test_datasets: ( Annotated[ @@ -119,22 +166,59 @@ class AxolotlInputConfig( MinLen(1), ] | None - ) = None - shuffle_merged_datasets: bool | None = True - dataset_prepared_path: str | None = None - dataset_shard_num: int | None = None - dataset_shard_idx: int | None = None + ) = Field( + default=None, + json_schema_extra={ + "description": "A list of one or more datasets to eval the model with. You can use either test_datasets, or val_set_size, but not both." + }, + ) + shuffle_merged_datasets: bool | None = Field( + default=True, + json_schema_extra={ + "description": "If false, the datasets will not be shuffled and will keep their original order in `datasets`. The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true." + }, + ) + dataset_prepared_path: str | None = Field( + default=None, + json_schema_extra={ + "description": "Axolotl attempts to save the dataset as an arrow after packing the data together so subsequent training attempts load faster, relative path" + }, + ) + dataset_shard_num: int | None = Field( + default=None, json_schema_extra={"description": "Num shards for whole dataset"} + ) + dataset_shard_idx: int | None = Field( + default=None, + json_schema_extra={"description": "Index of shard to use for whole dataset"}, + ) skip_prepare_dataset: bool | None = False pretraining_dataset: ( Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None ) = Field( default=None, - json_schema_extra={"description": "streaming dataset to use for pretraining"}, + json_schema_extra={ + "description": "Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize" + }, + ) + dataset_processes: int | None = Field( + default=min(32, os.cpu_count()), # type: ignore[type-var] + json_schema_extra={ + "description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set." + }, + ) + dataset_exact_deduplication: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Deduplicates datasets and test_datasets with identical entries" + }, + ) + dataset_keep_in_memory: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Keep dataset in memory while preprocessing. Only needed if cached dataset is taking too much storage" + }, ) - dataset_processes: int | None = Field(default=min(32, os.cpu_count() or 1)) - dataset_exact_deduplication: bool | None = None - dataset_keep_in_memory: bool | None = None dataloader_pin_memory: bool | None = None dataloader_num_workers: int | None = None dataloader_prefetch_factor: int | None = None @@ -144,75 +228,203 @@ class AxolotlInputConfig( remove_unused_columns: bool | None = None - push_dataset_to_hub: str | None = None - hf_use_auth_token: bool | None = None + push_dataset_to_hub: str | None = Field( + default=None, + json_schema_extra={ + "description": "Push prepared dataset to hub - repo_org/repo_name" + }, + ) + hf_use_auth_token: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets. Required to be true when used in combination with `push_dataset_to_hub`" + }, + ) device: Any | None = None - device_map: Any | None = None + device_map: Any | None = Field( + default=None, + json_schema_extra={ + "description": "Passed through to transformers when loading the model when launched without accelerate. Use `sequential` when training w/ model parallelism to limit memory" + }, + ) world_size: int | None = None - local_rank: int | None = None + local_rank: int | None = Field( + default=None, + json_schema_extra={ + "description": "Don't mess with this, it's here for accelerate and torchrun" + }, + ) ddp: bool | None = None - seed: int | None = None - ddp_timeout: int | None = None - ddp_bucket_cap_mb: int | None = None - ddp_broadcast_buffers: bool | None = None + seed: int | None = Field( + default=None, json_schema_extra={"description": "Seed for reproducibility"} + ) + ddp_timeout: int | None = Field( + default=None, + json_schema_extra={"description": "Advanced DDP Arguments - timeout"}, + ) + ddp_bucket_cap_mb: int | None = Field( + default=None, + json_schema_extra={"description": "Advanced DDP Arguments - bucket cap in MB"}, + ) + ddp_broadcast_buffers: bool | None = Field( + default=None, + json_schema_extra={"description": "Advanced DDP Arguments - broadcast buffers"}, + ) ddp_find_unused_parameters: bool | None = None - eval_table_size: int | None = None - eval_max_new_tokens: int | None = None - do_causal_lm_eval: bool | None = None - eval_causal_lm_metrics: list[str] | None = None + eval_table_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0" + }, + ) + eval_max_new_tokens: int | None = Field( + default=None, + json_schema_extra={ + "description": "Total number of tokens generated for predictions sent to wandb. Default is 128" + }, + ) + do_causal_lm_eval: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`" + }, + ) + eval_causal_lm_metrics: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "HF evaluate metrics used during evaluation. Default is ['sacrebleu', 'comet', 'ter', 'chrf', 'perplexity']" + }, + ) do_bench_eval: bool | None = None bench_dataset: str | None = None bench_split: str | None = None metric_for_best_model: str | None = None greater_is_better: bool | None = None - loss_watchdog_threshold: float | None = None - loss_watchdog_patience: int | None = None + loss_watchdog_threshold: float | None = Field( + default=None, + json_schema_extra={ + "description": "High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)" + }, + ) + loss_watchdog_patience: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of high-loss steps in a row before the trainer aborts (default: 3)" + }, + ) gc_steps: int | None = None - bf16: Literal["auto"] | bool | None = "auto" - fp16: bool | None = None + bf16: Literal["auto"] | bool | None = Field( + default="auto", + json_schema_extra={ + "description": "Use CUDA bf16. bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere" + }, + ) + fp16: bool | None = Field( + default=None, json_schema_extra={"description": "Use CUDA fp16"} + ) fp8: bool | None = None - bfloat16: bool | None = None # for non-AMP cases - float16: bool | None = None # for non-AMP cases - tf32: bool | None = None + bfloat16: bool | None = Field( + default=None, + json_schema_extra={ + "description": "No AMP (automatic mixed precision) - require >=ampere" + }, + ) # for non-AMP cases + float16: bool | None = Field( + default=None, + json_schema_extra={"description": "No AMP (automatic mixed precision)"}, + ) # for non-AMP cases + tf32: bool | None = Field( + default=None, + json_schema_extra={"description": "Use CUDA tf32 - require >=ampere"}, + ) float32: bool | None = None - # torch_dtype: torch.dtype | None - gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field( - default=False + default=False, + json_schema_extra={ + "description": "Whether to use gradient checkpointing. Available options are: true, false, 'offload', 'offload_disk'. https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing" + }, + ) + gradient_checkpointing_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Additional kwargs to pass to the trainer for gradient checkpointing" + }, ) - gradient_checkpointing_kwargs: dict[str, Any] | None = None unfrozen_parameters: list[str] | None = None - sequence_len: int = Field(default=512) + sequence_len: int = Field( + default=512, + json_schema_extra={ + "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" + }, + ) min_sample_len: int | None = None max_prompt_len: int = Field( default=512, json_schema_extra={"description": "maximum prompt length for RL training"}, ) - sample_packing: bool | None = None - sample_packing_group_size: int | None = 100_000 - sample_packing_bin_size: int | None = 200 - sample_packing_sequentially: bool | None = None - eval_sample_packing: bool | None = None - pad_to_sequence_len: bool | None = None - curriculum_sampling: bool | None = None + sample_packing: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'" + }, + ) + sample_packing_group_size: int | None = Field( + default=100_000, + json_schema_extra={ + "description": "The number of samples packed at a time. Increasing the following values helps with packing, but usually only slightly (<%1.)" + }, + ) + sample_packing_bin_size: int | None = Field( + default=200, + json_schema_extra={ + "description": "The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples." + }, + ) + sample_packing_sequentially: bool | None = Field( + default=None, + json_schema_extra={"description": "Whether to pack samples sequentially"}, + ) + eval_sample_packing: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Set to 'false' if getting errors during eval with sample_packing on" + }, + ) + pad_to_sequence_len: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently" + }, + ) + curriculum_sampling: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use sequential sampling for curriculum learning" + }, + ) multipack_real_batches: bool | None = None pretraining_sample_concatenation: bool | None = Field( default=None, json_schema_extra={ - "description": "whether to soft pack/concatenate samples during pretraining", + "description": "whether to concatenate samples during pretraining", }, ) - batch_flattening: Literal["auto"] | bool | None = None + batch_flattening: Literal["auto"] | bool | None = Field( + default=None, + json_schema_extra={ + "description": "Use batch flattening for speedups when not using sample_packing" + }, + ) # for PoSE context length extension use_pose: bool | None = None @@ -228,17 +440,60 @@ class AxolotlInputConfig( }, ) - xformers_attention: bool | None = None - sdp_attention: bool | None = None - s2_attention: bool | None = None + xformers_attention: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use xformers attention patch https://github.com/facebookresearch/xformers" + }, + ) + sdp_attention: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use scaled-dot-product attention https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" + }, + ) + s2_attention: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf" + }, + ) flex_attention: bool | None = None flex_attn_compile_kwargs: dict[str, Any] | None = None - flash_attention: bool | None = None - flash_attn_cross_entropy: bool | None = None - flash_attn_rms_norm: bool | None = None - flash_attn_fuse_qkv: bool | None = None - flash_attn_fuse_mlp: bool | None = None - flash_optimum: bool | None = None + flash_attention: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention" + }, + ) + flash_attn_cross_entropy: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use flash-attention cross entropy implementation - advanced use only" + }, + ) + flash_attn_rms_norm: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use flash-attention rms norm implementation - advanced use only" + }, + ) + flash_attn_fuse_qkv: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to fuse QKV into a single operation" + }, + ) + flash_attn_fuse_mlp: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to fuse part of the MLP into a single operation" + }, + ) + flash_optimum: bool | None = Field( + default=None, + json_schema_extra={"description": "Whether to use bettertransformers"}, + ) eager_attention: bool | None = None @@ -249,76 +504,273 @@ class AxolotlInputConfig( unsloth_rms_norm: bool | None = None unsloth_rope: bool | None = None - lora_mlp_kernel: bool | None = None - lora_qkv_kernel: bool | None = None - lora_o_kernel: bool | None = None + lora_mlp_kernel: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html" + }, + ) + lora_qkv_kernel: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html" + }, + ) + lora_o_kernel: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html" + }, + ) llama4_linearized_experts: bool | None = None - deepspeed: str | dict[str, Any] | None = None - fsdp: list[str] | None = None - fsdp_config: dict[str, Any] | None = None + deepspeed: str | dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json" + }, + ) + fsdp: list[str] | None = Field( + default=None, json_schema_extra={"description": "FSDP configuration"} + ) + fsdp_config: dict[str, Any] | None = Field( + default=None, json_schema_extra={"description": "FSDP configuration options"} + ) fsdp_final_state_dict_type: ( Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None ) = None - val_set_size: float | None = Field(default=0.0) + val_set_size: float | None = Field( + default=0.0, + json_schema_extra={ + "description": "How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval." + }, + ) - sequence_parallel_degree: int | None = None - heads_k_stride: int | None = None - ring_attn_func: RingAttnFunc | None = None + sequence_parallel_degree: int | None = Field( + default=None, + json_schema_extra={ + "description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details." + }, + ) + heads_k_stride: int | None = Field( + default=None, + json_schema_extra={ + "description": "Optional; strides across the key dimension. Larger values use more memory but should make training faster. Must evenly divide the number of KV heads in your model." + }, + ) + ring_attn_func: RingAttnFunc | None = Field( + default=None, + json_schema_extra={ + "description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case." + }, + ) - special_tokens: SpecialTokensConfig | None = None - tokens: list[str] | None = None - added_tokens_overrides: dict[int, str] | None = None + special_tokens: SpecialTokensConfig | None = Field( + default=None, + json_schema_extra={ + "description": "Add or change special tokens. If you add tokens here, you don't need to add them to the `tokens` list." + }, + ) + tokens: list[str] | None = Field( + default=None, + json_schema_extra={"description": "Add extra tokens to the tokenizer"}, + ) + added_tokens_overrides: dict[int, str] | None = Field( + default=None, + json_schema_extra={ + "description": "Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer. Only works for tokens that are not part of the base vocab (aka are added_tokens). Can be checked if they exist in tokenizer.json added_tokens." + }, + ) - torch_compile: Literal["auto"] | bool | None = None - torch_compile_backend: str | None = None + torch_compile: Literal["auto"] | bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.5.1" + }, + ) + torch_compile_backend: str | None = Field( + default=None, + json_schema_extra={"description": "Backend to use for torch.compile"}, + ) torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = ( None ) - max_steps: int | None = None - warmup_steps: int | None = None - warmup_ratio: float | None = None - eval_steps: int | float | None = None - evals_per_epoch: int | None = None - eval_strategy: str | None = None - save_steps: int | float | None = None - saves_per_epoch: int | None = None - save_strategy: str | None = None - save_total_limit: int | None = None - logging_steps: int | None = None - early_stopping_patience: int | None = None + max_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Maximum number of iterations to train for. It precedes num_epochs which means that if both are set, num_epochs will not be guaranteed. e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps" + }, + ) + warmup_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of warmup steps. Cannot use with warmup_ratio" + }, + ) + warmup_ratio: float | None = Field( + default=None, + json_schema_extra={"description": "Warmup ratio. Cannot use with warmup_steps"}, + ) + eval_steps: int | float | None = Field( + default=None, + json_schema_extra={ + "description": "Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps" + }, + ) + evals_per_epoch: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of times per epoch to run evals, mutually exclusive with eval_steps" + }, + ) + eval_strategy: str | None = Field( + default=None, + json_schema_extra={ + "description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`" + }, + ) + save_steps: int | float | None = Field( + default=None, + json_schema_extra={ + "description": "Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps" + }, + ) + saves_per_epoch: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of times per epoch to save a checkpoint, mutually exclusive with save_steps" + }, + ) + save_strategy: str | None = Field( + default=None, + json_schema_extra={ + "description": "Set to `no` to skip checkpoint saves, `epoch` at end of each epoch, `best` when better result is achieved, leave empty to infer from `save_steps`" + }, + ) + save_total_limit: int | None = Field( + default=None, json_schema_extra={"description": "Checkpoints saved at a time"} + ) + logging_steps: int | None = Field( + default=None, json_schema_extra={"description": "Logging frequency"} + ) + early_stopping_patience: int | None = Field( + default=None, + json_schema_extra={ + "description": "Stop training after this many evaluation losses have increased in a row. https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback" + }, + ) load_best_model_at_end: bool | None = False - save_only_model: bool | None = False - use_tensorboard: bool | None = None - profiler_steps: int | None = None - include_tokens_per_second: bool | None = None + save_only_model: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints." + }, + ) + use_tensorboard: bool | None = Field( + default=None, json_schema_extra={"description": "Use tensorboard for logging"} + ) + profiler_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz" + }, + ) + include_tokens_per_second: bool | None = Field( + default=None, + json_schema_extra={ + "description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time." + }, + ) - neftune_noise_alpha: float | None = None + neftune_noise_alpha: float | None = Field( + default=None, + json_schema_extra={ + "description": "NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings. Currently only supported on Llama and Mistral" + }, + ) - orpo_alpha: float | None = None - rpo_alpha: float | None = None - simpo_gamma: float | None = None - cpo_alpha: float | None = None + orpo_alpha: float | None = Field( + default=None, + json_schema_extra={ + "description": "Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping." + }, + ) + rpo_alpha: float | None = Field( + default=None, + json_schema_extra={ + "description": "Weighting of NLL term in loss from RPO paper" + }, + ) + simpo_gamma: float | None = Field( + default=None, + json_schema_extra={"description": "Target reward margin for the SimPO loss"}, + ) + cpo_alpha: float | None = Field( + default=None, json_schema_extra={"description": "Weight of the BC regularizer"} + ) - kto_desirable_weight: float | None = None - kto_undesirable_weight: float | None = None - rl_beta: float | None = None + kto_desirable_weight: float | None = Field( + default=None, + json_schema_extra={"description": "Factor for desirable loss term in KTO loss"}, + ) + kto_undesirable_weight: float | None = Field( + default=None, + json_schema_extra={ + "description": "Factor for undesirable loss term in KTO loss" + }, + ) + rl_beta: float | None = Field( + default=None, + json_schema_extra={"description": "The beta parameter for the RL training"}, + ) - max_memory: dict[int | Literal["cpu", "disk"], int | str] | None = None - gpu_memory_limit: int | str | None = None - low_cpu_mem_usage: bool | None = None + max_memory: dict[int | Literal["cpu", "disk"], int | str] | None = Field( + default=None, + json_schema_extra={ + "description": "Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model." + }, + ) + gpu_memory_limit: int | str | None = Field( + default=None, + json_schema_extra={ + "description": "Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset" + }, + ) + low_cpu_mem_usage: bool | None = Field( + default=None, + json_schema_extra={"description": "Whether to use low_cpu_mem_usage"}, + ) chat_template: ( ChatTemplate | Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")] - ) | None = None - chat_template_jinja: str | None = None - chat_template_kwargs: dict[str, Any] | None = None - eot_tokens: list[str] | None = None - default_system_message: str | None = None + ) | None = Field( + default=None, + json_schema_extra={ + "description": "The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. The selected chat template will be saved to the tokenizer_config.json for easier inferencing" + }, + ) + chat_template_jinja: str | None = Field( + default=None, + json_schema_extra={ + "description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null." + }, + ) + eot_tokens: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "Custom EOT (End-of-Turn) tokens to mask/unmask during training. These tokens mark the boundaries between conversation turns. For example: ['/INST', '', '[/SYSTEM_PROMPT]']. If not specified, defaults to just the model's eos_token. This is useful for templates that use multiple delimiter tokens." + }, + ) + default_system_message: str | None = Field( + default=None, + json_schema_extra={ + "description": "Changes the default system message. Currently only supports chatml." + }, + ) fix_untrained_tokens: int | list[int] | None = None @@ -326,49 +778,50 @@ class AxolotlInputConfig( is_preprocess: bool | None = None preprocess_iterable: bool | None = None - total_num_tokens: int | None = None + total_num_tokens: int | None = Field( + default=None, + json_schema_extra={"description": "Total number of tokens - internal use"}, + ) total_supervised_tokens: int | None = None - sample_packing_eff_est: float | None = None + sample_packing_eff_est: float | None = Field( + default=None, + json_schema_extra={ + "description": "You can set these packing optimizations AFTER starting a training at least once. The trainer will provide recommended values for these values." + }, + ) axolotl_config_path: str | None = None - is_falcon_derived_model: bool | None = Field(default=None) - is_llama_derived_model: bool | None = Field(default=None) - is_mistral_derived_model: bool | None = Field(default=None) - is_qwen_derived_model: bool | None = Field(default=None) + is_falcon_derived_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Internal use only - Used to identify which the model is based on" + }, + ) + is_llama_derived_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Internal use only - Used to identify which the model is based on" + }, + ) + is_mistral_derived_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Internal use only - Used to identify which the model is based on. Please note that if you set this to true, `padding_side` will be set to 'left' by default" + }, + ) + is_qwen_derived_model: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Internal use only - Used to identify which the model is based on" + }, + ) - plugins: list[str] | None = Field(default=None) - - @field_validator("seed", mode="after") - @classmethod - def set_default_seed(cls, seed): - if seed is None: - LOG.info("`seed` not set in config; setting to 42") - seed = 42 - return seed - - @field_validator("datasets", mode="before") - @classmethod - def deprecate_sharegpt_datasets(cls, datasets): - for _, ds_cfg in enumerate(datasets): - # Handle both dict and pydantic model cases - ds_type = ( - ds_cfg.get("type") - if isinstance(ds_cfg, dict) - else getattr(ds_cfg, "type", None) - ) - if not ds_type: - continue - - # skip if it's a dict (for custom user instruction prompt) - if isinstance(ds_type, dict): - continue - - if isinstance(ds_type, str) and ds_type.startswith("sharegpt"): - raise ValueError( - "`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead." - ) - - return datasets + plugins: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html" + }, + ) @field_serializer("datasets") def datasets_serializer( @@ -378,960 +831,9 @@ class AxolotlInputConfig( return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None - @model_validator(mode="before") - @classmethod - def check_attention_fields(cls, data): - fields = ( - "xformers_attention", - "sdp_attention", - "s2_attention", - "flash_attention", - "flex_attention", - ) - non_empty_count = sum(1 for field in fields if data.get(field)) - - if non_empty_count > 1: - raise ValueError(f"Only one of {', '.join(fields)} must be set") - return data - - @model_validator(mode="before") - @classmethod - def check_batch_size_fields(cls, data): - fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size") - non_empty_count = sum(1 for field in fields if data.get(field)) - - if non_empty_count < 2: - raise ValueError(f"At least two of {', '.join(fields)} must be set") - return data - - @model_validator(mode="before") - @classmethod - def check_pretraining_w_max_steps(cls, data): - if data.get("pretraining_dataset") and not data.get("max_steps"): - raise ValueError( - "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_pretraining_w_group_by_length(cls, data): - if data.get("pretraining_dataset") and data.get("group_by_length"): - LOG.warning( - "You probably want to disable group_by_length as it will force a streamed dataset to download completely." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_pretraining_split_batches_accelerate(cls, data): - # alternatively set ACCELERATE_SPLIT_BATCHES=False - if data.get("pretraining_dataset"): - accelerator_config = data.get("accelerator_config", {}) - if not accelerator_config: - data["accelerator_config"] = { - "split_batches": False, - "dispatch_batches": False, - } - else: - if accelerator_config.get("split_batches") is None: - data["accelerator_config"]["split_batches"] = False - if accelerator_config.get("dispatch_batches") is None: - data["accelerator_config"]["dispatch_batches"] = False - return data - - @model_validator(mode="before") - @classmethod - def check_gptq_w_revision(cls, data): - if data.get("gptq") and data.get("revision_of_model"): - raise ValueError( - "revision_of_model is not supported for GPTQ models. " - + "Please download the model from HuggingFace Hub manually for correct branch, " - + "point to its path, and remove revision_of_model from the config." - ) - return data - - @model_validator(mode="before") - @classmethod - # pylint: disable=duplicate-code - def check_chat_template_config(cls, data): - # if chat_template is set to jinja, chat_template_jinja is required - if data.get("chat_template") == ChatTemplate.jinja and not data.get( - "chat_template_jinja" - ): - raise ValueError( - "chat_template_jinja is required when chat_template is set to jinja" - ) - - # If chat_template_jinja is set, set chat_template to jinja - if data.get("chat_template_jinja") and not data.get("chat_template"): - data["chat_template"] = ChatTemplate.jinja - - return data - - @model_validator(mode="before") - @classmethod - def check_sample_packing_wo_flash(cls, data): - if ( - data.get("sample_packing") - and not data.get("flash_attention") - and not data.get("sdp_attention") - and not data.get("flex_attention") - and not data.get("xformers_attention") - ): - LOG.warning( - "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_sample_packing_with_s2attn(cls, data): - if data.get("sample_packing") and data.get("s2_attention"): - raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_batch_flattening_fa(cls, data): - if data.get("batch_flattening"): - batch_flattening_auto = data.get("batch_flattening") == "auto" - if not data.get("flash_attention") and not batch_flattening_auto: - raise ValueError("batch_flattening requires flash attention") - if data.get("sample_packing") and not batch_flattening_auto: - raise ValueError("batch_flattening not compatible with sample_packing") - if data.get("micro_batch_size") == 1 and not batch_flattening_auto: - LOG.warning("batch_flattening has no effect with micro_batch_size == 1") - - if ( - batch_flattening_auto - and data.get("flash_attention") - and not data.get("sample_packing") - and data.get("micro_batch_size") > 1 - ): - data["batch_flattening"] = True - elif batch_flattening_auto: - data["batch_flattening"] = False - - return data - - @model_validator(mode="before") - @classmethod - def check_sample_packing_w_rl(cls, data): - if data.get("sample_packing") and data.get("rl"): - raise ValueError("`sample_packing: true` does not work with RLHF training") - return data - - @model_validator(mode="before") - @classmethod - def hint_sample_packing_padding(cls, data): - if data.get("sample_packing"): - pad_to_sequence_len = data.get("pad_to_sequence_len") - if pad_to_sequence_len is False: - LOG.warning( - "`pad_to_sequence_len: true` is recommended when using sample_packing" - ) - elif pad_to_sequence_len is None: - LOG.info( - "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing" - ) - data["pad_to_sequence_len"] = True - return data - - @model_validator(mode="before") - @classmethod - def hint_reward_model_pad(cls, data): - if data.get("reward_model") and not data.get("pad_to_sequence_len"): - LOG.warning( - "`pad_to_sequence_len: true` is recommended when using reward_model" - ) - if data.get("pad_to_sequence_len") is None: - data["pad_to_sequence_len"] = True - return data - - @model_validator(mode="before") - @classmethod - def check_gas_bsz(cls, data): - if data.get("gradient_accumulation_steps") and data.get("batch_size"): - raise ValueError( - "please set only one of gradient_accumulation_steps or batch_size" - ) - return data - - @model_validator(mode="before") - @classmethod - def hint_eval_train_mbsz(cls, data): - if ( - data.get("eval_batch_size") - and data.get("micro_batch_size") - and data.get("eval_batch_size") != data.get("micro_batch_size") - ): - LOG.warning( - "eval_batch_size != micro_batch_size. This can lead to VRAM instability." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_push_ds_auth(cls, data): - if ( - data.get("push_dataset_to_hub") - and data.get("hf_use_auth_token") is not True - ): - raise ValueError( - "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" - ) - return data - - @model_validator(mode="after") - def check_falcon_fsdp(self): - if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp: - raise ValueError("FSDP is not supported for falcon models") - return self - - @model_validator(mode="after") - def check_mpt_checkpointing(self): - if ( - self.base_model and "mpt" in self.base_model.lower() - ) and self.gradient_checkpointing: - raise ValueError("gradient_checkpointing is not supported for MPT models") - return self - - @model_validator(mode="after") - def check_offload_grad_checkpointing(self): - if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth": - LOG.warning( - "`unsloth` is deprecated for gradient_checkpointing, use `offload`" - ) - self.gradient_checkpointing = "offload" - return self - - @model_validator(mode="after") - def check_better_transformers(self): - if self.flash_optimum is True: - if self.adapter: - LOG.warning( - "BetterTransformers probably doesn't work with PEFT adapters" - ) - if self.fp16 or self.bf16: - raise ValueError("AMP is not supported with BetterTransformer") - if self.float16 is not True and self.bfloat16 is not True: - LOG.warning( - "You should probably set bfloat16 or float16 to true to " - "load the model in float16 for BetterTransformers" - ) - return self - - @model_validator(mode="after") - def check_adamw_optimizer_params(self): - if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and ( - not self.optimizer or "adamw" not in str(self.optimizer).lower() - ): - LOG.warning("adamw hyperparameters found, but no adamw optimizer set") - return self - - @model_validator(mode="before") - @classmethod - def check_lr_groups(cls, data): - if data.get("lr_groups") and data.get("loraplus_lr_ratio"): - raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.") - return data - - @model_validator(mode="before") - @classmethod - def check_saves(cls, data): - if ( - data.get("save_strategy") - and data.get("save_steps") - and data.get("save_strategy") != "steps" - ): - raise ValueError( - "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." - ) - if data.get("saves_per_epoch") and data.get("save_steps"): - raise ValueError( - "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_push_save(cls, data): - if data.get("hub_model_id") and ( - data.get("save_strategy") not in ["steps", "epoch", None] - ): - LOG.warning( - "hub_model_id is set without any models being saved. To save a model, set save_strategy." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_evals(cls, data): - if ( - data.get("eval_strategy") - and data.get("eval_steps") - and data.get("eval_strategy") != "steps" - ): - raise ValueError( - "eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps." - ) - - if ( - data.get("val_set_size") == 0 - and (data.get("eval_steps") or data.get("eval_strategy")) - and not data.get("test_datasets") - and data.get("eval_strategy") != "no" - ): - raise ValueError( - "eval_steps and eval_strategy are not supported with val_set_size == 0" - ) - if data.get("evals_per_epoch") and data.get("eval_steps"): - raise ValueError( - "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." - ) - if ( - data.get("evals_per_epoch") - and data.get("eval_strategy") - and data.get("eval_strategy") != "steps" - ): - raise ValueError( - "eval_strategy must be empty or set to `steps` when used with evals_per_epoch." - ) - - if data.get("do_bench_eval") and not ( - data.get("evals_per_epoch") or data.get("eval_steps") - ): - raise ValueError( - "do_bench_eval requires evals_per_epoch or eval_steps to be set." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_test_datasets_bench(cls, data): - if ( - data.get("do_bench_eval") - and not data.get("test_datasets") - and not data.get("val_set_size") - ): - LOG.warning( - "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." - ) - data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] - return data - - @model_validator(mode="before") - @classmethod - def check_eval_packing(cls, data): - # TODO also should check test_datasets and val_set_size as we can skip - # if there are no eval datasets/splits - if ( - data.get("sample_packing") - and data.get("eval_table_size") - and data.get("eval_sample_packing") is not False - ): - raise ValueError( - "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." - ) - if ( - data.get("sample_packing") - and data.get("eval_sample_packing") is None - and not data.get("eval_table_size") - ): - LOG.info( - "explicitly setting `eval_sample_packing` to match `sample_packing`" - ) - data["eval_sample_packing"] = True - - if ( - data.get("sample_packing") - and data.get("eval_sample_packing") is False - and data.get("remove_unused_columns") is None - ): - LOG.info( - "setting `remove_unused_columns: false` for when sample_packing and eval_sample_packing don't match" - ) - data["remove_unused_columns"] = False - - return data - - @model_validator(mode="before") - @classmethod - def check_mm_prepare(cls, data): - if data.get("skip_prepare_dataset"): - if data.get("remove_unused_columns") is None: - LOG.info( - "setting `remove_unused_columns: false` for skip_prepare_dataset" - ) - data["remove_unused_columns"] = False - - return data - - @model_validator(mode="before") - @classmethod - def check_warmup(cls, data): - if data.get("warmup_steps") and data.get("warmup_ratio"): - raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") - return data - - @model_validator(mode="before") - @classmethod - def check_neftune(cls, data): - if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"): - data["neftune_noise_alpha"] = data["noisy_embedding_alpha"] - del data["noisy_embedding_alpha"] - elif data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"): - raise ValueError( - "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting" - ) - return data - - @field_validator("neftune_noise_alpha") - @classmethod - def validate_neftune_noise_alpha(cls, neftune_noise_alpha): - if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0: - raise ValueError("neftune_noise_alpha must be > 0.0") - return neftune_noise_alpha - - @model_validator(mode="after") - def check_rl_beta(self): - if self.dpo_beta and not self.rl_beta: - self.rl_beta = self.dpo_beta - del self.dpo_beta - return self - - @model_validator(mode="after") - def check_simpo_warmup(self): - if self.rl is RLType.SIMPO and self.warmup_ratio: - raise ValueError( - "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" - ) - return self - - @model_validator(mode="before") - @classmethod - def check_frozen(cls, data): - if ( - data.get("adapter") - and data.get("peft_layers_to_transform") - and data.get("unfrozen_parameters") - ): - raise ValueError( - "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_peft_layers_pattern(cls, data): - if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): - raise ValueError( - "peft_layers_pattern requires peft_layers_to_transform to be set" - ) - return data - - @model_validator(mode="after") - def check_fft_possible_bad_config(self): - if ( - # pylint: disable=too-many-boolean-expressions - not (self.bf16 or self.bfloat16) - and (self.fp16 or self.float16) - and not self.adapter - and not self.flash_attention - and self.sample_packing - ): - LOG.warning( - "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA." - ) - # ValueError: Attempting to unscale FP16 gradients. - # OR - # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half - return self - - @model_validator(mode="after") - def check_fused_lora(self): - if self.adapter in ["lora", "qlora"] and ( - self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp - ): - raise ValueError("Fused modules are not supported with LoRA/QLoRA") - return self - - @model_validator(mode="after") - def hint_lora_8bit(self): - loftq = ( - self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits - ) - if not self.load_in_8bit and self.adapter == "lora" and not loftq: - LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") - return self - - @model_validator(mode="after") - def check_early_stopping(self): - if self.early_stopping_patience: - if not self.save_steps or not self.eval_steps: - raise ValueError( - "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." - ) - if self.save_steps % self.eval_steps != 0: - raise ValueError( - "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." - ) - return self - - @model_validator(mode="after") - def check_relora(self): - if self.relora_steps: - if self.adapter not in ("lora", "qlora"): - raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") - - if self.fsdp: - raise ValueError("fsdp not supported with ReLoRA") - - if self.deepspeed: - raise ValueError("deepspeed not supported with ReLoRA") - - if self.lr_scheduler == "one_cycle": - raise ValueError( - "ReLoRA is not compatible with the one_cycle scheduler" - ) - - if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp: - raise ValueError("Fused modules are not supported with ReLoRA") - return self - - @model_validator(mode="before") - @classmethod - def check_mem_mismatch(cls, data): - if ( - data.get("max_memory") is not None - and data.get("gpu_memory_limit") is not None - ): - raise ValueError( - "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_use_reentrant_mismatch(cls, data): - if ( - data.get("unfrozen_parameters") - and data.get("gradient_checkpointing_kwargs") - and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") - is True - ): - # https://github.com/huggingface/transformers/issues/21381 - raise ValueError( - "`use_reentrant` must be false when used with partially frozen model." - ) - return data - - @model_validator(mode="before") - @classmethod - def warn_qlora_zero3_w_use_reentrant(cls, data): - if ( - data.get("adapter") == "qlora" - and data.get("gradient_checkpointing_kwargs", {}) - and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") - is False - and data.get("deepspeed", "") is not None - and "zero3" in data.get("deepspeed", "") - ): - # may result in: - # torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: - # Recomputed values for the following tensors have different metadata - # than during the forward pass. - LOG.warning( - "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_val_w_test_datasets(cls, data): - if data.get("test_datasets") and data.get("val_set_size"): - raise ValueError( - "non-zero val_set_size should not be used with test_datasets configuration" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_eval_strategy(cls, data): - if ( - data.get("evaluation_strategy") is not None - and data.get("eval_strategy") is None - ): - LOG.info( - "explicitly setting `eval_strategy` from the `evaluation_strategy`" - ) - data["eval_strategy"] = data.get("evaluation_strategy") - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp_offload_w_8bit_optimizer(cls, data): - if ( - data.get("fsdp") - and "8bit" in data.get("optimizer", "") - and data.get("fsdp_config") - and data["fsdp_config"].get("fsdp_offload_params") - and str(data["fsdp_config"].get("fsdp_version")) != "2" - ): - raise ValueError( - f"FSDP Offload not compatible with {data.get('optimizer')}" - ) - if ( - data.get("fsdp") - and "8bit" in data.get("optimizer", "") - and data.get("fsdp_config") - and str(data["fsdp_config"].get("fsdp_version")) == "2" - ): - if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]: - # CUDA ops errors with bnb 8bit optimizer + FSDP2 - raise ValueError( - f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead" - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_fsdp_sharded_state_dict_w_safetensors(cls, data): - if ( - data.get("fsdp") - and data.get("save_safetensors") - and data.get("fsdp_config") - and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" - ): - raise ValueError( - "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_causal_lm_evals(cls, data): - if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"): - raise ValueError( - "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" - ) - - if data.get("eval_causal_lm_metrics"): - if not isinstance(data.get("eval_causal_lm_metrics"), list): - raise ValueError("eval_causal_lm_metrics must be a list") - # only ["sacrebleu", "comet", "ter", "chrf"] supported - if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS: - raise ValueError( - f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_dataset_or_pretraining_dataset(cls, data): - if data.get("datasets") is None and data.get("pretraining_dataset") is None: - raise ValueError("either datasets or pretraining_dataset is required") - return data - - @model_validator(mode="before") - @classmethod - def check_xentropy_patch_conflicts(cls, data): - if data.get("flash_attn_cross_entropy") and data.get( - "unsloth_cross_entropy_loss" - ): - raise ValueError( - "flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_qlora_unsloth(cls, data): - if ( - data.get("unsloth_lora_mlp") - or data.get("unsloth_lora_qkv") - or data.get("unsloth_lora_o") - ): - if data.get("adapter") == "lora" and data.get("load_in_8bit"): - raise ValueError( - "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_lora_kernel_8bit(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ): - if data.get("adapter") == "lora" and data.get("load_in_8bit"): - raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_lora_kernel_rl(cls, data): - if ( - data.get("lora_mlp_kernel") - or data.get("lora_qkv_kernel") - or data.get("lora_o_kernel") - ) and data.get("rl"): - raise ValueError( - "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_lora_axolotl_unsloth(cls, data): - is_lora_kernel = any( - data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] - ) - is_unsloth_lora = any( - data.get(k) - for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] - ) - if is_lora_kernel and is_unsloth_lora: - raise ValueError( - "both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_torch_compile_deepspeed(cls, data): - if data.get("deepspeed") and data.get("torch_compile"): - raise ValueError( - "torch_compile should be set within your deepspeed config file" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_npu_config(cls, data): - if is_torch_npu_available(): - # check attention config - attn_list = ["flash_attention", "sdp_attention", "s2_attention"] - for attn in attn_list: - if data.get(attn): - raise NotImplementedError( - f"{attn} is currently not supported in Ascend npu, please disable this configuration." - ) - - # check quant config - if data.get("optimizer") is not None and "bit" in data.get("optimizer"): - optimizer = data.get("optimizer") - raise NotImplementedError( - f"{optimizer} is currently not supported in Ascend npu, choose another one please." - ) - - quant_list = ["load_in_8bit", "load_in_4bit"] - for quant in quant_list: - if data.get(quant): - raise NotImplementedError( - f"Quantification is currently not supported in Ascend npu, please disable {quant}." - ) - - # check dtype config - if data.get("tf32"): - raise NotImplementedError( - "tf32 dtype is currently not supported in Ascend npu, please disable this configuration" - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_rl_config_gradient_checkpointing(cls, data): - # TODO: SalmanMohammadi - # Distributed RL with QLoRA + gradient checkpointing - # and use_reentrant = True is broken upstream in TRL - # pylint: disable=too-many-boolean-expressions - if ( - data.get("rl") - and data.get("gradient_checkpointing") - and data.get("gradient_checkpointing_kwargs") - and data.get("gradient_checkpointing_kwargs").get("use_reentrant") - and data.get("load_in_4bit") - and data.get("adapter") == "qlora" - and data.get("capabilities") - and data.get("capabilities").get("n_gpu", 1) > 1 - ): - raise ValueError( - "The `use_reentrant: True` implementation of gradient checkpointing " - "is not supported for distributed RL training with QLoRA. Please set " - "`use_reentrant: False` in `gradient_checkpointing_kwargs`." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_kto_config(cls, data): - if data.get("rl") == "kto": - if data.get("sample_packing") or data.get("eval_sample_packing"): - raise ValueError("sample_packing is not supported with kto") - - if data.get("remove_unused_columns") is not False: - raise ValueError("Set `remove_unused_columns: False` when using kto") - - return data - - @model_validator(mode="before") - @classmethod - def check_grpo_liger_sequence_parallel(cls, data): - if ( - data.get("rl") == "grpo" - and data.get("trl", {}) - and data.get("trl").get("use_liger_loss") - and data.get("sequence_parallel_degree", 1) > 1 - ): - raise ValueError("GRPO + SP + Liger not currently supported") - return data - - @model_validator(mode="after") - def check_sequence_parallel_degree(self): - if not self.sequence_parallel_degree: - self.sequence_parallel_degree = 1 - elif self.sequence_parallel_degree > 1: - if not self.flash_attention: - raise ValueError( - "flash_attention: true must be set with sequence_parallel_degree > 1" - ) - - if self.sample_packing and getattr(self, "micro_batch_size", 1) > 1: - raise ValueError( - "micro_batch_size must be set to 1 when sample_packing is enabled " - "due to a `ring-flash-attn` requirement" - ) - - try: - import ring_flash_attn # noqa: F401 # pylint:disable=unused-import - except ImportError as exception: - raise ImportError( - "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " - "Please install it with `pip install axolotl[ring-flash-attn] " - "or `pip install ring-flash-attn>=0.1.4`." - ) from exception - - # TODO: monkeypatch / callback to average losses correctly across SP ranks - # / fix gradient scaling across SP ranks. Losses, grads should be scaled - # according to the proportion of non-padding tokens per rank. - LOG.warning( - "Sequence parallelism (SP) is enabled with " - f"sequence_parallel_degree={self.sequence_parallel_degree}. " - "Please note that logged losses may differ slightly to the non-SP " - "losses due to transformers Trainer implementation details. " - "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " - "for more details." - ) - - return self - - @model_validator(mode="after") - def validate_ring_attn_func(self): - if getattr(self, "sequence_parallel_degree", 1) == 1: - return self - - if self.ring_attn_func is not None: - self.ring_attn_func = RingAttnFunc(self.ring_attn_func) - else: - # Default ring attention function selection - sample_packing = getattr(self, "sample_packing", False) - self.ring_attn_func = ( - RingAttnFunc.VARLEN_LLAMA3 - if sample_packing - else RingAttnFunc.BATCH_RING - ) - - return self - - @model_validator(mode="before") - @classmethod - def check_muon_deepspeed_fsdp(cls, data): - if data.get("optimizer") == "muon" and ( - data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config") - ): - raise ValueError( - "Muon optimizer is currently incompatible with DeepSpeed and FSDP" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_tokenizer_use_mistral_common(cls, data): - if data.get("tokenizer_use_mistral_common") is None: - if any( - "magistral" in name.lower() - for name in [ - data.get("base_model", ""), - data.get("base_model_config", ""), - data.get("tokenizer_config", ""), - ] - ): - LOG.warning( - "tokenizer_use_mistral_common auto inferred to True for Magistral models. Please set it to True explicitly if you want to use mistral-common tokenizer." - ) - data["tokenizer_use_mistral_common"] = True - - return data - - @field_validator("tokenizer_use_mistral_common", mode="after") - @classmethod - def check_mistral_common_import(cls, tokenizer_use_mistral_common): - if tokenizer_use_mistral_common: - try: - import mistral_common # noqa: F401 # pylint:disable=unused-import - except ImportError as exception: - raise ImportError( - "mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`." - ) from exception - - return tokenizer_use_mistral_common - - @model_validator(mode="before") - @classmethod - def check_mistral_common_incompatible_options(cls, data): - if not data.get("tokenizer_use_mistral_common"): - return data - - # NOTE: mistral-common tokenizer is not compatible with editing tokenizer at the moment - - if data.get("added_tokens_overrides"): - raise ValueError( - "added_tokens_overrides is not supported with mistral-common tokenizer" - ) - - if data.get("special_tokens"): - raise ValueError( - "special_tokens override is not supported with mistral-common tokenizer" - ) - - if data.get("tokens"): - raise ValueError( - "tokens override is not supported with mistral-common tokenizer" - ) - - if data.get("chat_template"): - raise ValueError( - "Setting chat_template is not supported with mistral-common tokenizer" - ) - - return data - class AxolotlConfigWCapabilities(AxolotlInputConfig): - """wrapper to valdiate gpu capabilities with the configured options""" + """wrapper to valdiate GPU capabilities with the configured options""" capabilities: GPUCapabilities env_capabilities: EnvCapabilities @@ -1375,13 +877,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): return data - @model_validator(mode="before") - @classmethod - def check_fsdp_deepspeed(cls, data): - if data.get("deepspeed") and data.get("fsdp"): - raise ValueError("deepspeed and fsdp cannot be used together.") - return data - + # pylint: disable=duplicate-code @model_validator(mode="before") @classmethod def check_multigpu_unsloth(cls, data): @@ -1397,6 +893,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return data + # pylint: disable=duplicate-code @model_validator(mode="before") @classmethod def check_multigpu_lora_kernels(cls, data): diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py index c71f9be77..d9459feb9 100644 --- a/src/axolotl/utils/schemas/datasets.py +++ b/src/axolotl/utils/schemas/datasets.py @@ -1,6 +1,8 @@ """Pydantic models for datasets-related configuration""" -from pydantic import BaseModel, model_validator +from typing import Literal + +from pydantic import BaseModel, Field, model_validator from axolotl.utils.schemas.enums import ChatTemplate from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic @@ -9,57 +11,178 @@ from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic class UserDefinedPrompterType(BaseModel): """Structure for user defined prompt types""" - system_prompt: str | None = None - system_format: str | None = None + system_prompt: str | None = Field( + default=None, + json_schema_extra={"description": "Custom user instruction prompt"}, + ) + system_format: str | None = Field( + default=None, + json_schema_extra={"description": "Use {system} as key to be replaced"}, + ) field_system: str | None = None field_instruction: str | None = None field_input: str | None = None field_output: str | None = None - format: str | None = None - no_input_format: str | None = None - field: str | None = None + format: str | None = Field( + default=None, + json_schema_extra={ + "description": "Customizable to be single line or multi-line. Use {instruction}/{input} as key to be replaced. 'format' can include {input}" + }, + ) + no_input_format: str | None = Field( + default=None, + json_schema_extra={"description": "'no_input_format' cannot include {input}"}, + ) + field: str | None = Field( + default=None, + json_schema_extra={ + "description": "For `completion` datsets only, uses the provided field instead of `text` column" + }, + ) class SFTDataset(BaseModel): """SFT configuration subset""" - path: str | None = None - split: str | None = None - type: str | UserDefinedPrompterType | None = None + path: str | None = Field( + default=None, + json_schema_extra={ + "description": "HuggingFace dataset repo | s3:// | gs:// | path to local file or directory" + }, + ) + split: str | None = Field( + default=None, + json_schema_extra={"description": "name of dataset split to load from"}, + ) + type: str | UserDefinedPrompterType | None = Field( + default=None, + json_schema_extra={ + "description": "The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]" + }, + ) input_transform: str | None = None - shards: int | None = None - shards_idx: int | None = None - preprocess_shards: int | None = None + shards: int | None = Field( + default=None, + json_schema_extra={ + "description": "split dataset into N pieces (use with shards_idx)" + }, + ) + shards_idx: int | None = Field( + default=None, + json_schema_extra={"description": "the index of sharded dataset to use"}, + ) + preprocess_shards: int | None = Field( + default=None, + json_schema_extra={ + "description": "process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)" + }, + ) conversation: str | None = None # Do not make this too strict or it will break the validator to choose different dataset class - chat_template: ChatTemplate | str | None = None - chat_template_jinja: str | None = None - data_files: str | list[str] | None = None + chat_template: ChatTemplate | str | None = Field( + default=None, + json_schema_extra={ + "description": "The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field." + }, + ) + chat_template_jinja: str | None = Field( + default=None, + json_schema_extra={ + "description": "Custom jinja chat template. Used only if `chat_template: jinja` or empty." + }, + ) + data_files: str | list[str] | None = Field( + default=None, json_schema_extra={"description": "path to source data files"} + ) input_format: str | None = None - name: str | None = None - ds_type: str | None = None + name: str | None = Field( + default=None, + json_schema_extra={"description": "name of dataset configuration to load"}, + ) + ds_type: str | None = Field( + default=None, + json_schema_extra={"description": "defines the datatype when path is a file"}, + ) field: str | None = None field_human: str | None = None field_model: str | None = None - field_messages: str | None = None - field_tools: str | None = None + field_messages: str | None = Field( + default=None, + json_schema_extra={ + "description": 'Key containing the messages (default: "messages")' + }, + ) + field_tools: str | None = Field( + default=None, + json_schema_extra={ + "description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).' + }, + ) # deprecated, use message_property_mappings message_field_role: str | None = None # deprecated, use message_property_mappings message_field_content: str | None = None - message_property_mappings: dict[str, str] | None = None - message_field_training: str | None = None - message_field_training_detail: str | None = None - split_thinking: bool | None = None + message_property_mappings: dict[str, str] | None = Field( + default=None, + json_schema_extra={ + "description": "Mapping of properties from the input dataset to the chat template. (default: message_property_mappings={'role':'role', 'content':'content'}) If a property exists in the template but not in this mapping, the system will attempt to load it directly from the message using the property name as the key. Example: In the mapping below, 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and used as 'content' in the chat template." + }, + ) + message_field_training: str | None = Field( + default=None, + json_schema_extra={ + "description": "The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`." + }, + ) + message_field_training_detail: str | None = Field( + default=None, + json_schema_extra={ + "description": "The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train)." + }, + ) + split_thinking: bool | None = Field( + default=None, + json_schema_extra={ + "description": "(for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags" + }, + ) logprobs_field: str | None = None temperature: float | None = None - roles_to_train: list[str] | None = None - train_on_eos: str | None = None - roles: dict[str, list[str]] | None = None - drop_system_message: bool | None = None - trust_remote_code: bool | None = False - revision: str | None = None + roles_to_train: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "Roles to train on. The tokens from these roles will be considered for the loss." + }, + ) + train_on_eos: Literal["all", "turn", "last"] | None = Field( + default=None, + json_schema_extra={ + "description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation" + }, + ) + roles: dict[str, list[str]] | None = Field( + default=None, + json_schema_extra={ + "description": 'Roles mapping in the messages. The format is {target_role: [source_roles]}. All source roles will be mapped to the target role. The default is: user: ["human", "user"], assistant: ["gpt", "assistant"], system: ["system"], tool: ["tool"]' + }, + ) + drop_system_message: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to drop the system turn from the dataset. Only works with chat_template. This does not drop the default system message from chat_template if it exists. If you wish to, we recommend using a custom jinja template with the default system message removed or adding a system turn with empty content." + }, + ) + trust_remote_code: bool | None = Field( + default=False, + json_schema_extra={"description": "Trust remote code for untrusted source"}, + ) + revision: str | None = Field( + default=None, + json_schema_extra={ + "description": "The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets." + }, + ) @model_validator(mode="before") @classmethod diff --git a/src/axolotl/utils/schemas/deprecated.py b/src/axolotl/utils/schemas/deprecated.py index b8904136e..972fe0ccf 100644 --- a/src/axolotl/utils/schemas/deprecated.py +++ b/src/axolotl/utils/schemas/deprecated.py @@ -60,10 +60,30 @@ class RemappedParameters(BaseModel): """Parameters that have been remapped to other names""" overrides_of_model_config: dict[str, Any] | None = Field( - default=None, alias="model_config" + default=None, + alias="model_config", + json_schema_extra={ + "description": "optional overrides to the base model configuration" + }, ) overrides_of_model_kwargs: dict[str, Any] | None = Field( - default=None, alias="model_kwargs" + default=None, + alias="model_kwargs", + json_schema_extra={ + "description": "optional overrides the base model loading from_pretrained" + }, + ) + type_of_model: str | None = Field( + default=None, + alias="model_type", + json_schema_extra={ + "description": "If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too" + }, + ) + revision_of_model: str | None = Field( + default=None, + alias="model_revision", + json_schema_extra={ + "description": "You can specify to choose a specific model revision from huggingface hub" + }, ) - type_of_model: str | None = Field(default=None, alias="model_type") - revision_of_model: str | None = Field(default=None, alias="model_revision") diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index d09ab6387..bfef14d53 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -1,5 +1,7 @@ """Enums for Axolotl input config""" +# pylint: disable=invalid-name + from enum import Enum import torch @@ -8,81 +10,81 @@ import torch class TorchIntDType(Enum): """Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4""" - uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name - uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name - uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name - uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name - uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name - uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name - uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name - int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name - int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name + uint1 = getattr(torch, "uint1", None) + uint2 = getattr(torch, "uint2", None) + uint3 = getattr(torch, "uint3", None) + uint4 = getattr(torch, "uint4", None) + uint5 = getattr(torch, "uint5", None) + uint6 = getattr(torch, "uint6", None) + uint7 = getattr(torch, "uint7", None) + int4 = getattr(torch, "int4", None) + int8 = getattr(torch, "int8", None) class RLType(str, Enum): """RL trainer type configuration subset""" - DPO = "dpo" # pylint: disable=invalid-name - GRPO = "grpo" # pylint: disable=invalid-name - IPO = "ipo" # pylint: disable=invalid-name - ORPO = "orpo" # pylint: disable=invalid-name - KTO = "kto" # pylint: disable=invalid-name - SIMPO = "simpo" # pylint: disable=invalid-name + DPO = "dpo" + GRPO = "grpo" + IPO = "ipo" + ORPO = "orpo" + KTO = "kto" + SIMPO = "simpo" class ChatTemplate(str, Enum): """Chat templates configuration subset""" - alpaca = "alpaca" # pylint: disable=invalid-name - chatml = "chatml" # pylint: disable=invalid-name - mistral_v1 = "mistral_v1" # pylint: disable=invalid-name - mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name - mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name - mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name - gemma = "gemma" # pylint: disable=invalid-name - cohere = "cohere" # pylint: disable=invalid-name - llama3 = "llama3" # pylint: disable=invalid-name - llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name - llama4 = "llama4" # pylint: disable=invalid-name - phi_3 = "phi_3" # pylint: disable=invalid-name - phi_35 = "phi_35" # pylint: disable=invalid-name - deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name - deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name - jamba = "jamba" # pylint: disable=invalid-name - jinja = "jinja" # pylint: disable=invalid-name - qwen_25 = "qwen_25" # pylint: disable=invalid-name - qwen3 = "qwen3" # pylint: disable=invalid-name - tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name - exaone = "exaone" # pylint: disable=invalid-name - metharme = "metharme" # pylint: disable=invalid-name - pixtral = "pixtral" # pylint: disable=invalid-name - llava = "llava" # pylint: disable=invalid-name - qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name - gemma3 = "gemma3" # pylint: disable=invalid-name - command_a = "command_a" # pylint: disable=invalid-name - command_a_tool_use = "command_a_tool_use" # pylint: disable=invalid-name - command_a_rag = "command_a_rag" # pylint: disable=invalid-name - aya = "aya" # pylint: disable=invalid-name + alpaca = "alpaca" + chatml = "chatml" + mistral_v1 = "mistral_v1" + mistral_v2v3 = "mistral_v2v3" + mistral_v3_tekken = "mistral_v3_tekken" + mistral_v7_tekken = "mistral_v7_tekken" + gemma = "gemma" + cohere = "cohere" + llama3 = "llama3" + llama3_2_vision = "llama3_2_vision" + llama4 = "llama4" + phi_3 = "phi_3" + phi_35 = "phi_35" + deepseek_v2 = "deepseek_v2" + deepseek_v3 = "deepseek_v3" + jamba = "jamba" + jinja = "jinja" + qwen_25 = "qwen_25" + qwen3 = "qwen3" + tokenizer_default = "tokenizer_default" + exaone = "exaone" + metharme = "metharme" + pixtral = "pixtral" + llava = "llava" + qwen2_vl = "qwen2_vl" + gemma3 = "gemma3" + command_a = "command_a" + command_a_tool_use = "command_a_tool_use" + command_a_rag = "command_a_rag" + aya = "aya" class CustomSupportedOptimizers(str, Enum): """Custom supported optimizers""" - optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name - ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name - ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name - ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name - adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name - came_pytorch = "came_pytorch" # pylint: disable=invalid-name - muon = "muon" # pylint: disable=invalid-name + optimi_adamw = "optimi_adamw" + ao_adamw_4bit = "ao_adamw_4bit" + ao_adamw_8bit = "ao_adamw_8bit" + ao_adamw_fp8 = "ao_adamw_fp8" + adopt_adamw = "adopt_adamw" + came_pytorch = "came_pytorch" + muon = "muon" class RingAttnFunc(str, Enum): """Enum class for supported `ring-flash-attn` implementations""" - # VARLEN_RING = "varlen_ring" - # VARLEN_ZIGZAG = "varlen_zigzag" VARLEN_LLAMA3 = "varlen_llama3" BATCH_RING = "batch_ring" + # VARLEN_RING = "varlen_ring" + # VARLEN_ZIGZAG = "varlen_zigzag" # BATCH_ZIGZAG = "batch_zigzag" # BATCH_STRIPE = "batch_stripe" diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py index 4843e3592..7332c7d39 100644 --- a/src/axolotl/utils/schemas/integrations.py +++ b/src/axolotl/utils/schemas/integrations.py @@ -13,10 +13,21 @@ class MLFlowConfig(BaseModel): """MLFlow configuration subset""" use_mlflow: bool | None = None - mlflow_tracking_uri: str | None = None - mlflow_experiment_name: str | None = None - mlflow_run_name: str | None = None - hf_mlflow_log_artifacts: bool | None = None + mlflow_tracking_uri: str | None = Field( + default=None, json_schema_extra={"description": "URI to mlflow"} + ) + mlflow_experiment_name: str | None = Field( + default=None, json_schema_extra={"description": "Your experiment name"} + ) + mlflow_run_name: str | None = Field( + default=None, json_schema_extra={"description": "Your run name"} + ) + hf_mlflow_log_artifacts: bool | None = Field( + default=None, + json_schema_extra={ + "description": "set to true to copy each saved checkpoint on each save to mlflow artifact registry" + }, + ) class LISAConfig(BaseModel): @@ -40,13 +51,33 @@ class WandbConfig(BaseModel): """Wandb configuration subset""" use_wandb: bool | None = None - wandb_name: str | None = None - wandb_run_id: str | None = None - wandb_mode: str | None = None - wandb_project: str | None = None - wandb_entity: str | None = None + wandb_name: str | None = Field( + default=None, + json_schema_extra={"description": "Set the name of your wandb run"}, + ) + wandb_run_id: str | None = Field( + default=None, json_schema_extra={"description": "Set the ID of your wandb run"} + ) + wandb_mode: str | None = Field( + default=None, + json_schema_extra={ + "description": '"offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb' + }, + ) + wandb_project: str | None = Field( + default=None, json_schema_extra={"description": "Your wandb project name"} + ) + wandb_entity: str | None = Field( + default=None, + json_schema_extra={"description": "A wandb Team name if using a Team"}, + ) wandb_watch: str | None = None - wandb_log_model: str | None = None + wandb_log_model: str | None = Field( + default=None, + json_schema_extra={ + "description": '"checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training' + }, + ) @model_validator(mode="before") @classmethod @@ -64,14 +95,52 @@ class WandbConfig(BaseModel): class CometConfig(BaseModel): """Comet configuration subset""" - use_comet: bool | None = None - comet_api_key: str | None = None - comet_workspace: str | None = None - comet_project_name: str | None = None - comet_experiment_key: str | None = None - comet_mode: str | None = None - comet_online: bool | None = None - comet_experiment_config: dict[str, Any] | None = None + use_comet: bool | None = Field( + default=None, + json_schema_extra={"description": "Enable or disable Comet integration."}, + ) + comet_api_key: str | None = Field( + default=None, + json_schema_extra={ + "description": "API key for Comet. Recommended to set via `comet login`." + }, + ) + comet_workspace: str | None = Field( + default=None, + json_schema_extra={ + "description": "Workspace name in Comet. Defaults to the user's default workspace." + }, + ) + comet_project_name: str | None = Field( + default=None, + json_schema_extra={ + "description": "Project name in Comet. Defaults to Uncategorized." + }, + ) + comet_experiment_key: str | None = Field( + default=None, + json_schema_extra={ + "description": "Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key." + }, + ) + comet_mode: str | None = Field( + default=None, + json_schema_extra={ + "description": 'Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.' + }, + ) + comet_online: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Set to True to log data to Comet server, or False for offline storage. Default is True." + }, + ) + comet_experiment_config: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Dictionary for additional configuration settings, see the doc for more details." + }, + ) class GradioConfig(BaseModel): diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index aafb52152..6f995996d 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -12,20 +12,55 @@ class ModelInputConfig(BaseModel): model_config = {"protected_namespaces": ()} - base_model: str - base_model_config: str | None = None + base_model: str = Field( + json_schema_extra={ + "description": "This is the huggingface model that contains *.pt, *.safetensors, or *.bin files. This can also be a relative path to a model on disk" + } + ) + base_model_config: str | None = Field( + default=None, + json_schema_extra={ + "description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model" + }, + ) cls_model_config: str | None = None - tokenizer_config: str | None = None - tokenizer_use_fast: bool | None = None - tokenizer_legacy: bool | None = None - tokenizer_use_mistral_common: bool | None = None + tokenizer_config: str | None = Field( + default=None, + json_schema_extra={ + "description": "Optional tokenizer configuration path in case you want to use a different tokenizer than the one defined in the base model" + }, + ) + tokenizer_use_fast: bool | None = Field( + default=None, + json_schema_extra={ + "description": "use_fast option for tokenizer loading from_pretrained, default to True" + }, + ) + tokenizer_legacy: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use the legacy tokenizer setting, defaults to True" + }, + ) + tokenizer_use_mistral_common: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer." + }, + ) tokenizer_type: str | None = Field( - default=None, json_schema_extra={"description": "transformers tokenizer class"} + default=None, + json_schema_extra={ + "description": "Corresponding tokenizer for the model AutoTokenizer is a good choice" + }, ) processor_type: str | None = Field( default=None, json_schema_extra={"description": "transformers processor class"} ) - trust_remote_code: bool | None = None + trust_remote_code: bool | None = Field( + default=None, + json_schema_extra={"description": "Trust remote code for untrusted source"}, + ) @field_validator("trust_remote_code") @classmethod @@ -40,10 +75,23 @@ class ModelInputConfig(BaseModel): class ModelOutputConfig(BaseModel): """model save configuration subset""" - output_dir: str = Field(default="./model-out") - hub_model_id: str | None = None - hub_strategy: str | None = None - save_safetensors: bool | None = True + output_dir: str = Field( + default="./model-out", + json_schema_extra={"description": "Where to save the full-finetuned model to"}, + ) + hub_model_id: str | None = Field( + default=None, json_schema_extra={"description": "push checkpoints to hub"} + ) + hub_strategy: str | None = Field( + default=None, + json_schema_extra={"description": "how to push checkpoints to hub"}, + ) + save_safetensors: bool | None = Field( + default=True, + json_schema_extra={ + "description": "Save model as safetensors (require safetensors package). Default True" + }, + ) class SpecialTokensConfig(BaseModel): diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index 5d408e1fe..4b31ce018 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -9,7 +9,7 @@ class LoftQConfig(BaseModel): """LoftQ configuration subset""" loftq_bits: int = Field( - default=4, json_schema_extra={"description": "Quantization bits for LoftQ"} + default=4, json_schema_extra={"description": "typically 4 bits"} ) # loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"}) @@ -17,31 +17,78 @@ class LoftQConfig(BaseModel): class PeftConfig(BaseModel): """peftq configuration subset""" - loftq_config: LoftQConfig | None = None + loftq_config: LoftQConfig | None = Field( + default=None, + json_schema_extra={ + "description": "Configuration options for loftq initialization for LoRA" + }, + ) class LoraConfig(BaseModel): """Peft / LoRA configuration subset""" - load_in_8bit: bool | None = Field(default=False) - load_in_4bit: bool | None = Field(default=False) + load_in_8bit: bool | None = Field( + default=False, + json_schema_extra={ + "description": "This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer" + }, + ) + load_in_4bit: bool | None = Field( + default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"} + ) - adapter: str | None = None - lora_model_dir: str | None = None + adapter: str | None = Field( + default=None, + json_schema_extra={ + "description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model" + }, + ) + lora_model_dir: str | None = Field( + default=None, + json_schema_extra={ + "description": "If you already have a lora model trained that you want to load, put that here. This means after training, if you want to test the model, you should set this to the value of `output_dir`. Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`." + }, + ) lora_r: int | None = None lora_alpha: int | None = None lora_fan_in_fan_out: bool | None = None lora_target_modules: str | list[str] | None = None - lora_target_linear: bool | None = None - lora_modules_to_save: list[str] | None = None + lora_target_linear: bool | None = Field( + default=None, + json_schema_extra={"description": "If true, will target all linear modules"}, + ) + lora_modules_to_save: list[str] | None = Field( + default=None, + json_schema_extra={ + "description": "If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities." + }, + ) lora_dropout: float | None = 0.0 - peft_layers_to_transform: list[int] | None = None + peft_layers_to_transform: list[int] | None = Field( + default=None, + json_schema_extra={ + "description": "The layer indices to transform, otherwise, apply to all layers" + }, + ) peft_layers_pattern: list[str] | None = None peft: PeftConfig | None = None - peft_use_dora: bool | None = None - peft_use_rslora: bool | None = None - peft_layer_replication: list[tuple[int, int]] | None = None - peft_init_lora_weights: bool | str | None = None + peft_use_dora: bool | None = Field( + default=None, json_schema_extra={"description": "Whether to use DoRA."} + ) + peft_use_rslora: bool | None = Field( + default=None, json_schema_extra={"description": "Whether to use RSLoRA."} + ) + peft_layer_replication: list[tuple[int, int]] | None = Field( + default=None, + json_schema_extra={"description": "List of layer indices to replicate."}, + ) + peft_init_lora_weights: bool | str | None = Field( + default=None, + json_schema_extra={ + "description": "How to initialize LoRA weights. Default to True which is MS original implementation." + }, + ) qlora_sharded_model_loading: bool | None = Field( default=False, @@ -49,9 +96,24 @@ class LoraConfig(BaseModel): "description": "load qlora model in sharded format for FSDP using answer.ai technique." }, ) - lora_on_cpu: bool | None = None - gptq: bool | None = None - bnb_config_kwargs: dict[str, Any] | None = None + lora_on_cpu: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge" + }, + ) + gptq: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether you are training a 4-bit GPTQ quantized model" + }, + ) + bnb_config_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "optional overrides to the bnb 4bit quantization configuration" + }, + ) loraplus_lr_ratio: float | None = Field( default=None, @@ -62,7 +124,7 @@ class LoraConfig(BaseModel): loraplus_lr_embedding: float | None = Field( default=1e-6, json_schema_extra={ - "description": "loraplus learning rate for lora embedding layers." + "description": "loraplus learning rate for lora embedding layers. Default value is 1e-6." }, ) @@ -125,8 +187,29 @@ class LoraConfig(BaseModel): class ReLoRAConfig(BaseModel): """ReLoRA configuration subset""" - relora_steps: int | None = None - relora_warmup_steps: int | None = None - relora_anneal_steps: int | None = None - relora_prune_ratio: float | None = None - relora_cpu_offload: bool | None = None + relora_steps: int | None = Field( + default=None, + json_schema_extra={"description": "Number of steps per ReLoRA restart"}, + ) + relora_warmup_steps: int | None = Field( + default=None, + json_schema_extra={"description": "Number of per-restart warmup steps"}, + ) + relora_anneal_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of anneal steps for each relora cycle" + }, + ) + relora_prune_ratio: float | None = Field( + default=None, + json_schema_extra={ + "description": "threshold for optimizer magnitude when pruning" + }, + ) + relora_cpu_offload: bool | None = Field( + default=None, + json_schema_extra={ + "description": "True to perform lora weight merges on cpu during restarts, for modest gpu memory savings" + }, + ) diff --git a/src/axolotl/utils/schemas/quantization.py b/src/axolotl/utils/schemas/quantization.py index fe2cdb1fe..090640c7b 100644 --- a/src/axolotl/utils/schemas/quantization.py +++ b/src/axolotl/utils/schemas/quantization.py @@ -15,17 +15,22 @@ class QATConfig(BaseModel): """ activation_dtype: TorchIntDType | None = Field( - default=None, description="Activation dtype" + default=None, + description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', ) weight_dtype: TorchIntDType = Field( - default=TorchIntDType.int8, description="Weight dtype" + default=TorchIntDType.int8, + description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"', ) quantize_embedding: bool | None = Field( default=False, description="Quantize embedding" ) - group_size: int | None = Field(default=32, description="Group size") + group_size: int | None = Field( + default=32, + description="The number of elements in each group for per-group fake quantization", + ) fake_quant_after_n_steps: int | None = Field( - default=None, description="Fake quant after n steps" + default=None, description="The number of steps to apply fake quantization after" ) @field_validator("activation_dtype", "weight_dtype", mode="before") @@ -44,15 +49,20 @@ class PTQConfig(BaseModel): """ weight_dtype: TorchIntDType = Field( - default=TorchIntDType.int8, description="Weight dtype" + default=TorchIntDType.int8, + description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8", ) activation_dtype: TorchIntDType | None = Field( - default=None, description="Activation dtype" + default=None, + description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', ) quantize_embedding: bool | None = Field( - default=None, description="Quantize embedding" + default=None, description="Whether to quantize the embedding layer." + ) + group_size: int | None = Field( + default=32, + description="The number of elements in each group for per-group fake quantization", ) - group_size: int | None = Field(default=32, description="Group size") @field_validator("activation_dtype", "weight_dtype", mode="before") @classmethod diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py index ad7f899ac..4d88cc9e6 100644 --- a/src/axolotl/utils/schemas/training.py +++ b/src/axolotl/utils/schemas/training.py @@ -23,10 +23,17 @@ class LrGroup(BaseModel): class HyperparametersConfig(BaseModel): """Training hyperparams configuration subset""" - gradient_accumulation_steps: int | None = Field(default=1) + gradient_accumulation_steps: int | None = Field( + default=1, + json_schema_extra={ + "description": "If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps." + }, + ) micro_batch_size: int | None = Field( default=1, - json_schema_extra={"description": "per gpu micro batch size for training"}, + json_schema_extra={ + "description": "The number of samples to include in each batch. This is the number of samples sent to each GPU. Batch size per gpu = micro_batch_size * gradient_accumulation_steps" + }, ) batch_size: int | None = Field( default=None, @@ -41,45 +48,99 @@ class HyperparametersConfig(BaseModel): }, ) - auto_find_batch_size: bool | None = None + auto_find_batch_size: bool | None = Field( + default=None, + json_schema_extra={ + "description": "whether to find batch size that fits in memory. Passed to underlying transformers Trainer" + }, + ) - train_on_inputs: bool | None = False - group_by_length: bool | None = None + train_on_inputs: bool | None = Field( + default=False, + json_schema_extra={ + "description": "Whether to mask out or include the human's prompt from the training labels" + }, + ) + group_by_length: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Group similarly sized data to minimize padding. May be slower to start, as it must download and sort the entire dataset. Note that training loss may have an oscillating pattern with this enabled." + }, + ) learning_rate: str | float embedding_lr: float | None = None embedding_lr_scale: float | None = None - weight_decay: float | None = 0.0 - optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = ( - OptimizerNames.ADAMW_TORCH_FUSED + weight_decay: float | None = Field( + default=0.0, json_schema_extra={"description": "Specify weight decay"} + ) + optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = Field( + default=OptimizerNames.ADAMW_TORCH_FUSED, + json_schema_extra={"description": "Specify optimizer"}, ) optim_args: (str | dict[str, Any]) | None = Field( default=None, - json_schema_extra={"description": "Optional arguments to supply to optimizer."}, + json_schema_extra={ + "description": "Dictionary of arguments to pass to the optimizer" + }, ) optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field( default=None, json_schema_extra={ - "description": "The target modules to optimize, i.e. the module names that you would like to train." + "description": "The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm" + }, + ) + torchdistx_path: str | None = Field( + default=None, + json_schema_extra={ + "description": "Path to torch distx for optim 'adamw_anyprecision'" }, ) - torchdistx_path: str | None = None lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = ( SchedulerType.COSINE ) - lr_scheduler_kwargs: dict[str, Any] | None = None + lr_scheduler_kwargs: dict[str, Any] | None = Field( + default=None, + json_schema_extra={ + "description": "Specify a scheduler and kwargs to use with the optimizer" + }, + ) lr_quadratic_warmup: bool | None = None - cosine_min_lr_ratio: float | None = None - cosine_constant_lr_ratio: float | None = None - lr_div_factor: float | None = None + cosine_min_lr_ratio: float | None = Field( + default=None, + json_schema_extra={ + "description": "decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr" + }, + ) + cosine_constant_lr_ratio: float | None = Field( + default=None, + json_schema_extra={ + "description": "freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step" + }, + ) + lr_div_factor: float | None = Field( + default=None, json_schema_extra={"description": "Learning rate div factor"} + ) lr_groups: list[LrGroup] | None = None - adam_epsilon: float | None = None - adam_epsilon2: float | None = None - adam_beta1: float | None = None - adam_beta2: float | None = None - adam_beta3: float | None = None - max_grad_norm: float | None = None + adam_epsilon: float | None = Field( + default=None, json_schema_extra={"description": "adamw hyperparams"} + ) + adam_epsilon2: float | None = Field( + default=None, json_schema_extra={"description": "only used for CAME Optimizer"} + ) + adam_beta1: float | None = Field( + default=None, json_schema_extra={"description": "adamw hyperparams"} + ) + adam_beta2: float | None = Field( + default=None, json_schema_extra={"description": "adamw hyperparams"} + ) + adam_beta3: float | None = Field( + default=None, json_schema_extra={"description": "only used for CAME Optimizer"} + ) + max_grad_norm: float | None = Field( + default=None, json_schema_extra={"description": "Gradient clipping max norm"} + ) num_epochs: float = Field(default=1.0) @field_validator("batch_size") diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py index 37b71dba8..d1b18a56e 100644 --- a/src/axolotl/utils/schemas/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -10,12 +10,14 @@ class TRLConfig(BaseModel): beta: float | None = Field( default=None, - json_schema_extra={"description": "Beta for RL training"}, + json_schema_extra={ + "description": "Beta parameter for the RL training. Same as `rl_beta`. Use" + }, ) max_completion_length: int | None = Field( default=None, json_schema_extra={ - "description": "Maximum length of the completion for RL training" + "description": "Maximum length of the completion for RL training." }, ) @@ -23,81 +25,69 @@ class TRLConfig(BaseModel): # Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23 use_vllm: bool = Field( default=False, - json_schema_extra={"description": "Whether to use VLLM for RL training"}, + json_schema_extra={"description": "Whether to use VLLM for RL training."}, ) vllm_server_host: str | None = Field( default="0.0.0.0", # nosec B104 - json_schema_extra={"description": "Host of the vLLM server to connect to"}, + json_schema_extra={"description": "Host of the vLLM server to connect to."}, ) vllm_server_port: int | None = Field( default=8000, - json_schema_extra={"description": "Port of the vLLM server to connect to"}, + json_schema_extra={"description": "Port of the vLLM server to connect to."}, ) vllm_server_timeout: int | None = Field( default=None, json_schema_extra={ - "description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " - "after the timeout, a `ConnectionError` is raised." + "description": "Total timeout (in seconds) to wait for the vLLM server to respond." }, ) vllm_guided_decoding_regex: str | None = Field( default=None, - json_schema_extra={ - "description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled." - }, + json_schema_extra={"description": "Regex for vLLM guided decoding."}, ) reward_funcs: list[str] | None = Field( default=None, - json_schema_extra={"description": "List of reward functions to load"}, + json_schema_extra={ + "description": "List of reward functions to load. Paths must be importable from current dir." + }, ) reward_weights: list[float] | None = Field( default=None, json_schema_extra={ - "description": "Weights for each reward function. Must match the number of reward functions." + "description": "List of reward weights for the reward functions." }, ) num_generations: int | None = Field( default=None, - json_schema_extra={ - "description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value." - }, + json_schema_extra={"description": "Number of generations to sample."}, ) log_completions: bool | None = Field( default=False, - json_schema_extra={"description": "Whether to log completions"}, + json_schema_extra={"description": "Whether to log completions."}, ) num_completions_to_print: int | None = Field( default=None, json_schema_extra={ - "description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged." + "description": "Number of completions to print when log_completions is True." }, ) sync_ref_model: bool | None = Field( default=False, - json_schema_extra={ - "description": ( - "Whether to sync the reference model every `ref_model_sync_steps` " - "steps, using the `ref_model_mixup_alpha` parameter." - ) - }, + json_schema_extra={"description": "Whether to sync the reference model."}, ) ref_model_mixup_alpha: float | None = Field( default=0.9, - json_schema_extra={ - "description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`." - }, + json_schema_extra={"description": "Mixup alpha for the reference model."}, ) ref_model_sync_steps: int | None = Field( default=64, - json_schema_extra={ - "description": "Sync steps for the reference model. Requires `sync_ref_model=True`." - }, + json_schema_extra={"description": "Sync steps for the reference model."}, ) scale_rewards: bool = Field( default=True, json_schema_extra={ - "description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation." + "description": "Whether to scale rewards by their standard deviation." }, ) @@ -124,13 +114,13 @@ class TRLConfig(BaseModel): repetition_penalty: float | None = Field( default=None, json_schema_extra={ - "description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far." + "description": "Penalty for tokens that appear in prompt and generated text." }, ) num_iterations: int | None = Field( default=None, json_schema_extra={ - "description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO." + "description": "Number of iterations per batch (μ) for GRPO." }, ) epsilon: float | None = Field( @@ -152,12 +142,12 @@ class TRLConfig(BaseModel): loss_type: str | None = Field( default=None, json_schema_extra={ - "description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`." + "description": "Loss formulation to use. Supported values: grpo, bnpo, dr_grpo." }, ) mask_truncated_completions: bool = Field( default=False, json_schema_extra={ - "description": "When enabled, truncated completions are excluded from the loss calculation." + "description": "Whether to exclude truncated completions from loss calculation." }, ) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py new file mode 100644 index 000000000..5a6bf43b3 --- /dev/null +++ b/src/axolotl/utils/schemas/validation.py @@ -0,0 +1,1073 @@ +"""Module with validation methods for config pydantic model.""" + +# pylint: disable=too-many-lines + +import logging + +from pydantic import ( + field_validator, + model_validator, +) +from transformers.utils.import_utils import is_torch_npu_available + +from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType + +LOG = logging.getLogger(__name__) + +SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} + + +class DatasetValidationMixin: + """Validation methods related to dataset configuration.""" + + @field_validator("seed", mode="after") + @classmethod + def set_default_seed(cls, seed): + if seed is None: + LOG.info("`seed` not set in config; setting to 42") + seed = 42 + return seed + + @field_validator("datasets", mode="before") + @classmethod + def deprecate_sharegpt_datasets(cls, datasets): + for _, ds_cfg in enumerate(datasets): + ds_type = ( + ds_cfg.get("type") + if isinstance(ds_cfg, dict) + else getattr(ds_cfg, "type", None) + ) + if not ds_type: + continue + + if isinstance(ds_type, dict): + continue + + if isinstance(ds_type, str) and ds_type.startswith("sharegpt"): + raise ValueError( + "`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead." + ) + + return datasets + + @model_validator(mode="before") + @classmethod + def check_dataset_or_pretraining_dataset(cls, data): + if data.get("datasets") is None and data.get("pretraining_dataset") is None: + raise ValueError("either datasets or pretraining_dataset is required") + return data + + @model_validator(mode="before") + @classmethod + def check_push_ds_auth(cls, data): + if ( + data.get("push_dataset_to_hub") + and data.get("hf_use_auth_token") is not True + ): + raise ValueError( + "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_val_w_test_datasets(cls, data): + if data.get("test_datasets") and data.get("val_set_size"): + raise ValueError( + "non-zero val_set_size should not be used with test_datasets configuration" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_test_datasets_bench(cls, data): + if ( + data.get("do_bench_eval") + and not data.get("test_datasets") + and not data.get("val_set_size") + ): + LOG.warning( + "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." + ) + data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] + return data + + @model_validator(mode="before") + @classmethod + def check_eval_packing(cls, data): + # TODO also should check test_datasets and val_set_size as we can skip + # if there are no eval datasets/splits + if ( + data.get("sample_packing") + and data.get("eval_table_size") + and data.get("eval_sample_packing") is not False + ): + raise ValueError( + "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." + ) + if ( + data.get("sample_packing") + and data.get("eval_sample_packing") is None + and not data.get("eval_table_size") + ): + LOG.info( + "explicitly setting `eval_sample_packing` to match `sample_packing`" + ) + data["eval_sample_packing"] = True + + if ( + data.get("sample_packing") + and data.get("eval_sample_packing") is False + and data.get("remove_unused_columns") is None + ): + LOG.info( + "setting `remove_unused_columns: false` for when sample_packing and eval_sample_packing don't match" + ) + data["remove_unused_columns"] = False + + return data + + @model_validator(mode="before") + @classmethod + def check_mm_prepare(cls, data): + if data.get("skip_prepare_dataset"): + if data.get("remove_unused_columns") is None: + LOG.info( + "setting `remove_unused_columns: false` for skip_prepare_dataset" + ) + data["remove_unused_columns"] = False + + return data + + +class AttentionValidationMixin: + """Validation methods related to attention mechanisms.""" + + @model_validator(mode="before") + @classmethod + def check_attention_fields(cls, data): + fields = ( + "xformers_attention", + "sdp_attention", + "s2_attention", + "flash_attention", + "flex_attention", + ) + non_empty_count = sum(1 for field in fields if data.get(field)) + + if non_empty_count > 1: + raise ValueError(f"Only one of {', '.join(fields)} must be set") + return data + + @model_validator(mode="before") + @classmethod + def check_sample_packing_without_attention(cls, data): + if ( + data.get("sample_packing") + and not data.get("flash_attention") + and not data.get("sdp_attention") + and not data.get("flex_attention") + and not data.get("xformers_attention") + ): + LOG.warning( + "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_sample_packing_with_s2attn(cls, data): + if data.get("sample_packing") and data.get("s2_attention"): + raise ValueError( + "Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not currently support sample packing." + ) + return data + + +class TrainingValidationMixin: + """Validation methods related to training configuration.""" + + @model_validator(mode="before") + @classmethod + def check_batch_size_fields(cls, data): + fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size") + non_empty_count = sum(1 for field in fields if data.get(field)) + + if non_empty_count < 2: + raise ValueError(f"At least two of {', '.join(fields)} must be set") + return data + + @model_validator(mode="before") + @classmethod + def hint_sample_packing_padding(cls, data): + if data.get("sample_packing"): + pad_to_sequence_len = data.get("pad_to_sequence_len") + if pad_to_sequence_len is False: + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using sample_packing" + ) + elif pad_to_sequence_len is None: + LOG.info( + "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing" + ) + data["pad_to_sequence_len"] = True + return data + + @model_validator(mode="before") + @classmethod + def hint_reward_model_pad(cls, data): + if data.get("reward_model") and not data.get("pad_to_sequence_len"): + LOG.warning( + "`pad_to_sequence_len: true` is recommended when using reward_model" + ) + if data.get("pad_to_sequence_len") is None: + data["pad_to_sequence_len"] = True + return data + + @model_validator(mode="before") + @classmethod + def check_gas_bsz(cls, data): + if data.get("gradient_accumulation_steps") and data.get("batch_size"): + raise ValueError( + "please set only one of gradient_accumulation_steps or batch_size" + ) + return data + + @model_validator(mode="before") + @classmethod + def hint_eval_train_mbsz(cls, data): + if ( + data.get("eval_batch_size") + and data.get("micro_batch_size") + and data.get("eval_batch_size") != data.get("micro_batch_size") + ): + LOG.warning( + "eval_batch_size != micro_batch_size. This can lead to VRAM instability." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_warmup(cls, data): + if data.get("warmup_steps") and data.get("warmup_ratio"): + raise ValueError("warmup_steps and warmup_ratio are mutually exclusive") + return data + + @model_validator(mode="before") + @classmethod + def check_saves(cls, data): + if ( + data.get("save_strategy") + and data.get("save_steps") + and data.get("save_strategy") != "steps" + ): + raise ValueError( + "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." + ) + if data.get("saves_per_epoch") and data.get("save_steps"): + raise ValueError( + "save_steps and saves_per_epoch are mutually exclusive and cannot be used together." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_push_save(cls, data): + if data.get("hub_model_id") and ( + data.get("save_strategy") not in ["steps", "epoch", None] + ): + LOG.warning( + "hub_model_id is set without any models being saved. To save a model, set save_strategy." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_evals(cls, data): + if ( + data.get("eval_strategy") + and data.get("eval_steps") + and data.get("eval_strategy") != "steps" + ): + raise ValueError( + "eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps." + ) + + if ( + data.get("val_set_size") == 0 + and (data.get("eval_steps") or data.get("eval_strategy")) + and not data.get("test_datasets") + and data.get("eval_strategy") != "no" + ): + raise ValueError( + "eval_steps and eval_strategy are not supported with val_set_size == 0" + ) + if data.get("evals_per_epoch") and data.get("eval_steps"): + raise ValueError( + "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." + ) + if ( + data.get("evals_per_epoch") + and data.get("eval_strategy") + and data.get("eval_strategy") != "steps" + ): + raise ValueError( + "eval_strategy must be empty or set to `steps` when used with evals_per_epoch." + ) + + if data.get("do_bench_eval") and not ( + data.get("evals_per_epoch") or data.get("eval_steps") + ): + raise ValueError( + "do_bench_eval requires evals_per_epoch or eval_steps to be set." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_neftune(cls, data): + if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"): + data["neftune_noise_alpha"] = data["noisy_embedding_alpha"] + del data["noisy_embedding_alpha"] + elif data.get("noisy_embedding_alpha") and data.get("neftune_noise_alpha"): + raise ValueError( + "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting" + ) + return data + + @model_validator(mode="after") + def check_fft_possible_bad_config(self): + if ( + # pylint: disable=too-many-boolean-expressions + not (self.bf16 or self.bfloat16) + and (self.fp16 or self.float16) + and not self.adapter + and not self.flash_attention + and self.sample_packing + ): + LOG.warning( + "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA." + ) + # ValueError: Attempting to unscale FP16 gradients. + # OR + # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half + return self + + @model_validator(mode="before") + @classmethod + def check_use_reentrant_mismatch(cls, data): + if ( + data.get("unfrozen_parameters") + and data.get("gradient_checkpointing_kwargs") + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") + is True + ): + # https://github.com/huggingface/transformers/issues/21381 + raise ValueError( + "`use_reentrant` must be false when used with partially frozen model." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_eval_strategy(cls, data): + if ( + data.get("evaluation_strategy") is not None + and data.get("eval_strategy") is None + ): + LOG.info( + "explicitly setting `eval_strategy` from the `evaluation_strategy`" + ) + data["eval_strategy"] = data.get("evaluation_strategy") + return data + + @model_validator(mode="before") + @classmethod + def check_causal_lm_evals(cls, data): + if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"): + raise ValueError( + "do_causal_lm_eval is enabled, eval_sample_packing must be set to False" + ) + + if data.get("eval_causal_lm_metrics"): + if not isinstance(data.get("eval_causal_lm_metrics"), list): + raise ValueError("eval_causal_lm_metrics must be a list") + # only ["sacrebleu", "comet", "ter", "chrf"] supported + if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS: + raise ValueError( + f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_tokenizer_use_mistral_common(cls, data): + if data.get("tokenizer_use_mistral_common") is None: + if any( + "magistral" in name.lower() + for name in [ + data.get("base_model", ""), + data.get("base_model_config", ""), + data.get("tokenizer_config", ""), + ] + ): + LOG.warning( + "tokenizer_use_mistral_common auto inferred to True for Magistral models. Please set it to True explicitly if you want to use mistral-common tokenizer." + ) + data["tokenizer_use_mistral_common"] = True + + return data + + @field_validator("tokenizer_use_mistral_common", mode="after") + @classmethod + def check_mistral_common_import(cls, tokenizer_use_mistral_common): + if tokenizer_use_mistral_common: + try: + import mistral_common # noqa: F401 # pylint:disable=unused-import + except ImportError as exception: + raise ImportError( + "mistral-common is required for mistral models. Please install it with `pip install axolotl` or `pip install -e .`." + ) from exception + + return tokenizer_use_mistral_common + + @model_validator(mode="before") + @classmethod + def check_mistral_common_incompatible_options(cls, data): + if not data.get("tokenizer_use_mistral_common"): + return data + + # NOTE: mistral-common tokenizer is not compatible with editing tokenizer at the moment + + if data.get("added_tokens_overrides"): + raise ValueError( + "added_tokens_overrides is not supported with mistral-common tokenizer" + ) + + if data.get("special_tokens"): + raise ValueError( + "special_tokens override is not supported with mistral-common tokenizer" + ) + + if data.get("tokens"): + raise ValueError( + "tokens override is not supported with mistral-common tokenizer" + ) + + if data.get("chat_template"): + raise ValueError( + "Setting chat_template is not supported with mistral-common tokenizer" + ) + + return data + + +class LoRAValidationMixin: + """Validation methods related to LoRA/QLoRA configuration.""" + + @model_validator(mode="before") + @classmethod + def check_lr_groups(cls, data): + if data.get("lr_groups") and data.get("loraplus_lr_ratio"): + raise ValueError("lr_groups and loraplus_lr_ratio cannot be used together.") + return data + + @model_validator(mode="before") + @classmethod + def check_frozen(cls, data): + if ( + data.get("adapter") + and data.get("peft_layers_to_transform") + and data.get("unfrozen_parameters") + ): + raise ValueError( + "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_peft_layers_pattern(cls, data): + if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"): + raise ValueError( + "peft_layers_pattern requires peft_layers_to_transform to be set" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_qlora_unsloth(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + if data.get("adapter") == "lora" and data.get("load_in_8bit"): + raise ValueError( + "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_8bit(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ): + if data.get("adapter") == "lora" and data.get("load_in_8bit"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_axolotl_unsloth(cls, data): + is_lora_kernel = any( + data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] + ) + is_unsloth_lora = any( + data.get(k) + for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] + ) + if is_lora_kernel and is_unsloth_lora: + raise ValueError( + "both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)" + ) + return data + + @model_validator(mode="after") + def check_fused_lora(self): + if self.adapter in ["lora", "qlora"] and ( + self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp + ): + raise ValueError("Fused modules are not supported with LoRA/QLoRA") + return self + + @model_validator(mode="after") + def hint_lora_8bit(self): + loftq = ( + self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits + ) + if not self.load_in_8bit and self.adapter == "lora" and not loftq: + LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") + return self + + @model_validator(mode="before") + @classmethod + def warn_qlora_zero3_w_use_reentrant(cls, data): + if ( + data.get("adapter") == "qlora" + and data.get("gradient_checkpointing_kwargs", {}) + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") + is False + and data.get("deepspeed", "") is not None + and "zero3" in data.get("deepspeed", "") + ): + # may result in: + # torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: + # Recomputed values for the following tensors have different metadata + # than during the forward pass. + LOG.warning( + "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_kernel_8bit(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ): + if data.get("adapter") == "lora" and data.get("load_in_8bit"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_kernel_rl(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ) and data.get("rl"): + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment." + ) + return data + + +class RLValidationMixin: + """Validation methods related to RL training configuration.""" + + @model_validator(mode="before") + @classmethod + def check_sample_packing_w_rl(cls, data): + if data.get("sample_packing") and data.get("rl"): + raise ValueError("`sample_packing: true` does not work with RLHF training") + return data + + @model_validator(mode="before") + @classmethod + def check_kto_config(cls, data): + if data.get("rl") == "kto": + if data.get("sample_packing") or data.get("eval_sample_packing"): + raise ValueError("sample_packing is not supported with kto") + + if data.get("remove_unused_columns") is not False: + raise ValueError("Set `remove_unused_columns: False` when using kto") + return data + + @model_validator(mode="before") + @classmethod + def check_grpo_liger_sequence_parallel(cls, data): + if ( + data.get("rl") == "grpo" + and data.get("trl", {}) + and data.get("trl").get("use_liger_loss") + and data.get("sequence_parallel_degree", 1) > 1 + ): + raise ValueError("GRPO + SP + Liger not currently supported") + return data + + @model_validator(mode="before") + @classmethod + def check_rl_config_gradient_checkpointing(cls, data): + # TODO: SalmanMohammadi + # Distributed RL with QLoRA + gradient checkpointing + # and use_reentrant = True is broken upstream in TRL + # pylint: disable=too-many-boolean-expressions + if ( + data.get("rl") + and data.get("gradient_checkpointing") + and data.get("gradient_checkpointing_kwargs") + and data.get("gradient_checkpointing_kwargs").get("use_reentrant") + and data.get("load_in_4bit") + and data.get("adapter") == "qlora" + and data.get("capabilities") + and data.get("capabilities").get("n_gpu", 1) > 1 + ): + raise ValueError( + "The `use_reentrant: True` implementation of gradient checkpointing " + "is not supported for distributed RL training with QLoRA. Please set " + "`use_reentrant: False` in `gradient_checkpointing_kwargs`." + ) + return data + + +class OptimizationValidationMixin: + """Validation methods related to optimization and performance.""" + + @model_validator(mode="after") + def check_adamw_optimizer_params(self): + if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and ( + not self.optimizer or "adamw" not in str(self.optimizer).lower() + ): + LOG.warning("adamw hyperparameters found, but no adamw optimizer set") + return self + + @model_validator(mode="before") + @classmethod + def check_muon_deepspeed_fsdp(cls, data): + if data.get("optimizer") == "muon" and ( + data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config") + ): + raise ValueError( + "Muon optimizer is currently incompatible with DeepSpeed and FSDP" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_batch_flattening_fa(cls, data): + if data.get("batch_flattening"): + batch_flattening_auto = data.get("batch_flattening") == "auto" + if not data.get("flash_attention") and not batch_flattening_auto: + raise ValueError("batch_flattening requires flash attention") + if data.get("sample_packing") and not batch_flattening_auto: + raise ValueError("batch_flattening not compatible with sample_packing") + if data.get("micro_batch_size") == 1 and not batch_flattening_auto: + LOG.warning("batch_flattening has no effect with micro_batch_size == 1") + + if ( + batch_flattening_auto + and data.get("flash_attention") + and not data.get("sample_packing") + and data.get("micro_batch_size") > 1 + ): + data["batch_flattening"] = True + elif batch_flattening_auto: + data["batch_flattening"] = False + + return data + + @model_validator(mode="before") + @classmethod + def check_torch_compile_deepspeed(cls, data): + if data.get("deepspeed") and data.get("torch_compile"): + raise ValueError( + "torch_compile should be set within your deepspeed config file" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_xentropy_patch_conflicts(cls, data): + if data.get("flash_attn_cross_entropy") and data.get( + "unsloth_cross_entropy_loss" + ): + raise ValueError( + "flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_offload_w_8bit_optimizer(cls, data): + if ( + data.get("fsdp") + and "8bit" in data.get("optimizer", "") + and data.get("fsdp_config") + and data["fsdp_config"].get("fsdp_offload_params") + and str(data["fsdp_config"].get("fsdp_version")) != "2" + ): + raise ValueError( + f"FSDP Offload not compatible with {data.get('optimizer')}" + ) + if ( + data.get("fsdp") + and "8bit" in data.get("optimizer", "") + and data.get("fsdp_config") + and str(data["fsdp_config"].get("fsdp_version")) == "2" + ): + if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]: + # CUDA ops errors with bnb 8bit optimizer + FSDP2 + raise ValueError( + f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead" + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_sharded_state_dict_w_safetensors(cls, data): + if ( + data.get("fsdp") + and data.get("save_safetensors") + and data.get("fsdp_config") + and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" + ): + raise ValueError( + "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" + ) + return data + + +class SystemValidationMixin: + """Validation methods related to system and hardware configuration.""" + + @model_validator(mode="before") + @classmethod + def check_mem_mismatch(cls, data): + if ( + data.get("max_memory") is not None + and data.get("gpu_memory_limit") is not None + ): + raise ValueError( + "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_fsdp_deepspeed(cls, data): + if data.get("deepspeed") and data.get("fsdp"): + raise ValueError("deepspeed and fsdp cannot be used together.") + return data + + @model_validator(mode="before") + @classmethod + def check_npu_config(cls, data): + if is_torch_npu_available(): + # check attention config + attn_list = ["flash_attention", "sdp_attention", "s2_attention"] + for attn in attn_list: + if data.get(attn): + raise NotImplementedError( + f"{attn} is currently not supported in Ascend npu, please disable this configuration." + ) + + # check quant config + if data.get("optimizer") is not None and "bit" in data.get("optimizer"): + optimizer = data.get("optimizer") + raise NotImplementedError( + f"{optimizer} is currently not supported in Ascend npu, choose another one please." + ) + + quant_list = ["load_in_8bit", "load_in_4bit"] + for quant in quant_list: + if data.get(quant): + raise NotImplementedError( + f"Quantification is currently not supported in Ascend npu, please disable {quant}." + ) + + # check dtype config + if data.get("tf32"): + raise NotImplementedError( + "tf32 dtype is currently not supported in Ascend npu, please disable this configuration" + ) + + return data + + +class ChatTemplateValidationMixin: + """Validation methods related to chat template configuration.""" + + @model_validator(mode="before") + @classmethod + def check_chat_template_config(cls, data): + # if chat_template is set to jinja, chat_template_jinja is required + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) + + # If chat_template_jinja is set, set chat_template to jinja + if data.get("chat_template_jinja") and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.jinja + + return data + + +class PretrainingValidationMixin: + """Validation methods related to pretraining configuration.""" + + @model_validator(mode="before") + @classmethod + def check_pretraining_w_max_steps(cls, data): + if data.get("pretraining_dataset") and not data.get("max_steps"): + raise ValueError( + "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_pretraining_w_group_by_length(cls, data): + if data.get("pretraining_dataset") and data.get("group_by_length"): + LOG.warning( + "You probably want to disable group_by_length as it will force a streamed dataset to download completely." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_pretraining_split_batches_accelerate(cls, data): + # alternatively set ACCELERATE_SPLIT_BATCHES=False + if data.get("pretraining_dataset"): + accelerator_config = data.get("accelerator_config", {}) + if not accelerator_config: + data["accelerator_config"] = { + "split_batches": False, + "dispatch_batches": False, + } + else: + if accelerator_config.get("split_batches") is None: + data["accelerator_config"]["split_batches"] = False + if accelerator_config.get("dispatch_batches") is None: + data["accelerator_config"]["dispatch_batches"] = False + return data + + +class ModelCompatibilityValidationMixin: + """Validation methods for specific model compatibility.""" + + @model_validator(mode="after") + def check_falcon_fsdp(self): + if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp: + raise ValueError("FSDP is not supported for falcon models") + return self + + @model_validator(mode="after") + def check_mpt_checkpointing(self): + if ( + self.base_model and "mpt" in self.base_model.lower() + ) and self.gradient_checkpointing: + raise ValueError("gradient_checkpointing is not supported for MPT models") + return self + + @model_validator(mode="after") + def check_offload_grad_checkpointing(self): + if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth": + LOG.warning( + "`unsloth` is deprecated for gradient_checkpointing, use `offload`" + ) + self.gradient_checkpointing = "offload" + return self + + @model_validator(mode="after") + def check_better_transformers(self): + if self.flash_optimum is True: + if self.adapter: + LOG.warning( + "BetterTransformers probably doesn't work with PEFT adapters" + ) + if self.fp16 or self.bf16: + raise ValueError("AMP is not supported with BetterTransformer") + if self.float16 is not True and self.bfloat16 is not True: + LOG.warning( + "You should probably set bfloat16 or float16 to true to " + "load the model in float16 for BetterTransformers" + ) + return self + + @model_validator(mode="before") + @classmethod + def check_gptq_w_revision(cls, data): + if data.get("gptq") and data.get("revision_of_model"): + raise ValueError( + "revision_of_model is not supported for GPTQ models. " + + "Please download the model from HuggingFace Hub manually for correct branch, " + + "point to its path, and remove revision_of_model from the config." + ) + return data + + +class ComplexValidationMixin: + """Complex validation methods that involve multiple systems.""" + + @field_validator("neftune_noise_alpha") + @classmethod + def validate_neftune_noise_alpha(cls, neftune_noise_alpha): + if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0: + raise ValueError("neftune_noise_alpha must be > 0.0") + return neftune_noise_alpha + + @model_validator(mode="after") + def check_rl_beta(self): + if self.dpo_beta and not self.rl_beta: + self.rl_beta = self.dpo_beta + del self.dpo_beta + return self + + @model_validator(mode="after") + def check_simpo_warmup(self): + if self.rl is RLType.SIMPO and self.warmup_ratio: + raise ValueError( + "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" + ) + return self + + @model_validator(mode="after") + def check_relora(self): + if self.relora_steps: + if self.adapter not in ("lora", "qlora"): + raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") + + if self.fsdp: + raise ValueError("fsdp not supported with ReLoRA") + + if self.deepspeed: + raise ValueError("deepspeed not supported with ReLoRA") + + if self.lr_scheduler == "one_cycle": + raise ValueError( + "ReLoRA is not compatible with the one_cycle scheduler" + ) + + if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp: + raise ValueError("Fused modules are not supported with ReLoRA") + return self + + @model_validator(mode="after") + def check_early_stopping(self): + if self.early_stopping_patience: + if not self.save_steps or not self.eval_steps: + raise ValueError( + "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps." + ) + if self.save_steps % self.eval_steps != 0: + raise ValueError( + "`early_stopping_patience` requires that eval_steps should evenly divide save_steps." + ) + return self + + @model_validator(mode="after") + def check_sequence_parallel_degree(self): + if not self.sequence_parallel_degree: + self.sequence_parallel_degree = 1 + elif self.sequence_parallel_degree > 1: + if not self.flash_attention: + raise ValueError( + "flash_attention: true must be set with sequence_parallel_degree > 1" + ) + + if self.sample_packing and self.micro_batch_size > 1: + raise ValueError( + "micro_batch_size must be set to 1 when sample_packing is enabled " + "due to a `ring-flash-attn` requirement" + ) + + try: + import ring_flash_attn # noqa: F401 # pylint:disable=unused-import + except ImportError as exception: + raise ImportError( + "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " + "Please install it with `pip install axolotl[ring-flash-attn] " + "or `pip install ring-flash-attn>=0.1.4`." + ) from exception + + LOG.warning( + "Sequence parallelism (SP) is enabled with " + f"sequence_parallel_degree={self.sequence_parallel_degree}. " + "Please note that logged losses may differ slightly to the non-SP " + "losses due to transformers Trainer implementation details. " + "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " + "for more details." + ) + + return self + + @model_validator(mode="after") + def validate_ring_attn_func(self): + if getattr(self, "sequence_parallel_degree", 1) == 1: + return self + + if self.ring_attn_func is not None: + self.ring_attn_func = RingAttnFunc(self.ring_attn_func) + else: + # Default ring attention function selection + sample_packing = getattr(self, "sample_packing", False) + self.ring_attn_func = ( + RingAttnFunc.VARLEN_LLAMA3 + if sample_packing + else RingAttnFunc.BATCH_RING + ) + + return self + + +# pylint: disable=too-many-ancestors +class ValidationMixin( + DatasetValidationMixin, + AttentionValidationMixin, + TrainingValidationMixin, + LoRAValidationMixin, + RLValidationMixin, + OptimizationValidationMixin, + SystemValidationMixin, + ChatTemplateValidationMixin, + PretrainingValidationMixin, + ModelCompatibilityValidationMixin, + ComplexValidationMixin, +): + """Full validation mixin for Axolotl configuration.""" diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index ec5360fa3..33ddadf78 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,6 @@ from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support @@ -483,6 +482,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree ) ) + if cfg.dataloader_drop_last: + # drop the last batch for each epoch + total_num_steps -= int(math.ceil(cfg.num_epochs)) def calc_sample_packing_eff_est(estimates: List[float]): LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") @@ -630,6 +632,8 @@ def setup_trainer( A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based on the provided parameters. """ + from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder + if ( cfg.torch_compile and cfg.fsdp_config diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index 2bd1fbf3d..212450e89 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -5,10 +5,9 @@ e2e tests for kd trainer support in Axolotl from pathlib import Path import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port -from axolotl.common.datasets import load_datasets -from axolotl.train import train -from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.dict import DictDefault from tests.e2e.utils import check_tensorboard, require_torch_2_5_1 @@ -17,8 +16,8 @@ from tests.e2e.utils import check_tensorboard, require_torch_2_5_1 @pytest.fixture(name="kd_min_cfg") def min_cfg(temp_dir): return { - "base_model": "osllmai-community/Llama-3.2-1B", - "tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer", + "base_model": "Qwen/Qwen3-0.6B", + "tokenizer_config": "winglian/qwen3-14b-math", "plugins": [ "axolotl.integrations.kd.KDPlugin", "axolotl.integrations.liger.LigerPlugin", @@ -31,20 +30,22 @@ def min_cfg(temp_dir): "kd_ce_alpha": 0.1, "kd_alpha": 0.9, "kd_temperature": 1.0, + "kd_beta": 0.0, + "kd_normalize_topk": True, "dataloader_prefetch_factor": 8, "dataloader_num_workers": 4, "dataloader_pin_memory": True, "datasets": [ { - "path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", - "type": "axolotl.integrations.kd.chat_template", - "field_messages": "messages_combined", + "path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized", + "type": "chat_template", "split": "train", - "logprobs_field": "llm_text_generation_vllm_logprobs", - "temperature": 1.0, - "preprocess_shards": 2, + "split_thinking": True, + "eot_tokens": ["<|im_end|>"], + "data_files": ["train/batch-000000.parquet"], }, ], + "skip_prepare_dataset": True, "val_set_size": 0.0, "sequence_len": 2048, "sample_packing": True, @@ -80,17 +81,29 @@ class TestKnowledgeDistillation: def test_llama_kd(self, temp_dir, kd_min_cfg): cfg = DictDefault(kd_min_cfg) # pylint: disable=duplicate-code - cfg = validate_config(cfg) - prepare_plugins(cfg) - normalize_config(cfg) - dataset_meta = load_datasets(cfg=cfg) + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "1", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) - train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() check_tensorboard( temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high" ) + @pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA") @pytest.mark.parametrize( "load_in_8bit", [True, False], @@ -110,12 +123,22 @@ class TestKnowledgeDistillation: | kd_min_cfg ) # pylint: disable=duplicate-code - cfg = validate_config(cfg) - prepare_plugins(cfg) - normalize_config(cfg) - dataset_meta = load_datasets(cfg=cfg) + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) - train(cfg=cfg, dataset_meta=dataset_meta) + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "1", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) assert (Path(temp_dir) / "adapter_model.safetensors").exists() check_tensorboard( temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high" diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py index e90def2b7..8883e0135 100644 --- a/tests/e2e/multigpu/patched/test_sp.py +++ b/tests/e2e/multigpu/patched/test_sp.py @@ -91,7 +91,10 @@ class TestSequenceParallelism: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high" + temp_dir + "/runs", + "train/train_loss", + threshold, + "Train Loss (%s) is too high", ) @pytest.mark.parametrize( diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py index 42c3c00c8..c8f14330d 100644 --- a/tests/e2e/multigpu/solo/test_flex.py +++ b/tests/e2e/multigpu/solo/test_flex.py @@ -85,5 +85,5 @@ class TestPackedFlex: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py index 9bff25f40..b4cb6e59d 100644 --- a/tests/e2e/multigpu/test_gemma3.py +++ b/tests/e2e/multigpu/test_gemma3.py @@ -91,5 +91,5 @@ class TestMultiGPUGemma3: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 9c4bf5054..a8ed6bda0 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -89,7 +89,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -154,7 +154,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) def test_dpo_lora_ddp(self, temp_dir): @@ -232,7 +232,7 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", loss_threshold, - "Train Loss is too high", + "Train Loss (%s) is too high", ) def test_dpo_qlora_ddp(self, temp_dir): @@ -310,7 +310,7 @@ class TestMultiGPULlama: temp_dir + "/runs", "train/train_loss", loss_threshold, - "Train Loss is too high", + "Train Loss (%s) is too high", ) @pytest.mark.parametrize( @@ -380,7 +380,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -452,7 +452,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @require_torch_2_6_0 @@ -533,7 +533,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high" ) def test_fsdp_qlora_prequant_packed(self, temp_dir): @@ -613,7 +613,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -697,7 +697,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.4, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -771,7 +771,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -845,7 +845,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @pytest.mark.skip( @@ -912,5 +912,5 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index f2c812eb5..22023507a 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -75,7 +75,7 @@ class TestMultiGPURay: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) @require_torch_lt_2_6_0 @@ -133,5 +133,5 @@ class TestMultiGPURay: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index 4e3cbc50d..ca8b21178 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -78,5 +78,5 @@ class TestFAXentropyLlama: check_model_output_exists(temp_dir, cfg) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 9567c0b18..69171481c 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -73,7 +73,7 @@ class TestUnslothQLoRA: check_model_output_exists(temp_dir, cfg) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) def test_unsloth_llama_qlora_unpacked(self, temp_dir): @@ -123,7 +123,7 @@ class TestUnslothQLoRA: check_model_output_exists(temp_dir, cfg) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( @@ -178,5 +178,5 @@ class TestUnslothQLoRA: check_model_output_exists(temp_dir, cfg) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py index 8d1a0c7d1..f6b8c6283 100644 --- a/tests/e2e/solo/test_flex.py +++ b/tests/e2e/solo/test_flex.py @@ -63,5 +63,5 @@ class TestPackedFlex(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 6944c6f5e..fdebf2173 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -69,5 +69,5 @@ class TestPretrainLlama: temp_dir + "/runs", "train/train_loss", loss_threshold, - "Train Loss is too high", + "Train Loss (%s) is too high", ) diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index 463f7c838..cc2db72e0 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -62,5 +62,5 @@ class TestPackedLlama(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high" ) diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py index 964bf3c1c..ef726079d 100644 --- a/tests/e2e/test_qat.py +++ b/tests/e2e/test_qat.py @@ -129,5 +129,5 @@ class TestQATLlama: temp_dir + "/runs", "train/train_loss", loss_threshold, - "Train Loss is too high", + "Train Loss (%s) is too high", ) diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py index 304fda1cc..5d52bcc86 100644 --- a/tests/e2e/test_reward_model_smollm2.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -66,6 +66,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high" ) check_model_output_exists(temp_dir, cfg) diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py index 98488a988..d440565d2 100644 --- a/tests/prompt_strategies/conftest.py +++ b/tests/prompt_strategies/conftest.py @@ -143,6 +143,12 @@ def fixture_phi35_tokenizer(): return tokenizer +@pytest.fixture(name="phi4_tokenizer", scope="session", autouse=True) +def fixture_phi4_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning") + return tokenizer + + @pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True) def fixture_gemma2_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit") diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index fcf860f81..f847cab4a 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -33,15 +33,14 @@ PARAMETRIZE_PARAMS = [ "mistralv03_tokenizer_chat_template_jinja", "[/INST]", ), - # TODO: temporarily skip gemma due to gemma3 template - # Re-enable on new chat_template implementation for perf - # ( - # "gemma2_tokenizer", - # "jinja", - # "gemma2_tokenizer_chat_template_jinja", - # "", - # ), + ( + "gemma2_tokenizer", + "jinja", + "gemma2_tokenizer_chat_template_jinja", + "", + ), ("phi35_tokenizer", "phi_35", None, "<|end|>"), + ("phi4_tokenizer", "phi_4", None, "<|im_end|>"), ] @@ -95,11 +94,7 @@ class TestChatTemplateConfigurations: if ( turn_idx == 0 and turn.get("from") in ["system", "context"] - and ( - "mistral" in tokenizer.name_or_path.lower() - or "gemma" - in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template - ) + and ("mistral" in tokenizer.name_or_path.lower()) ): assert ( start_idx == -1 and end_idx == -1 @@ -935,36 +930,14 @@ class TestChatTemplateConfigurations: "messages", ) - if chat_template == "llama3": - assert variables == {"role", "content"}, ( - f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) - elif chat_template == "chatml": - assert variables == {"role", "content"}, ( - f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) - elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer": - assert variables == {"role", "content", "tool_call_id", "tool_calls"}, ( - f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) - elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer": - assert variables == {"role", "content"}, ( - f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) - elif chat_template == "phi_35": - assert variables == {"role", "content"}, ( - f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n" - f"Got: {variables}\n" - f"Chat template: {actual_jinja_template}" - ) + # Special case for Mistral with additional tool variables + if chat_template == "jinja" and tokenizer == "mistralv03_tokenizer": + expected_variables = {"role", "content", "tool_call_id", "tool_calls"} + # Most chat templates use the standard role and content variables + elif chat_template in ["llama3", "chatml", "phi_35", "phi_4"] or ( + chat_template == "jinja" and tokenizer == "gemma2_tokenizer" + ): + expected_variables = {"role", "content"} else: LOG.warning( f"Unsupported chat template: {chat_template} with {chat_template_jinja}" @@ -973,6 +946,12 @@ class TestChatTemplateConfigurations: f"Unsupported chat template: {chat_template} with {chat_template_jinja}" ) + assert variables == expected_variables, ( + f"Expected variables: {expected_variables} from {tokenizer}/{chat_template}\n" + f"Got: {variables}\n" + f"Chat template: {actual_jinja_template}" + ) + def test_eot_tokens_conflict_with_eos_token( self, tokenizer, diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py index 79429b731..e807111aa 100644 --- a/tests/prompt_strategies/test_chat_templates_thinking.py +++ b/tests/prompt_strategies/test_chat_templates_thinking.py @@ -11,8 +11,6 @@ from axolotl.prompt_strategies.chat_template import ( ) from axolotl.utils.dict import DictDefault -from tests.hf_offline_utils import enable_hf_offline - @pytest.fixture(name="messages_w_reasoning") def messages_w_reasoning_fixture(): @@ -59,7 +57,6 @@ def messages_w_reasoning_fixture(): @pytest.fixture(name="qwen3_tokenizer") -@enable_hf_offline def qwen3_tokenizer_fixture( download_qwen3_half_billion_model, ): # pylint: disable=unused-argument diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index 2b03c62f8..d91f63d94 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -81,6 +81,7 @@ class TestBatchedSamplerPacking: group_size=100000, bin_size=200, sequential=sequential, + drop_last=False, ) loader = DataLoader(