From 86472715dabf6be53e4f2dccce746c2344f455a5 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 17 May 2025 00:05:55 +0700 Subject: [PATCH 01/19] fix: remove doc string imports in monkeypatches (#2671) [skip ci] --- .../cut_cross_entropy/monkeypatch/cohere.py | 10 ---------- .../cut_cross_entropy/monkeypatch/gemma.py | 10 ---------- .../cut_cross_entropy/monkeypatch/gemma3.py | 12 ------------ .../cut_cross_entropy/monkeypatch/llama.py | 10 ---------- .../cut_cross_entropy/monkeypatch/llama4.py | 13 ------------- .../cut_cross_entropy/monkeypatch/mistral3.py | 8 -------- .../cut_cross_entropy/monkeypatch/qwen2_moe.py | 10 ---------- .../cut_cross_entropy/monkeypatch/qwen2_vl.py | 10 ---------- .../cut_cross_entropy/monkeypatch/qwen3_moe.py | 11 ----------- src/axolotl/integrations/liger/models/deepseekv2.py | 4 ---- src/axolotl/integrations/liger/models/jamba.py | 10 ---------- src/axolotl/monkeypatch/gemma3.py | 8 -------- 12 files changed, 116 deletions(-) diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py index 99e17910e..ea9e10724 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py @@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import ( from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( - _CONFIG_FOR_DOC, - COHERE_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py index 4c8d2261a..ae3d8c6ef 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py @@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import ( from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.gemma.modeling_gemma import ( - _CONFIG_FOR_DOC, - GEMMA_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py index ccf0c160d..644e5cce7 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py @@ -20,15 +20,11 @@ from torch import nn from transformers.cache_utils import Cache, HybridCache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.gemma3.modeling_gemma3 import ( - _CONFIG_FOR_DOC, - GEMMA3_INPUTS_DOCSTRING, Gemma3CausalLMOutputWithPast, logger, ) from transformers.utils import ( - add_start_docstrings_to_model_forward, is_torchdynamo_compiling, - replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg @@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, @@ -170,10 +162,6 @@ def cce_forward( @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward_multimodal( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py index 42ab996b9..bed411ace 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py @@ -19,15 +19,9 @@ from transformers.modeling_outputs import ( CausalLMOutputWithPast, ) from transformers.models.llama.modeling_llama import ( - _CONFIG_FOR_DOC, - LLAMA_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple @@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py index 7204f5c90..3143e9c8d 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py @@ -16,22 +16,12 @@ from torch import nn from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama4.modeling_llama4 import ( - _CONFIG_FOR_DOC, - LLAMA4_INPUTS_DOCSTRING, Llama4CausalLMOutputWithPast, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) _PATCH_OPTS: PatchOptions | None = None -@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, @@ -160,9 +150,6 @@ def cce_forward( ) -@replace_return_docstrings( - output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward_multimodal( self, input_ids: torch.LongTensor | None = None, # type: ignore diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py index adb65fa8f..aa252701e 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py @@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import ( Mistral3CausalLMOutputWithPast, ) from transformers.models.mistral.modeling_mistral import ( - _CONFIG_FOR_DOC, - MISTRAL_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack from transformers.utils import ( - add_start_docstrings_to_model_forward, is_torchdynamo_compiling, - replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg @@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py index 0811bf55a..afe56266e 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py @@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import ( apply_lce, ) from transformers.models.qwen2_moe.modeling_qwen2_moe import ( - _CONFIG_FOR_DOC, - QWEN2MOE_INPUTS_DOCSTRING, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, load_balancing_loss_func, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple @@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py index 250c3ab6b..79af01cfa 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py @@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import ( ) from torch.nn import CrossEntropyLoss from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - _CONFIG_FOR_DOC, - QWEN2_VL_INPUTS_DOCSTRING, Qwen2VLCausalLMOutputWithPast, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) _PATCH_OPTS: PatchOptions | None = None -@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward_multimodal( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py index c5cd76f94..90466e64b 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py @@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import ( TransformersModelT, apply_lce, ) -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - _CONFIG_FOR_DOC, - QWEN3_MOE_INPUTS_DOCSTRING, KwargsForCausalLM, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, load_balancing_loss_func, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple @@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/liger/models/deepseekv2.py b/src/axolotl/integrations/liger/models/deepseekv2.py index c29fd4e79..2f0d2a704 100644 --- a/src/axolotl/integrations/liger/models/deepseekv2.py +++ b/src/axolotl/integrations/liger/models/deepseekv2.py @@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import CausalLMOutputWithPast -# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) -# @replace_return_docstrings( -# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -# ) def lce_forward( self, input_ids: torch.LongTensor = None, diff --git a/src/axolotl/integrations/liger/models/jamba.py b/src/axolotl/integrations/liger/models/jamba.py index 7ab464c88..d25529970 100644 --- a/src/axolotl/integrations/liger/models/jamba.py +++ b/src/axolotl/integrations/liger/models/jamba.py @@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import MoeCausalLMOutputWithPast from transformers.models.jamba.modeling_jamba import ( - _CONFIG_FOR_DOC, - JAMBA_INPUTS_DOCSTRING, HybridMambaAttentionDynamicCache, load_balancing_loss_func, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) -@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def lce_forward( self, input_ids: torch.LongTensor = None, diff --git a/src/axolotl/monkeypatch/gemma3.py b/src/axolotl/monkeypatch/gemma3.py index 38183fa0e..36f591efd 100644 --- a/src/axolotl/monkeypatch/gemma3.py +++ b/src/axolotl/monkeypatch/gemma3.py @@ -7,24 +7,16 @@ from typing import Optional, Tuple, Union import torch from transformers.cache_utils import Cache from transformers.models.gemma3.modeling_gemma3 import ( - _CONFIG_FOR_DOC, - GEMMA3_INPUTS_DOCSTRING, Gemma3CausalLMOutputWithPast, logger, ) from transformers.utils import ( - add_start_docstrings_to_model_forward, is_torchdynamo_compiling, - replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def new_forward( self, input_ids: torch.LongTensor = None, From 8f8a7afb05aec84ed6bf03e187a745088e2563b0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 16 May 2025 13:06:08 -0400 Subject: [PATCH 02/19] Add ci and images for CUDA 12.8 for B200s (#2683) [skip ci] * Add ci and images for CUDA 12.8 for B200s * add comments explaining CI [skip e2e] --- .github/workflows/main.yml | 10 ++++++++++ .github/workflows/tests.yml | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4fcf08352..01606f902 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,6 +31,11 @@ jobs: python_version: "3.11" pytorch: 2.7.0 axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.7.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -94,6 +99,11 @@ jobs: python_version: "3.11" pytorch: 2.7.0 axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.7.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c296e2314..69f0a030d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -295,6 +295,7 @@ jobs: find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; docker-e2e-tests-1st: + # Run this job first as a gate for running the remainder of the test matrix if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] @@ -341,6 +342,8 @@ jobs: # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] timeout-minutes: 90 + # Only run the remainder of the matrix if the first e2e check passed; + # this is to save on wasted compute costs for known failures that get caught in the first run needs: [pre-commit, pytest, docker-e2e-tests-1st] strategy: @@ -365,6 +368,12 @@ jobs: pytorch: 2.7.0 num_gpus: 1 axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.7.0 + num_gpus: 1 + axolotl_extras: steps: - name: Checkout uses: actions/checkout@v4 From c9797de6bb208dc95eb7374e76fefaa9f00a58c8 Mon Sep 17 00:00:00 2001 From: michelyang Date: Fri, 16 May 2025 10:06:20 -0700 Subject: [PATCH 03/19] Add num_proc to fix data set slow processing issue (#2681) [skip ci] --- src/axolotl/utils/data/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index eaa834822..dc5920099 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -72,6 +72,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): data_set = data_set.map( ds_transform_fn, desc="Mapping RL Dataset", + num_proc=cfg.dataset_processes, **map_kwargs, ) From c837c4a424a60dd34ea44e34a65076fd6a5782d4 Mon Sep 17 00:00:00 2001 From: Eric Meier Date: Fri, 16 May 2025 10:06:46 -0700 Subject: [PATCH 04/19] Add missing init file to liger plugin (#2670) [skip ci] --- src/axolotl/integrations/liger/models/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/axolotl/integrations/liger/models/__init__.py diff --git a/src/axolotl/integrations/liger/models/__init__.py b/src/axolotl/integrations/liger/models/__init__.py new file mode 100644 index 000000000..e69de29bb From f661858fc4d6c0d2329f5b0c1c965a7fc694c8fc Mon Sep 17 00:00:00 2001 From: xzuyn <16216325+xzuyn@users.noreply.github.com> Date: Fri, 16 May 2025 13:06:58 -0400 Subject: [PATCH 05/19] Print dataset name (#2668) [skip ci] --- src/axolotl/utils/data/sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 5fa0cb60d..6de2d2cf7 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -484,7 +484,7 @@ def get_dataset_wrapper( } LOG.info( - f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}" + f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}" ) if ( From 3a5b495a740a85fc8dc638ae4a78620dadc59721 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sat, 17 May 2025 00:07:40 +0700 Subject: [PATCH 06/19] Fix: improve doc on merge/inference cli visibility (#2674) * feat: improve visibility for merge doc * feat: add tip on reuse config between modes --- docs/getting-started.qmd | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/docs/getting-started.qmd b/docs/getting-started.qmd index a0501ad21..064985e35 100644 --- a/docs/getting-started.qmd +++ b/docs/getting-started.qmd @@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format: Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to format them. -2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca +2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca` format): ```json @@ -120,6 +120,12 @@ axolotl train my_training.yml ## Common Tasks {#sec-common-tasks} +::: {.callout-tip} + +The same yaml file is used for training, inference, and merging. + +::: + ### Testing Your Model {#sec-testing} After training, test your model: @@ -128,6 +134,16 @@ After training, test your model: axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" ``` +More details can be found in [Inference](inference.qmd). + +### Using a UI {#sec-ui} + +Launch a Gradio interface: + +```bash +axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio +``` + ### Preprocessing Data {#sec-preprocessing} For large datasets, preprocess first: @@ -136,14 +152,22 @@ For large datasets, preprocess first: axolotl preprocess my_training.yml ``` -### Using a UI {#sec-ui} +Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset. -Launch a Gradio interface: +More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd). + +### Merging LoRA weights {#sec-merging-lora} + +To merge the LoRA weights back into the base model, run: ```bash -axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio +axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out" ``` +The merged model will be saved in the `{output_dir}/merged` directory. + +More details can be found in [Merging LoRA weights](inference.qmd#sec-merging). + ## Next Steps {#sec-next-steps} Now that you have the basics, you might want to: @@ -156,6 +180,7 @@ Now that you have the basics, you might want to: Check our other guides for details on these topics: - [Configuration Guide](config.qmd) - Full configuration options +- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources - [Dataset Formats](dataset-formats) - Working with different data formats - [Multi-GPU Training](multi-gpu.qmd) - [Multi-Node Training](multi-node.qmd) From 288653adb6decba7ff71f6f4b5c460542f29f1a7 Mon Sep 17 00:00:00 2001 From: C080 <54465490+C080@users.noreply.github.com> Date: Fri, 16 May 2025 21:46:31 +0200 Subject: [PATCH 07/19] =?UTF-8?q?Fix:=20Make=20MLflow=20config=20artifact?= =?UTF-8?q?=20logging=20respect=20hf=5Fmlflow=5Flog=5Fartifa=E2=80=A6=20(#?= =?UTF-8?q?2675)=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifacts setting * cleanup and lint --------- Co-authored-by: Wing Lian --- src/axolotl/utils/callbacks/mlflow_.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/callbacks/mlflow_.py b/src/axolotl/utils/callbacks/mlflow_.py index 47679001f..15ca1ca47 100644 --- a/src/axolotl/utils/callbacks/mlflow_.py +++ b/src/axolotl/utils/callbacks/mlflow_.py @@ -1,6 +1,7 @@ """MLFlow module for trainer callbacks""" import logging +import os from shutil import copyfile from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING @@ -16,6 +17,11 @@ if TYPE_CHECKING: LOG = logging.getLogger("axolotl.callbacks") +def should_log_artifacts() -> bool: + truths = ["TRUE", "1", "YES"] + return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths + + class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): # pylint: disable=duplicate-code """Callback to save axolotl config to mlflow""" @@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): ): if is_main_process(): try: - with NamedTemporaryFile( - mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" - ) as temp_file: - copyfile(self.axolotl_config_path, temp_file.name) - mlflow.log_artifact(temp_file.name, artifact_path="") + if should_log_artifacts(): + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + copyfile(self.axolotl_config_path, temp_file.name) + mlflow.log_artifact(temp_file.name, artifact_path="") + LOG.info( + "The Axolotl config has been saved to the MLflow artifacts." + ) + else: LOG.info( - "The Axolotl config has been saved to the MLflow artifacts." + "Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)" ) except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to MLflow: {err}") From 6cb07b9d12984198b8feb0a0abda922be8108e2c Mon Sep 17 00:00:00 2001 From: xzuyn <16216325+xzuyn@users.noreply.github.com> Date: Fri, 16 May 2025 15:46:50 -0400 Subject: [PATCH 08/19] Fix for setting `adam_beta3` and `adam_epsilon2` for CAME Optimizer (#2654) [skip ci] * make setting `adam_beta3` and `adam_epsilon2` work correctly * update config docs so users know args are specific to CAME optim --------- Co-authored-by: Wing Lian --- docs/config.qmd | 2 ++ src/axolotl/core/trainer_builder.py | 6 +++++- src/axolotl/core/training_args.py | 13 +++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/config.qmd b/docs/config.qmd index 10e5a5895..ac4c3fa4f 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -633,7 +633,9 @@ weight_decay: # adamw hyperparams adam_beta1: adam_beta2: +adam_beta3: # only used for CAME Optimizer adam_epsilon: +adam_epsilon2: # only used for CAME Optimizer # Gradient clipping max norm max_grad_norm: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 25d327dcd..6bd4ef996 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 if self.cfg.adam_beta2: training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2 + if self.cfg.adam_beta3: + training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3 if self.cfg.adam_epsilon: training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon + if self.cfg.adam_epsilon2: + training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2 if self.cfg.max_grad_norm: training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm @@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): beta1 = training_arguments_kwargs.get("adam_beta1", 0.9) beta2 = training_arguments_kwargs.get("adam_beta2", 0.999) - beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999) + beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999) eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30) eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16) adam_kwargs["betas"] = (beta1, beta2, beta3) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 0b14e7661..a81c33801 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -227,6 +227,19 @@ class AxolotlTrainingMixins: }, ) + adam_beta3: Optional[float] = field( + default=None, + metadata={ + "help": "The beta3 hyperparameter used in some optimizers such as CAME" + }, + ) + adam_epsilon2: Optional[float] = field( + default=None, + metadata={ + "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" + }, + ) + # multi-modal section image_size: int | tuple[int, int] | None = field( From a27b909c5c1c2c561a8d503024b89afcce15226f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 16 May 2025 15:47:03 -0400 Subject: [PATCH 09/19] GRPO fixes (peft) (#2676) * don't set peft_config on grpo to prevent double peft wrap * remove overrides needed to support bug * fix grpo tests * require more CPU for multigpu to help with torch compile for vllm --- cicd/multigpu.py | 2 +- src/axolotl/core/trainer_builder.py | 3 +- src/axolotl/core/trainers/grpo/trainer.py | 57 +---------------------- tests/e2e/multigpu/solo/test_grpo.py | 6 --- 4 files changed, 5 insertions(+), 63 deletions(-) diff --git a/cicd/multigpu.py b/cicd/multigpu.py index 90d4ce1ee..7de4ae0a7 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str): image=cicd_image, gpu=GPU_CONFIG, timeout=90 * 60, - cpu=8.0, + cpu=16.0, memory=131072 * N_GPUS, volumes=VOLUME_CONFIG, ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6bd4ef996..d82e4d20b 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1174,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.eval_dataset: trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: - trainer_kwargs["peft_config"] = self.peft_config + if self.cfg.rl is not RLType.GRPO: + trainer_kwargs["peft_config"] = self.peft_config if self.cfg.precompute_ref_log_probs is not None: trainer_kwargs["precompute_ref_log_probs"] = ( self.cfg.precompute_ref_log_probs diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index bc3d140b1..8a89de333 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-lines,duplicate-code,protected-access,no-member import warnings -from contextlib import nullcontext from typing import Any import datasets @@ -14,7 +13,7 @@ from accelerate.utils import ( broadcast_object_list, gather, gather_object, - is_peft_model, + is_peft_available, ) from datasets import Dataset, IterableDataset from torch import nn @@ -30,15 +29,13 @@ from transformers import ( TrainerCallback, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_peft_available from trl import GRPOTrainer from trl.data_utils import ( apply_chat_template, is_conversational, maybe_apply_chat_template, ) -from trl.extras.profiling import profiling_context, profiling_decorator -from trl.import_utils import is_deepspeed_available +from trl.extras.profiling import profiling_context from trl.models import unwrap_model_for_generation from trl.trainer.grpo_config import GRPOConfig from trl.trainer.grpo_trainer import RewardFunc, nanstd @@ -52,62 +49,12 @@ if is_peft_available(): # pylint: disable=unused-import from peft import PeftConfig -if is_deepspeed_available(): - import deepspeed - class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): """Extend the base GRPOTrainer for axolotl helpers""" _tag_names = ["trl", "grpo", "axolotl"] - @profiling_decorator - def _move_model_to_vllm(self): - # For DeepSpeed ZeRO-3, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - gather_if_zero3 = ( - deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext - ) - - if is_peft_model(self.model): - # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging - # adapters in a sharded manner is not supported. - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name and discard some parameters - name = ( - name.removeprefix("base_model.model.") - .removeprefix("base_model.model.") - .replace(".base_layer", "") - ) - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = name.replace("modules_to_save.default.", "") - - if self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather and update each parameter individually. - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): - if self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - - # Reset cache on main process - if self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): """Extend the base GRPOTrainer for sequence parallelism handling""" diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index a1eade531..575b7a620 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): """ ) - @pytest.mark.skip(reason="flaky test") @pytest.mark.parametrize( "num_gpus", [1, 2], @@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "NCCL_P2P_LEVEL": "LOC", **current_env, "CUDA_VISIBLE_DEVICES": "1", - "VLLM_DISABLE_COMPILE_CACHE": "1", - # "VLLM_USE_V1": "0", } vllm_process = start_vllm( cfg.base_model, @@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): finally: recursive_kill(vllm_process) - @pytest.mark.skip(reason="flaky test") @pytest.mark.parametrize( "num_gpus", [1, 2], @@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable **current_env, "CUDA_VISIBLE_DEVICES": "1", - "VLLM_DISABLE_COMPILE_CACHE": "1", - # "VLLM_USE_V1": "0", } vllm_process = start_vllm( cfg.base_model, From 6aa41740df7623d1f8995a1efd3b668f4a57c5cf Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 21 May 2025 11:20:20 -0400 Subject: [PATCH 10/19] SP dataloader patching + removing custom sampler / dataloader logic (#2686) * utilize accelerate prepare_data_loader with patching * lint * cleanup, fix * update to support DPO quirk * small change * coderabbit commits, cleanup, remove dead code * quarto fix * patch fix * review comments * moving monkeypatch up one level * fix --- _quarto.yml | 1 - docs/multi-gpu.qmd | 15 +- docs/sequence_parallelism.qmd | 6 +- examples/qwen2/dpo.yaml | 1 - src/axolotl/core/trainer_builder.py | 9 - src/axolotl/core/trainers/base.py | 50 +--- src/axolotl/core/trainers/dpo/trainer.py | 157 +----------- src/axolotl/core/trainers/grpo/trainer.py | 2 +- src/axolotl/core/trainers/mixins/__init__.py | 1 - .../core/trainers/mixins/sequence_parallel.py | 87 ------- src/axolotl/core/training_args.py | 13 - .../attention/ring_attn/__init__.py | 11 - .../monkeypatch/attention/ring_attn/patch.py | 131 ---------- src/axolotl/monkeypatch/ring_attn/__init__.py | 22 ++ .../ring_attn/adapters/__init__.py | 0 .../ring_attn/adapters/batch.py | 0 src/axolotl/monkeypatch/ring_attn/patch.py | 223 ++++++++++++++++++ .../utils/ctx_managers/sequence_parallel.py | 24 +- src/axolotl/utils/models.py | 22 +- tests/e2e/patched/test_sp.py | 6 +- 20 files changed, 304 insertions(+), 477 deletions(-) delete mode 100644 src/axolotl/core/trainers/mixins/sequence_parallel.py delete mode 100644 src/axolotl/monkeypatch/attention/ring_attn/__init__.py delete mode 100644 src/axolotl/monkeypatch/attention/ring_attn/patch.py create mode 100644 src/axolotl/monkeypatch/ring_attn/__init__.py rename src/axolotl/monkeypatch/{attention => }/ring_attn/adapters/__init__.py (100%) rename src/axolotl/monkeypatch/{attention => }/ring_attn/adapters/batch.py (100%) create mode 100644 src/axolotl/monkeypatch/ring_attn/patch.py diff --git a/_quarto.yml b/_quarto.yml index dc5071838..c09aecaea 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -60,7 +60,6 @@ quartodoc: - core.trainers.mixins.optimizer - core.trainers.mixins.rng_state_loader - core.trainers.mixins.scheduler - - core.trainers.mixins.sequence_parallel - title: Context Managers desc: Context managers for altering trainer behaviors contents: diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index 55eaca6c3..fee7d17e5 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -87,20 +87,7 @@ We support sequence parallelism (SP) via the allows one to split up sequences across GPUs, which is useful in the event that a single sequence causes OOM errors during model training. -First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`, -or from source with `pip install .[ring-flash-attn]`. - -Your Axolotl YAML config should contain the following lines: - -```{.yaml} -sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU -flash_attention: true # Required with sequence parallelism - -# Optional; strides across the key dimension. Larger values use more memory but will make training faster. -heads_k_stride: 1 -``` - -See our [dedicated guide](sequence_parallelism.qmd) for more details. +See our [dedicated guide](sequence_parallelism.qmd) for more information. ### FSDP + QLoRA {#sec-fsdp-qlora} diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index 1bff17ce9..b98206135 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -41,7 +41,7 @@ When sequence parallelism is enabled: 1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group 2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids -3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences +3. Position IDs are adjusted to maintain proper relative positions 4. The trainer uses special ring communication patterns for attention operations ## Requirements @@ -67,9 +67,11 @@ sequence_len: 8192 ... sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU -flash_attention: true # Required with sequence parallelism # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 +# Optional; one of "varlen_llama3" or "batch_ring". Defaults to +# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise. +ring_attn_func: ... ``` diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml index 3547c6c98..bd896c2b3 100644 --- a/examples/qwen2/dpo.yaml +++ b/examples/qwen2/dpo.yaml @@ -2,7 +2,6 @@ base_model: Qwen/Qwen2.5-0.5B # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name - chat_template: qwen_25 rl: dpo datasets: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d82e4d20b..878dd176a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -798,11 +798,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.kd_top_k_before_softmax ) - training_arguments_kwargs["sequence_parallel_degree"] = ( - self.cfg.sequence_parallel_degree - ) - training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func - if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig elif self.cfg.process_reward_model: @@ -1083,10 +1078,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.use_wandb: training_args_kwargs["run_name"] = self.cfg.wandb_name - training_args_kwargs["sequence_parallel_degree"] = ( - self.cfg.sequence_parallel_degree - ) - training_args_cls = None blocklist_args_kwargs = [] if self.cfg.rl is RLType.SIMPO: diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 2f0ce6894..d5cfc23df 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -29,7 +29,6 @@ from axolotl.core.trainers.mixins import ( OptimizerMixin, RngLoaderMixin, SchedulerMixin, - SequenceParallelMixin, ) from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, @@ -40,9 +39,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = logging.getLogger(__name__) -class AxolotlTrainer( - SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer -): +class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] @@ -68,10 +65,6 @@ class AxolotlTrainer( if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - # Initialize sequence parallelism if enabled - if self.args.sequence_parallel_degree > 1: - self._setup_sequence_parallel() - def _wrap_model(self, model, training=True, dataloader=None): if self.args.torch_compile: torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access @@ -122,8 +115,8 @@ class AxolotlTrainer( def _get_train_sampler(self) -> Sampler | None: """ - Helper method to get the sampler for training. Handles cases for sequence - parallelism, sample packing, and curriculum sampling (sequential). + Helper method to get the sampler for training. Handles cases for sample packing + and curriculum sampling (sequential). Returns: If the dataset is non-empty, a sampler is returned, the type of which @@ -132,9 +125,7 @@ class AxolotlTrainer( use_sample_packing = self.args.sample_packing and not self.args.pretraining # Determine the base sampler first - if self.args.sequence_parallel_degree > 1: - base_sampler = self._sp_get_train_sampler(self.train_dataset) - elif self.args.curriculum_sampling: + if self.args.curriculum_sampling: base_sampler = SequentialSampler(self.train_dataset) elif use_sample_packing: base_sampler = RandomSampler(self.train_dataset) @@ -153,8 +144,7 @@ class AxolotlTrainer( def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None: """ - Helper method to get the sampler for evaluation. Handles sequence parallelism - and sample packing cases. + Helper method to get the sampler for evaluation. Handles sample packing case. Returns: If the dataset is non-empty, a sampler is returned, the type of which @@ -168,9 +158,7 @@ class AxolotlTrainer( ) # Determine the base sampler - if self.args.sequence_parallel_degree > 1: - base_sampler = self._sp_get_eval_sampler(eval_dataset) - elif use_multipack: + if use_multipack: base_sampler = SequentialSampler(eval_dataset) else: return super()._get_eval_sampler(eval_dataset) @@ -236,14 +224,6 @@ class AxolotlTrainer( ): self.accelerator.even_batches = False - # Return unprepared dataloader if using sequence parallelism - # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation - # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., - # slice each batch along the sequence dimension). - if self.args.sequence_parallel_degree > 1: - return dataloader - - # Otherwise prepare with accelerator return self.accelerator.prepare_data_loader(dataloader) def get_train_dataloader(self) -> DataLoader: @@ -287,12 +267,7 @@ class AxolotlTrainer( return dataloader - # Handle sample packing or sequence parallelism - if ( - self.args.sample_packing - and self.args.eval_sample_packing is not False - or self.args.sequence_parallel_degree > 1 - ): + if self.args.sample_packing and self.args.eval_sample_packing is not False: # Get appropriate data collator self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.eval_data_collator @@ -302,17 +277,6 @@ class AxolotlTrainer( if "length" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns(["length"]) - # Handle dataset preprocessing for SP - if self.args.sequence_parallel_degree > 1: - if isinstance(eval_dataset, datasets.Dataset): - eval_dataset = self._remove_unused_columns( - eval_dataset, description="evaluation" - ) - else: - self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init - self.data_collator, description="evaluation" - ) - # Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise batch_size = ( self.args.eval_batch_size diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 1ce7deea7..c2c80c0bc 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -1,31 +1,15 @@ -""" -DPO trainer for axolotl -""" +"""DPO trainer for axolotl""" import gc -import random from functools import wraps -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union -import pandas as pd import torch -import wandb -from accelerate import PartialState -from datasets import Dataset, IterableDataset from peft.optimizers import create_loraplus_optimizer from torch import nn -from torch.utils.data import DataLoader -from transformers import ( - BaseImageProcessor, - FeatureExtractionMixin, - PreTrainedTokenizerBase, - ProcessorMixin, - Trainer, -) -from transformers.trainer_utils import EvalLoopOutput +from transformers import Trainer from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt -from trl.trainer.utils import log_table_to_comet_experiment +from trl import DPOTrainer from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.utils import ( @@ -38,9 +22,7 @@ if is_sagemaker_mp_enabled(): class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): - """ - Extend the base DPOTrainer for axolotl helpers - """ + """Extend the base DPOTrainer for axolotl helpers.""" tag_names = ["axolotl", "dpo"] @@ -85,8 +67,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): @wraps(DPOTrainer.push_to_hub) def push_to_hub(self, *args, **kwargs) -> str: """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + Overwrite the `push_to_hub` method in order to force-add the tags when pushing + the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` + for more details. """ kwargs = sanitize_kwargs_for_ds_tagging( dataset_tags=self.dataset_tags, kwargs=kwargs @@ -95,64 +78,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): return super().push_to_hub(*args, **kwargs) - # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release - def _prepare_dataset( - self, - dataset: Union[Dataset, IterableDataset], - processing_class: Union[ - PreTrainedTokenizerBase, - BaseImageProcessor, - FeatureExtractionMixin, - ProcessorMixin, - ], - args: DPOConfig, - dataset_name: str, - ) -> Union[Dataset, IterableDataset]: - # Build the kwargs for the `map` function - map_kwargs: Dict[str, Any] = {"writer_batch_size": 10} - if isinstance(dataset, Dataset): # IterableDataset does not support num_proc - map_kwargs["num_proc"] = args.dataset_num_proc - - with PartialState().main_process_first(): - # Extract prompt if needed - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" - dataset = dataset.map(maybe_extract_prompt, **map_kwargs) - - # Apply the chat template if needed - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" - dataset = dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, - **map_kwargs, - ) - - # Tokenize the dataset - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - - dataset = dataset.map( - self.tokenize_row if not self.is_vision_model else self.process_row, - remove_columns=["chosen", "rejected"], - fn_kwargs={ - "processing_class": processing_class, - "max_prompt_length": args.max_prompt_length, - "max_completion_length": args.max_completion_length, - # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) - "add_special_tokens": False, - }, - **map_kwargs, - ) - - return dataset - @staticmethod def tokenize_row( features, @@ -192,69 +117,3 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): gc.collect() torch.cuda.empty_cache() return loss - - # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. - Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample( - range(num_samples), k=self.args.eval_batch_size - ) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - policy_output_decoded, ref_output_decoded = ( - self.generate_from_model_and_ref(self.model, random_batch) - ) - - table = pd.DataFrame( - columns=["Prompt", "Policy", "Ref Model"], - data=[ - [prompt, pol[len(prompt) :], ref[len(prompt) :]] - for prompt, pol, ref in zip( - random_batch_dataset["prompt"], - policy_output_decoded, - ref_output_decoded, - ) - ], - ) - if "wandb" in self.args.report_to and self.accelerator.is_main_process: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - # Base evaluation - initial_output = super( # pylint: disable=bad-super-call - DPOTrainer, self - ).evaluation_loop( - dataloader, - description, - prediction_loss_only, - ignore_keys, - metric_key_prefix, - ) - - return initial_output diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 8a89de333..a603ed860 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -43,7 +43,7 @@ from trl.trainer.utils import pad from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin -from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group +from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group if is_peft_available(): # pylint: disable=unused-import diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 44751b465..a71cb321a 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -6,4 +6,3 @@ from .optimizer import OptimizerMixin from .rng_state_loader import RngLoaderMixin from .scheduler import SchedulerMixin -from .sequence_parallel import SequenceParallelMixin diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py deleted file mode 100644 index 0f30458cd..000000000 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Module for Axolotl trainer sequence parallelism mixin""" - -import torch.distributed as dist -from datasets import Dataset -from torch.utils.data import DistributedSampler, Sampler - -from axolotl.monkeypatch.attention.ring_attn import ( - get_ring_attn_group, -) - - -class SequenceParallelMixin: - """ - Mixin class for sequence parallelism support in trainers. - - This mixin provides functionality for handling sequence parallelism, - specifically for creating appropriate data samplers. - """ - - args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] - - def _setup_sequence_parallel(self): - """Set up sequence parallelism environment.""" - self.ring_attn_group = get_ring_attn_group() - - def _create_sequence_parallel_sampler( - self, - dataset: Dataset, - shuffle: bool = True, - is_eval: bool = False, - ) -> DistributedSampler: - """ - Helper method to create sampler for sequence parallelism (SP). - - We create a distributed sampler with rank equal to the SP group ID, which - means that all ranks in the SP group receive the same sample / set of samples - per training step. We also set the number of replicas equal to the number of - SP groups, which is a bit of a hack / unintended use, but works! - - Args: - dataset: Dataset to sample from. - shuffle: Whether to shuffle the dataset. - is_eval: Whether we are creating a sampler for evaluation or training. - - Returns: - Distributed sampler. - """ - num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree - sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree - - return DistributedSampler( - dataset, - num_replicas=num_sp_groups, - rank=sp_group_id, - seed=self.args.seed if shuffle else None, - shuffle=shuffle, - drop_last=not is_eval, - ) - - def _sp_get_train_sampler(self, dataset) -> Sampler | None: - """ - Get a training sampler configured for sequence parallelism. - - Args: - dataset: The training dataset - - Returns: - Configured sequence parallel sampler. - """ - return self._create_sequence_parallel_sampler( - dataset, - shuffle=not self.args.curriculum_sampling, - ) - - def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None: - """ - Get an evaluation sampler configured for sequence parallelism. - - Args: - eval_dataset: The evaluation dataset. - - Returns: - Configured sequence parallel sampler. - """ - return self._create_sequence_parallel_sampler( - eval_dataset, shuffle=False, is_eval=True - ) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index a81c33801..9c93f77c7 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -9,8 +9,6 @@ from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig -from axolotl.utils.schemas.enums import RingAttnFunc - @dataclass class AxolotlTrainingMixins: @@ -216,17 +214,6 @@ class AxolotlTrainingMixins: }, ) - sequence_parallel_degree: Optional[int] = field( - default=1, - metadata={"help": "The number of workers to use in sequence parallelism"}, - ) - ring_attn_func: Optional[RingAttnFunc] = field( - default=None, - metadata={ - "help": "The ring-flash-attn function to use in sequence parallelism" - }, - ) - adam_beta3: Optional[float] = field( default=None, metadata={ diff --git a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py b/src/axolotl/monkeypatch/attention/ring_attn/__init__.py deleted file mode 100644 index a50ad456e..000000000 --- a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Init for ring attention monkeypatch module""" - -# pylint: disable=unused-import -# flake8: noqa - -from .patch import ( - get_ring_attn_group, - register_ring_attn, - set_ring_attn_group, - update_ring_attn_params, -) diff --git a/src/axolotl/monkeypatch/attention/ring_attn/patch.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py deleted file mode 100644 index 8cbba338a..000000000 --- a/src/axolotl/monkeypatch/attention/ring_attn/patch.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -Ring attention group registration and flash attention patching. - -Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention) -package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in -their sequence parallel version of Flash Attention 2. -""" - -import torch -import torch.distributed as dist -from accelerate.logging import get_logger - -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids -from axolotl.utils.schemas.enums import RingAttnFunc - -LOG = get_logger(__name__) - - -RING_ATTN_GROUP = None - - -def get_ring_attn_group() -> dist.ProcessGroup: - """ - Getter for ring attention group on this rank. - - Returns: - The process group for ring attention for this rank. - """ - return RING_ATTN_GROUP - - -def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): - """ - Setter for ring attention group on this rank. - - Args: - Process group for ring attention. - """ - global RING_ATTN_GROUP # pylint: disable=global-statement - RING_ATTN_GROUP = ring_attn_group - - -def register_ring_attn( - sequence_parallel_degree: int, - heads_k_stride: int | None, - ring_attn_func: RingAttnFunc | None, -): - """ - Create ring attention group and substitute flash attn with ring flash attn. - - Args: - sequence_parallel_degree: Sequence parallelism factor. - heads_k_stride: Sequence parallelism K head stride size. Passed - through to `ring_flash_attn.substitute_hf_flash_attn`. - ring_attn_func: `ring_flash_attn` ring attention implemention. If sample - packing is enabled, it must be a `varlen` function; otherwise, it must be a - `batch` function. - """ - if get_ring_attn_group() is not None: - LOG.info("Ring attention already registered, exiting early...") - return - - LOG.info( - "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs" - ) - - rank = dist.get_rank() - world_size = dist.get_world_size() - - assert sequence_parallel_degree <= world_size, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must be less than or equal to world_size ({world_size})" - ) - assert world_size % sequence_parallel_degree == 0, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must evenly divide world_size ({world_size})" - ) - - # Assign ranks to sequence parallel groups - group_assignments = {} - for i in range(world_size // sequence_parallel_degree): - ring_attn_ranks = list( - range( - i * sequence_parallel_degree, - (i + 1) * sequence_parallel_degree, - ) - ) - group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") - - # Track which GPUs are in which groups - for r in ring_attn_ranks: - group_assignments[r] = i - - if rank in ring_attn_ranks: - set_ring_attn_group(group) - - # Log the GPU group assignments - if rank == 0: - LOG.info(f"Sequence parallel group assignments: {group_assignments}") - - if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: - from ring_flash_attn import substitute_hf_flash_attn - - substitute_hf_flash_attn( - process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 - ) - elif ring_attn_func is RingAttnFunc.BATCH_RING: - from axolotl.monkeypatch.attention.ring_attn.adapters.batch import ( - substitute_hf_flash_attn, - ) - - substitute_hf_flash_attn( - process_group=get_ring_attn_group(), - ring_attn_func=ring_attn_func, - ) - - -def update_ring_attn_params(position_ids: torch.Tensor | None): - """ - Calculate the cumulative sequence lengths for the current forward pass and pass the - value to the substituted `ring_flash_attn`. - - Args: - position_ids: Optional tensor of position IDs (for sample packed data). - """ - from ring_flash_attn import update_ring_flash_attn_params - - cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) - update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) diff --git a/src/axolotl/monkeypatch/ring_attn/__init__.py b/src/axolotl/monkeypatch/ring_attn/__init__.py new file mode 100644 index 000000000..5833b9ce4 --- /dev/null +++ b/src/axolotl/monkeypatch/ring_attn/__init__.py @@ -0,0 +1,22 @@ +"""Init for ring attention monkeypatch module""" + +# pylint: disable=unused-import +# flake8: noqa + +from .patch import ( + get_ring_attn_group, + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, + set_ring_attn_group, + update_ring_attn_params, +) + +__all__ = ( + "get_ring_attn_group", + "patch_prepare_data_loader", + "patch_prepare_device_mesh", + "register_ring_attn", + "set_ring_attn_group", + "update_ring_attn_params", +) diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/__init__.py b/src/axolotl/monkeypatch/ring_attn/adapters/__init__.py similarity index 100% rename from src/axolotl/monkeypatch/attention/ring_attn/adapters/__init__.py rename to src/axolotl/monkeypatch/ring_attn/adapters/__init__.py diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py b/src/axolotl/monkeypatch/ring_attn/adapters/batch.py similarity index 100% rename from src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py rename to src/axolotl/monkeypatch/ring_attn/adapters/batch.py diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py new file mode 100644 index 000000000..4329d9f13 --- /dev/null +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -0,0 +1,223 @@ +"""Ring attention group registration and flash attention patching. + +Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention) +package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in +their sequence parallel version of Flash Attention 2. + +We also provide some patches for accelerate functions to prepare the dataloader for +sequence parallelism training. +""" + +import inspect + +import accelerate +import torch +import torch.distributed as dist +from accelerate.logging import get_logger + +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.schemas.enums import RingAttnFunc + +LOG = get_logger(__name__) + + +RING_ATTN_GROUP = None + +ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 + submesh_dp_size = 1 + submesh_tp_size = 1 + if "tp" in torch_device_mesh.mesh_dim_names: + submesh_tp_size = torch_device_mesh["tp"].size() + if "dp" in torch_device_mesh.mesh_dim_names: + submesh_dp_size = torch_device_mesh["dp"].size() + if "fsdp" in torch_device_mesh.mesh_dim_names: + submesh_fsdp_size = torch_device_mesh["fsdp"].size() + process_index = process_index // submesh_tp_size""" + +NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 + submesh_dp_size = 1 + submesh_tp_size = 1 + submesh_cp_size = 1 + if "cp" in torch_device_mesh.mesh_dim_names: + submesh_cp_size = torch_device_mesh["cp"].size() + if "tp" in torch_device_mesh.mesh_dim_names: + submesh_tp_size = torch_device_mesh["tp"].size() + if "dp" in torch_device_mesh.mesh_dim_names: + submesh_dp_size = torch_device_mesh["dp"].size() + if "fsdp" in torch_device_mesh.mesh_dim_names: + submesh_fsdp_size = torch_device_mesh["fsdp"].size() + process_index = process_index // (submesh_tp_size * submesh_cp_size)""" + + +def get_ring_attn_group() -> dist.ProcessGroup: + """Getter for ring attention group on this rank.""" + return RING_ATTN_GROUP + + +def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): + """Setter for ring attention group on this rank.""" + global RING_ATTN_GROUP # pylint: disable=global-statement + RING_ATTN_GROUP = ring_attn_group + + +def register_ring_attn( + sequence_parallel_degree: int, + heads_k_stride: int | None, + ring_attn_func: RingAttnFunc | None, +): + """Create ring attention group and substitute flash attn with ring flash attn. + + Args: + sequence_parallel_degree: Sequence parallelism factor. + heads_k_stride: Sequence parallelism K head stride size. Passed + through to `ring_flash_attn.substitute_hf_flash_attn`. + ring_attn_func: `ring_flash_attn` ring attention implemention. If sample + packing is enabled, it must be a `varlen` function; otherwise, it must be a + `batch` function. + """ + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + LOG.info( + "Enabling ring attention sequence parallelism: " + f"each sequence will be processed across {sequence_parallel_degree} GPUs" + ) + + assert sequence_parallel_degree <= world_size, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " + f"must be less than or equal to world_size ({world_size})" + ) + assert world_size % sequence_parallel_degree == 0, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " + f"must evenly divide world_size ({world_size})" + ) + + # Assign ranks to sequence parallel groups + group_assignments = {} + for i in range(world_size // sequence_parallel_degree): + ring_attn_ranks = list( + range( + i * sequence_parallel_degree, + (i + 1) * sequence_parallel_degree, + ) + ) + group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") + + # Track which GPUs are in which groups + for r in ring_attn_ranks: + group_assignments[r] = i + + if rank in ring_attn_ranks: + set_ring_attn_group(group) + + # Log the GPU group assignments + if rank == 0: + LOG.info(f"Sequence parallel group assignments: {group_assignments}") + + if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: + from ring_flash_attn import substitute_hf_flash_attn + + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 + ) + elif ring_attn_func is RingAttnFunc.BATCH_RING: + from axolotl.monkeypatch.ring_attn.adapters.batch import ( + substitute_hf_flash_attn, + ) + + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), + ring_attn_func=ring_attn_func, + ) + + +def update_ring_attn_params(position_ids: torch.Tensor | None): + """ + Calculate the cumulative sequence lengths for the current forward pass and pass the + value to the substituted `ring_flash_attn`. + + Args: + position_ids: Optional tensor of position IDs (for sample packed data). + """ + from ring_flash_attn import update_ring_flash_attn_params + + cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) + cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) + update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) + + +def patch_prepare_data_loader(): + """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. + + Raies: + RuntimeError: If source code to patch does not exist. + """ + original_fn = accelerate.data_loader.prepare_data_loader + original_source = inspect.getsource(original_fn) + + if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source: + raise RuntimeError( + "SP patch failed - target snippet not found. " + "Check accelerate's version or update the patch." + ) + + patched_source = original_source.replace( + ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE + ) + + # Create a new function from the patched source + namespace = {} + exec( # pylint: disable=exec-used # nosec B102 + patched_source, accelerate.data_loader.__dict__, namespace + ) + patched_function = namespace["prepare_data_loader"] + + accelerate.data_loader.prepare_data_loader = patched_function + LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") + + +def patch_prepare_device_mesh(sequence_parallel_degree: int): + """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh + that includes sequence parallelism with the specified degree. + + Args: + sequence_parallel_degree (int): The degree of sequence parallelism to use. + """ + + def _prepare_device_mesh(self): + """Prepare the device mesh for distributed training. The dataloader will + determine how to load data based on the device mesh. + """ + if self.state.torch_tp_plugin: + return self.state.torch_tp_plugin.torch_device_mesh + if ( + self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED + and hasattr(self.state, "ds_device_mesh") + ): + return self.state.ds_device_mesh + + # Create device mesh with sequence parallelism + world_size = dist.get_world_size() + mesh_shape = ( + world_size // sequence_parallel_degree, + sequence_parallel_degree, + ) + device_ids = list(range(world_size)) + + # Note that we use "cp" instead of "sp" to match the PyTorch native "context + # parallelism" implementation naming + return dist.DeviceMesh( + "cuda", + torch.tensor(device_ids).reshape(mesh_shape), + mesh_dim_names=("dp", "cp"), + ) + + # Replace the original method with our new method + # pylint: disable=protected-access + accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh + + LOG.info( + "Successfully patched Accelerator._prepare_device_mesh " + f"with sequence_parallel_degree={sequence_parallel_degree}" + ) diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 66044f7f0..6e4f9bada 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -1,6 +1,7 @@ """Module for Axolotl trainer sequence parallelism manager and utilities""" import functools +import inspect import torch import torch.distributed as dist @@ -9,7 +10,7 @@ from torch.utils.hooks import RemovableHandle from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import ModelOutput -from axolotl.monkeypatch.attention.ring_attn.patch import ( +from axolotl.monkeypatch.ring_attn.patch import ( get_ring_attn_group, update_ring_attn_params, ) @@ -206,12 +207,25 @@ class SequenceParallelContextManager: def __enter__(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): - # Apply sequence parallelism to kwargs and get original sequence length and padding info - kwargs, self.original_seq_len, self.pad_len = ( - self.apply_sequence_parallelism(batch=kwargs) + # Get parameter names from the model's forward function + forward_params = list( + inspect.signature(self.models[0].forward).parameters.keys() ) - return args, kwargs + updated_kwargs = kwargs.copy() + for i, arg in enumerate(args): + if i < len(forward_params): + updated_kwargs[forward_params[i]] = arg + + # Any excess positional arguments are kept as-is + remaining_args = args[len(forward_params) :] + + # Apply sequence parallelism to updated kwargs + updated_kwargs, self.original_seq_len, self.pad_len = ( + self.apply_sequence_parallelism(updated_kwargs) + ) + + return remaining_args, updated_kwargs # Forward post-hook to gather outputs def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 316fbec8c..6236f78e8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -59,6 +59,7 @@ from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) +from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config @@ -681,16 +682,25 @@ class ModelLoader: patch_self_attn_lora(self.cfg) if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: - from axolotl.monkeypatch.attention.ring_attn import register_ring_attn + from axolotl.monkeypatch.ring_attn import ( + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, + ) # Initialize ring attn for sequence parallelism. This must be done after # model init but before the first forward pass, since it modifies flash # attn to use ring comm for SP training across multiple GPUs. - register_ring_attn( - sequence_parallel_degree=self.cfg.sequence_parallel_degree, - heads_k_stride=self.cfg.heads_k_stride, - ring_attn_func=self.cfg.ring_attn_func, - ) + if get_ring_attn_group() is None: # If already set, this is already patched + register_ring_attn( + sequence_parallel_degree=self.cfg.sequence_parallel_degree, + heads_k_stride=self.cfg.heads_k_stride, + ring_attn_func=self.cfg.ring_attn_func, + ) + patch_prepare_data_loader() + patch_prepare_device_mesh( + sequence_parallel_degree=self.cfg.sequence_parallel_degree + ) def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 8efe62940..83faa779f 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -10,7 +10,7 @@ import pytest import torch from accelerate.state import PartialState -from axolotl.monkeypatch.attention.ring_attn import ( +from axolotl.monkeypatch.ring_attn import ( get_ring_attn_group, register_ring_attn, set_ring_attn_group, @@ -313,13 +313,13 @@ class TestApplySequenceParallelism: # Mock the process group monkeypatch.setattr( - "axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group", + "axolotl.monkeypatch.ring_attn.get_ring_attn_group", MagicMock, ) # Mock update_ring_attn_params monkeypatch.setattr( - "axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params", + "axolotl.monkeypatch.ring_attn.update_ring_attn_params", lambda **kwargs: None, ) From 1c83a1a02081834eb84afa7292eef479133541fa Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 May 2025 19:18:27 +0700 Subject: [PATCH 11/19] feat(doc): clarify minimum pytorch and cuda to use blackwell (#2704) [skip ci] --- docs/docker.qmd | 4 ++++ docs/installation.qmd | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/docs/docker.qmd b/docs/docker.qmd index e208d3222..d665eaf5b 100644 --- a/docs/docker.qmd +++ b/docs/docker.qmd @@ -8,6 +8,10 @@ format: This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai). +::: {.callout-important} +For Blackwell GPUs, please use the tags with Pytorch 2.7.0 and CUDA 12.8. +::: + ## Base The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more. diff --git a/docs/installation.qmd b/docs/installation.qmd index 0cf5ffceb..b429992b6 100644 --- a/docs/installation.qmd +++ b/docs/installation.qmd @@ -25,6 +25,10 @@ Please make sure to have Pytorch installed before installing Axolotl in your loc Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) ::: +::: {.callout-important} +For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8. +::: + ### PyPI Installation (Recommended) {#sec-pypi} ```{.bash} @@ -72,6 +76,10 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \ ``` ::: +::: {.callout-important} +For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`. +::: + Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available. ## Cloud Environments {#sec-cloud} From 798b5f5cfdc3478b51cd48d53c38a3b5a0d387f2 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 May 2025 19:19:12 +0700 Subject: [PATCH 12/19] fix(RL): address plugin rl overwriting trainer_cls (#2697) [skip ci] * fix: plugin rl overwrite trainer_cls * feat(test): add test to catch trainer_cls is not None --- src/axolotl/core/trainer_builder.py | 4 +++- tests/core/test_trainer_builder.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 878dd176a..863b065e6 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1195,7 +1195,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.plugins: plugin_manager = PluginManager.get_instance() - trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + temp_trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + if temp_trainer_cls is not None: + trainer_cls = temp_trainer_cls sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters.keys(): diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index fbfd7a87c..d1ad273ea 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -8,6 +8,7 @@ from axolotl.core.trainer_builder import HFRLTrainerBuilder from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.schemas.enums import RLType @pytest.fixture(name="cfg") @@ -65,3 +66,27 @@ class TestHFRLTrainerBuilder: assert training_arguments.adam_epsilon == 0.00001 assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_pin_memory is True + + +class TestTrainerClsPlugin: + """ + TestCase class for trainer builder with plugin + """ + + def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer): + """ + Test that the trainer cls is not none with plugin + + Fixes #2693 + """ + cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"] + cfg.rl = RLType.KTO + + # Expected AttributeError as we don't pass regular model configs to RL trainer builder + # If it throws `TypeError: None is not a callable object`, trainer_cls could be None + with pytest.raises( + AttributeError, match=r".*'tuple' object has no attribute 'config'.*" + ): + builder = HFRLTrainerBuilder(cfg, model, tokenizer) + + builder.build(100) From aa0492c366d32645481c80b5c60a86f53f7670d7 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 May 2025 19:19:59 +0700 Subject: [PATCH 13/19] feat: do not find turn indices if turn is not trainable (#2696) * feat: do not find turn indices if turn is not trainable * fix: handle edge case where train on eos/eot is all * fix: improve warning message --- src/axolotl/prompt_strategies/chat_template.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 638cee559..047a66e94 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -424,6 +424,20 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): LOG.debug(f"Should train: {should_train}") + # turn not trainable, skip having to find the turn indices + # unless last turn and train_on_eos/train_on_eot is all + if not should_train and ( + self.train_on_eos != "all" and self.train_on_eot != "all" + ): + if index == len(turns) - 1: + LOG.warning( + "Last turn is not trainable, skipping having to find the turn indices. " + "This may cause incorrect last EOT/EOS token to be unmasked." + "This is likely a dataset design issue. Please ensure last turn is trainable." + ) + + continue + turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index) LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}") From 5f8f8172005543e2fff728c81fd7314be7883d6f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 22 May 2025 11:18:32 -0400 Subject: [PATCH 14/19] SP context manager update (#2699) * utilize accelerate prepare_data_loader with patching * lint * cleanup, fix * update to support DPO quirk * coderabbit commits, cleanup, remove dead code * fix * move ring attn patching to sp ctx manager * lint * lint * test fix * test fix --- src/axolotl/monkeypatch/ring_attn/patch.py | 6 ++- src/axolotl/train.py | 1 + .../utils/ctx_managers/sequence_parallel.py | 51 ++++++++++++++----- src/axolotl/utils/models.py | 22 -------- tests/e2e/patched/test_sp.py | 34 +++++++++---- 5 files changed, 68 insertions(+), 46 deletions(-) diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 4329d9f13..7d733cfc1 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -51,6 +51,8 @@ NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 def get_ring_attn_group() -> dist.ProcessGroup: """Getter for ring attention group on this rank.""" + if RING_ATTN_GROUP is None: + raise RuntimeError("register_ring_attn() not yet called") return RING_ATTN_GROUP @@ -69,8 +71,8 @@ def register_ring_attn( Args: sequence_parallel_degree: Sequence parallelism factor. - heads_k_stride: Sequence parallelism K head stride size. Passed - through to `ring_flash_attn.substitute_hf_flash_attn`. + heads_k_stride: Sequence parallelism K head stride size. Passed through to + `varlen_llama3` `ring_flash_attn` implementation. ring_attn_func: `ring_flash_attn` ring attention implemention. If sample packing is enabled, it must be a `varlen` function; otherwise, it must be a `batch` function. diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 90ab10e9f..46f722eeb 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -209,6 +209,7 @@ def execute_training( sequence_parallel_degree=cfg.sequence_parallel_degree, gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, + heads_k_stride=cfg.heads_k_stride, ) ) diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 6e4f9bada..2ae93acad 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -12,6 +12,9 @@ from transformers.utils import ModelOutput from axolotl.monkeypatch.ring_attn.patch import ( get_ring_attn_group, + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, update_ring_attn_params, ) from axolotl.utils.schemas.enums import RingAttnFunc @@ -169,6 +172,8 @@ class SequenceParallelContextManager: sequence_parallel_degree: Number of processes to split sequences over. gradient_accumulation_steps: Number of steps to accumulate gradients over. ring_attn_func: Which ring attention function to use. Currently unused. + heads_k_stride: Sequence parallelism K head stride size. Passed through to + `varlen_llama3` `ring_flash_attn` implementation. """ def __init__( @@ -177,14 +182,17 @@ class SequenceParallelContextManager: sequence_parallel_degree: int, gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, + heads_k_stride: int | None, ): self.models = models self.sequence_parallel_degree = sequence_parallel_degree self.gradient_accumulation_steps = gradient_accumulation_steps self.ring_attn_func = ring_attn_func - self.process_group = get_ring_attn_group() + self.heads_k_stride = heads_k_stride + self._register_ring_attn() - # Initialize sequence parallel group details + # Set distributed info for local rank + self.process_group = get_ring_attn_group() self.local_rank = dist.get_rank(self.process_group) self.local_world_size = dist.get_world_size(self.process_group) @@ -205,6 +213,33 @@ class SequenceParallelContextManager: ) def __enter__(self): + self._register_model_hooks() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Remove all hooks + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] + + # TODO(djsaunde): Un-patch attention and accelerate functions (low priority) + + def _register_ring_attn(self): + # Initialize ring attn for sequence parallelism + register_ring_attn( + sequence_parallel_degree=self.sequence_parallel_degree, + heads_k_stride=self.heads_k_stride, + ring_attn_func=self.ring_attn_func, + ) + + # Patches for accelerate functionality + patch_prepare_data_loader() + patch_prepare_device_mesh( + sequence_parallel_degree=self.sequence_parallel_degree + ) + + def _register_model_hooks(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): # Get parameter names from the model's forward function @@ -230,7 +265,7 @@ class SequenceParallelContextManager: # Forward post-hook to gather outputs def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: # Gather the sharded outputs - output = self.gather_outputs(output) + output = self._gather_outputs(output) # Remove padding if it was added if self.pad_len > 0: @@ -253,15 +288,7 @@ class SequenceParallelContextManager: model.register_forward_hook(sequence_parallel_post_hook) ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # Remove all hooks - for handle in self.hook_handles: - handle.remove() - self.hook_handles = [] - - def gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: + def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: """Gather sharded outputs from all ranks and reconstruct the full tensor.""" for key, value in output.items(): if isinstance(value, torch.Tensor) and value.dim() > 1: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6236f78e8..cd7499869 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -59,7 +59,6 @@ from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) -from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config @@ -681,27 +680,6 @@ class ModelLoader: patch_self_attn_lora(self.cfg) - if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: - from axolotl.monkeypatch.ring_attn import ( - patch_prepare_data_loader, - patch_prepare_device_mesh, - register_ring_attn, - ) - - # Initialize ring attn for sequence parallelism. This must be done after - # model init but before the first forward pass, since it modifies flash - # attn to use ring comm for SP training across multiple GPUs. - if get_ring_attn_group() is None: # If already set, this is already patched - register_ring_attn( - sequence_parallel_degree=self.cfg.sequence_parallel_degree, - heads_k_stride=self.cfg.heads_k_stride, - ring_attn_func=self.cfg.ring_attn_func, - ) - patch_prepare_data_loader() - patch_prepare_device_mesh( - sequence_parallel_degree=self.cfg.sequence_parallel_degree - ) - def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 83faa779f..2b4d11b30 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -84,16 +84,16 @@ class TestRingAttention: def test_get_ring_attn_group_no_registration( self, mock_world_size, mock_rank, partial_state ): - """Test that get_ring_attn_group returns None when no group has been registered.""" + """Test that get_ring_attn_group raises RuntimeError when no group has been registered.""" # Setup mocks mock_world_size.return_value = 4 mock_rank.return_value = 0 - # Get the group without registration - group = get_ring_attn_group() - - # Verify that None was returned - assert group is None + # Verify that RuntimeError is raised when no group is registered + with pytest.raises( + RuntimeError, match="register_ring_attn\\(\\) not yet called" + ): + get_ring_attn_group() @patch("torch.distributed.new_group") @patch("torch.distributed.get_rank") @@ -323,8 +323,11 @@ class TestApplySequenceParallelism: lambda **kwargs: None, ) - def test_world_size_one(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch): """Test that function returns original batch when world size is 1.""" + mock_get_ring_attn_group.return_value = 0 + result, _, _ = apply_sequence_parallelism( batch=sequence_parallel_batch, local_rank=0, @@ -336,8 +339,11 @@ class TestApplySequenceParallelism: # Should return the original batch unchanged assert result == sequence_parallel_batch - def test_batch_ring_rank0(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch): """Test BATCH_RING sharding for rank 0 in a 2-process group.""" + mock_get_ring_attn_group.return_value = 0 + batch = sequence_parallel_batch seq_len = batch["input_ids"].size(1) @@ -359,8 +365,11 @@ class TestApplySequenceParallelism: result["position_ids"], batch["position_ids"][:, : seq_len // 2] ) - def test_batch_ring_rank1(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch): """Test BATCH_RING sharding for rank 1 in a 2-process group.""" + mock_get_ring_attn_group.return_value = 0 + batch = sequence_parallel_batch seq_len = batch["input_ids"].size(1) original_input_ids = batch["input_ids"].clone() @@ -419,8 +428,13 @@ class TestApplySequenceParallelism: # assert torch.equal(result_rank0["input_ids"], rank0_expected) # assert torch.equal(result_rank1["input_ids"], rank1_expected) - def test_partial_application(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_partial_application( + self, mock_get_ring_attn_group, sequence_parallel_batch + ): """Test that we can create a partially applied version of the function.""" + mock_get_ring_attn_group.return_value = 0 + batch = sequence_parallel_batch original_input_ids = batch["input_ids"].clone() From 8cde256db2964f14398e861317431afedffc7a26 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 23 May 2025 12:27:38 -0400 Subject: [PATCH 15/19] Remove unused const (#2714) * remove unused const * accidentally commited benchmark plot --- src/axolotl/kernels/geglu.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/axolotl/kernels/geglu.py b/src/axolotl/kernels/geglu.py index 0aa035c94..6acbea0d4 100644 --- a/src/axolotl/kernels/geglu.py +++ b/src/axolotl/kernels/geglu.py @@ -1,5 +1,4 @@ -""" -Module for definition of GEGLU Triton kernels. +"""Module for definition of GEGLU Triton kernels. See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). @@ -12,8 +11,6 @@ import torch import triton import triton.language as tl -SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/Ï€) - @triton.jit def _geglu_fwd_kernel( From b5f1e53a0fbb43528c753f017bf099fa99f42c3e Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 23 May 2025 15:51:11 -0400 Subject: [PATCH 16/19] models.py -> loaders/ module refactor (#2680) * models.py -> loaders/ module refactor * refactor ModelLoader class * plugin manager changes * circular import fix * pytest * pytest * minor improvements * fix * minor changes * fix test * remove dead code * coderabbit comments * lint * fix * coderabbit suggestion I liked * more coderabbit * review comments, yak shaving * lint * updating in light of SP ctx manager changes * review comment * review comment 2 --- src/axolotl/cli/utils.py | 6 +- src/axolotl/common/datasets.py | 2 +- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/core/trainers/grpo/trainer.py | 2 +- src/axolotl/integrations/base.py | 550 +++--- src/axolotl/loaders/__init__.py | 10 + src/axolotl/loaders/adapter.py | 206 +++ src/axolotl/loaders/constants.py | 21 + src/axolotl/loaders/model.py | 754 ++++++++ src/axolotl/loaders/patch_manager.py | 380 ++++ src/axolotl/loaders/processor.py | 56 + src/axolotl/loaders/tokenizer.py | 281 +++ src/axolotl/loaders/utils.py | 211 +++ .../gradient_checkpointing/__init__.py | 4 +- .../gradient_checkpointing/offload_cpu.py | 0 .../gradient_checkpointing/offload_disk.py | 0 src/axolotl/monkeypatch/peft/utils.py | 2 +- src/axolotl/train.py | 12 +- src/axolotl/utils/config/__init__.py | 3 +- .../utils/ctx_managers/sequence_parallel.py | 2 +- src/axolotl/utils/data/rl.py | 2 +- src/axolotl/utils/lora_embeddings.py | 14 - src/axolotl/utils/models.py | 1648 ----------------- src/axolotl/utils/schemas/config.py | 10 + tests/core/test_trainer_builder.py | 8 +- tests/e2e/patched/test_model_patches.py | 6 +- tests/e2e/test_load_model.py | 13 +- tests/patched/test_validation.py | 16 +- tests/test_exact_deduplication.py | 22 +- .../{utils/test_models.py => test_loaders.py} | 37 +- tests/test_lora.py | 6 +- tests/test_tokenizers.py | 2 +- tests/utils/__init__.py | 0 33 files changed, 2249 insertions(+), 2039 deletions(-) create mode 100644 src/axolotl/loaders/__init__.py create mode 100644 src/axolotl/loaders/adapter.py create mode 100644 src/axolotl/loaders/constants.py create mode 100644 src/axolotl/loaders/model.py create mode 100644 src/axolotl/loaders/patch_manager.py create mode 100644 src/axolotl/loaders/processor.py create mode 100644 src/axolotl/loaders/tokenizer.py create mode 100644 src/axolotl/loaders/utils.py rename src/axolotl/{utils => monkeypatch}/gradient_checkpointing/__init__.py (91%) rename src/axolotl/{utils => monkeypatch}/gradient_checkpointing/offload_cpu.py (100%) rename src/axolotl/{utils => monkeypatch}/gradient_checkpointing/offload_disk.py (100%) delete mode 100644 src/axolotl/utils/lora_embeddings.py delete mode 100644 src/axolotl/utils/models.py rename tests/{utils/test_models.py => test_loaders.py} (83%) delete mode 100644 tests/utils/__init__.py diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index ee00db39d..e681589f3 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -20,8 +20,9 @@ from transformers import ( ProcessorMixin, ) +from axolotl.loaders import load_processor, load_tokenizer +from axolotl.loaders.model import ModelLoader from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_processor, load_tokenizer LOG = logging.getLogger(__name__) @@ -318,7 +319,8 @@ def load_model_and_tokenizer( tokenizer = load_tokenizer(cfg) LOG.info("loading model...") - model, _ = load_model(cfg, tokenizer, inference=inference) + model_loader = ModelLoader(cfg, tokenizer, inference=inference) + model, _ = model_loader.load() processor = None if cfg.is_multimodal: diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index f944cbd6a..e3ffb7ae9 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -10,10 +10,10 @@ from datasets import Dataset import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs +from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.schemas.enums import RLType from axolotl.utils.tokenization import check_dataset_labels diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 863b065e6..9709f0fd4 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -59,6 +59,7 @@ from axolotl.core.training_args import ( AxolotlTrainingArguments, ) from axolotl.integrations.base import PluginManager +from axolotl.loaders.utils import ensure_dtype from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr @@ -86,7 +87,6 @@ from axolotl.utils.collators import ( V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator -from axolotl.utils.models import ensure_dtype from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType try: diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index a603ed860..b5b3912cf 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -43,7 +43,7 @@ from trl.trainer.utils import pad from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin -from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group +from axolotl.monkeypatch.ring_attn import get_ring_attn_group if is_peft_available(): # pylint: disable=unused-import diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 97cbac693..2beaf667a 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -10,71 +10,73 @@ # License for the specific language governing permissions and limitations under # the License. -""" -Base class for all plugins. +"""Base class for all plugins. A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl. Plugins can be used to integrate third-party models, modify the training process, or add new features. To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods. """ + +from __future__ import annotations + import collections import importlib import logging -from typing import OrderedDict +from typing import TYPE_CHECKING, Callable, OrderedDict, Union -import torch +from peft import PeftModel +from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +from transformers import PreTrainedModel, Trainer from axolotl.utils.dict import DictDefault +if TYPE_CHECKING: + from axolotl.common.datasets import TrainDatasetMeta + class BasePlugin: - """ - Base class for all plugins. Defines the interface for plugin methods. - - Attributes: - None + """Base class for all plugins. Defines the interface for plugin methods. Methods: - register(cfg): Registers the plugin with the given configuration. - load_datasets(cfg): Loads and preprocesses the dataset for training. - pre_model_load(cfg): Performs actions before the model is loaded. - post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied. - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. - post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. - post_trainer_create(cfg, trainer): Performs actions after the trainer is created. - create_optimizer(cfg, trainer): Creates and returns an optimizer for training. - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler. - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training. + register(cfg): Registers the plugin with the given configuration. + load_datasets(cfg): Loads and preprocesses the dataset for training. + pre_model_load(cfg): Performs actions before the model is loaded. + post_model_build(cfg, model): Performs actions after the model is loaded, but + before LoRA adapters are applied. + pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. + post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. + post_model_load(cfg, model): Performs actions after the model is loaded, + inclusive of any adapters. + post_trainer_create(cfg, trainer): Performs actions after the trainer is + created. + create_optimizer(cfg, trainer): Creates and returns an optimizer for training. + create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and + returns a learning rate scheduler. + add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before + training. + add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after + training. """ def __init__(self): - """ - Initializes the BasePlugin. - """ + """Initializes the BasePlugin.""" def register(self, cfg): # pylint: disable=unused-argument - """ - Registers the plugin with the given configuration. + """Registers the plugin with the given configuration. - Parameters: - cfg (dict): The configuration for the plugin. - - Returns: - None + Args: + cfg: The configuration for the plugin. """ def get_input_args(self) -> str | None: - """ - Returns a pydantic model for the plugin's input arguments. - """ + """Returns a pydantic model for the plugin's input arguments.""" - def load_datasets(self, cfg: DictDefault, preprocess: bool = False): - """ - Loads and preprocesses the dataset for training. + def load_datasets( + self, cfg: DictDefault, preprocess: bool = False + ) -> Union["TrainDatasetMeta", None]: + """Loads and preprocesses the dataset for training. Args: cfg: The configuration for the plugin. @@ -84,181 +86,164 @@ class BasePlugin: dataset_meta: The metadata for the training dataset. """ - def pre_model_load(self, cfg): # pylint: disable=unused-argument - """ - Performs actions before the model is loaded. + def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument + """Performs actions before the model is loaded. Args: - cfg (dict): The configuration for the plugin. + cfg: The configuration for the plugin. + """ + + # pylint: disable=unused-argument + def post_model_build(self, cfg: DictDefault, model: PreTrainedModel): + """Performs actions after the model is built/loaded, but before any adapters are applied. + + Args: + cfg: The configuration for the plugin. + """ + + # pylint: disable=unused-argument + def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel): + """Performs actions before LoRA weights are loaded. + + Args: + cfg: The configuration for the plugin. + model: The loaded model. + """ + + # pylint: disable=unused-argument + def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Performs actions after LoRA weights are loaded. + + Args: + cfg: The configuration for the plugin. + model: The loaded model. + """ + + # pylint: disable=unused-argument + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Performs actions after the model is loaded. + + Args: + cfg: The configuration for the plugin. + model: The loaded model. + """ + + # pylint: disable=unused-argument + def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + """Returns a custom class for the trainer. + + Args: + cfg: The global axolotl configuration. Returns: - None + The first non-`None` trainer class returned by a plugin. """ - def post_model_build(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions after the model is built/loaded, but before any adapters are applied. + # pylint: disable=unused-argument + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): + """Performs actions after the trainer is created. Args: - cfg (dict): The configuration for the plugin. + cfg: The configuration for the plugin. + trainer: The trainer object for training. """ - def post_model_load(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions after the model is loaded. + # pylint: disable=unused-argument + def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None: + """Creates and returns an optimizer for training. Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. + cfg: The configuration for the plugin. + trainer: The trainer object for training. Returns: - None - """ - - def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions before LoRA weights are loaded. - - Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. - - Returns: - None - """ - - def post_lora_load(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions after LoRA weights are loaded. - - Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. - - Returns: - None - """ - - def get_trainer_cls(self, cfg): # pylint: disable=unused-argument): - """ - Returns a custom class for the trainer. - - Args: - cfg (dict): The global axolotl configuration. - - Returns: - class: The class for the trainer. - """ - - def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument - """ - Performs actions after the trainer is created. - - Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. - - Returns: - None - """ - - def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument - """ - Creates and returns an optimizer for training. - - Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. - - Returns: - object: The created optimizer. + The created optimizer. """ + # pylint: disable=unused-argument def create_lr_scheduler( - self, cfg, trainer, optimizer, num_training_steps - ) -> LRScheduler | None: # pylint: disable=unused-argument - """ - Creates and returns a learning rate scheduler. + self, + cfg: DictDefault, + trainer: Trainer, + optimizer: Optimizer, + num_training_steps: int, + ) -> LRScheduler | None: + """Creates and returns a learning rate scheduler. Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. - optimizer (object): The optimizer for training. - num_training_steps (int): Total number of training steps + cfg: The configuration for the plugin. + trainer: The trainer object for training. + optimizer: The optimizer for training. + num_training_steps: Total number of training steps Returns: - object (LRScheduler): The created learning rate scheduler. + The created learning rate scheduler. """ - def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument - """ - setup callbacks before creating the trainer. + # pylint: disable=unused-argument + def add_callbacks_pre_trainer( + self, cfg: DictDefault, model: PreTrainedModel + ) -> list[Callable]: + """Set up callbacks before creating the trainer. Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. + cfg: The configuration for the plugin. + model: The loaded model. Returns: - List[callable]: A list of callback functions to be added to the TrainingArgs + A list of callback functions to be added to the `TrainingArgs`. """ return [] + # pylint: disable=unused-argument def add_callbacks_post_trainer( - self, cfg, trainer - ): # pylint: disable=unused-argument - """ - Adds callbacks to the trainer after creating the trainer. - This is useful for callbacks that require access to the model or trainer. + self, cfg: DictDefault, trainer: Trainer + ) -> list[Callable]: + """Adds callbacks to the trainer after creating the trainer. This is useful for + callbacks that require access to the model or trainer. Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. + cfg: The configuration for the plugin. + trainer: The trainer object for training. Returns: - List[callable]: A list of callback functions to be added + A list of callback functions to be added """ return [] - def post_train(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions after training is complete. + # pylint: disable=unused-argument + def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Performs actions after training is complete. Args: - cfg (dict): The axolotl configuration - model (object): The loaded model. - - Returns: - None + cfg: The axolotl configuration. + model: The loaded model. """ - def post_train_unload(self, cfg): # pylint: disable=unused-argument - """ - Performs actions after training is complete and the model is unloaded. + def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument + """Performs actions after training is complete and the model is unloaded. Args: - cfg (dict): The configuration for the plugin. - - Returns: - None + cfg: The configuration for the plugin. """ def load_plugin(plugin_name: str) -> BasePlugin: - """ - Loads a plugin based on the given plugin name. + """Loads a plugin based on the given plugin name. - The plugin name should be in the format "module_name.class_name". - This function splits the plugin name into module and class, imports the module, - retrieves the class from the module, and creates an instance of the class. + The plugin name should be in the format "module_name.class_name". This function + splits the plugin name into module and class, imports the module, retrieves the + class from the module, and creates an instance of the class. - Parameters: - plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name". + Args: + plugin_name: The name of the plugin to be loaded. The name should be in the + format "module_name.class_name". Returns: - BasePlugin: An instance of the loaded plugin. + An instance of the loaded plugin. Raises: - ImportError: If the plugin module cannot be imported. + ImportError: If the plugin module cannot be imported. """ # split the plugin name into module and class module_name, class_name = plugin_name.rsplit(".", 1) @@ -284,28 +269,25 @@ def load_plugin(plugin_name: str) -> BasePlugin: class PluginManager: - """ - The PluginManager class is responsible for loading and managing plugins. - It should be a singleton so it can be accessed from anywhere in the codebase. + """The `PluginManager` class is responsible for loading and managing plugins. It + should be a singleton so it can be accessed from anywhere in the codebase. Attributes: - plugins (List[BasePlugin]): A list of loaded plugins. + plugins: A list of loaded plugins. Methods: - get_instance(): Static method to get the singleton instance of PluginManager. - register(plugin_name: str): Registers a new plugin by its name. - pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. + get_instance(): Static method to get the singleton instance of `PluginManager`. + register(plugin_name: str): Registers a new plugin by its name. + pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. """ plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() - _instance = None - _cfg = None + _instance: PluginManager | None = None + _cfg: DictDefault | None = None def __new__(cls): - """ - Creates a new instance of PluginManager if it doesn't exist yet. - """ + """Creates a new instance of PluginManager if it doesn't exist yet.""" if cls._instance is None: cls._instance = super(PluginManager, cls).__new__(cls) cls._instance.plugins: OrderedDict[str, BasePlugin] = ( @@ -315,9 +297,8 @@ class PluginManager: @staticmethod def get_instance() -> "PluginManager": - """ - Returns the singleton instance of PluginManager. - If the instance doesn't exist, it creates a new one. + """Returns the singleton instance of PluginManager. If the instance doesn't + exist, it creates a new one. """ if PluginManager._instance is None: PluginManager() @@ -332,17 +313,13 @@ class PluginManager: self._cfg = cfg def register(self, plugin_name: str): - """ - Registers a new plugin by its name. + """Registers a new plugin by its name. - Parameters: - plugin_name (str): The name of the plugin to be registered. - - Returns: - None + Args: + plugin_name: The name of the plugin to be registered. Raises: - ImportError: If the plugin module cannot be imported. + ImportError: If the plugin module cannot be imported. """ try: logging.info(f"Attempting to load plugin: {plugin_name}") @@ -352,12 +329,11 @@ class PluginManager: except ImportError: logging.error(f"Failed to load plugin: {plugin_name}") - def get_input_args(self): - """ - Returns a list of Pydantic classes for all registered plugins' input arguments.' + def get_input_args(self) -> list[str]: + """Returns a list of Pydantic classes for all registered plugins' input arguments.' Returns: - list[str]: A list of Pydantic classes for all registered plugins' input arguments.' + A list of Pydantic classes for all registered plugins' input arguments.' """ input_args = [] for plugin in self.plugins.values(): @@ -366,16 +342,17 @@ class PluginManager: input_args.append(input_args_from_plugin) return input_args - def load_datasets(self, cfg, preprocess: bool = False): - """ - Calls the load_datasets method of each registered plugin. + def load_datasets( + self, cfg: DictDefault, preprocess: bool = False + ) -> Union["TrainDatasetMeta", None]: + """Calls the load_datasets method of each registered plugin. Args: cfg: The configuration for the plugins. - preprocess : Whether this is preprocess step of the datasets. + preprocess: Whether this is preprocess step of the datasets. Returns: - dataset_meta: The dataset metadata loaded from all registered plugins. + The dataset metadata loaded from all registered plugins. """ return_ds_meta = None for plugin in self.plugins.values(): @@ -387,83 +364,66 @@ class PluginManager: raise RuntimeError("Multiple plugins loaded datasets") return return_ds_meta - def pre_model_load(self, cfg): - """ - Calls the pre_model_load method of all registered plugins. + def pre_model_load(self, cfg: DictDefault): + """Calls the pre_model_load method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - - Returns: - None + Args: + cfg: The configuration for the plugins. """ for plugin in self.plugins.values(): plugin.pre_model_load(cfg) - def post_model_build(self, cfg, model): - """ - Calls the post_model_build method of all registered plugins after the model has been built/loaded, - but before any adapters have been applied. + def post_model_build(self, cfg: DictDefault, model: PreTrainedModel): + """Calls the `post_model_build` method of all registered plugins after the + model has been built / loaded, but before any adapters have been applied. Args: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_model_build(cfg, model) - def post_model_load(self, cfg, model): - """ - Calls the post_model_load method of all registered plugins after the model has been loaded - inclusive of any adapters + def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel): + """Calls the `pre_lora_load` method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None - """ - for plugin in self.plugins.values(): - plugin.post_model_load(cfg, model) - - def pre_lora_load(self, cfg, model): - """ - Calls the pre_lora_load method of all registered plugins. - - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None + Args: + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.pre_lora_load(cfg, model) - def post_lora_load(self, cfg, model): - """ - Calls the post_lora_load method of all registered plugins. + def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Calls the `post_lora_load` method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None + Args: + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_lora_load(cfg, model) - def get_trainer_cls(self, cfg): - """ - Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class. + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Calls the `post_model_load` method of all registered plugins after the model + has been loaded inclusive of any adapters. - Parameters: - cfg (dict): The configuration for the plugins. + Args: + cfg: The configuration for the plugins. + model: The loaded model. + """ + for plugin in self.plugins.values(): + plugin.post_model_load(cfg, model) + + def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None: + """Calls the `get_trainer_cls` method of all registered plugins and returns the + first non-`None` trainer class. + + Args: + cfg: The configuration for the plugins. Returns: - object: The trainer class, or None if none was found. + The first non-`None` trainer class returned by a plugin. """ for plugin in self.plugins.values(): trainer_cls = plugin.get_trainer_cls(cfg) @@ -471,29 +431,25 @@ class PluginManager: return trainer_cls return None - def post_trainer_create(self, cfg, trainer): - """ - Calls the post_trainer_create method of all registered plugins. + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): + """Calls the `post_trainer_create` method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - trainer (object): The trainer object for training. - - Returns: - None + Args: + cfg: The configuration for the plugins. + trainer: The trainer object for training. """ for plugin in self.plugins.values(): plugin.post_trainer_create(cfg, trainer) - def create_optimizer(self, trainer): - """ - Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. + def create_optimizer(self, trainer: Trainer) -> Optimizer | None: + """Calls the `create_optimizer` method of all registered plugins and returns + the first non-`None` optimizer. - Parameters: - trainer (object): The trainer object for training. + Args: + trainer: The trainer object for training. Returns: - object: The created optimizer, or None if none was found. + The created optimizer, or `None` if none was found. """ for plugin in self.plugins.values(): optimizer = plugin.create_optimizer(self.cfg, trainer) @@ -502,17 +458,17 @@ class PluginManager: return None def create_lr_scheduler( - self, trainer, optimizer, num_training_steps + self, trainer: Trainer, optimizer: Optimizer, num_training_steps: int ) -> LRScheduler | None: - """ - Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler. + """Calls the `create_lr_scheduler` method of all registered plugins and returns + the first non-`None` scheduler. - Parameters: - trainer (object): The trainer object for training. - optimizer (object): The optimizer for training. + Args: + trainer: The trainer object for training. + optimizer: The optimizer for training. Returns: - object: The created learning rate scheduler, or None if none was found. + The created learning rate scheduler, or `None` if not found. """ for plugin in self.plugins.values(): scheduler: LRScheduler | None = plugin.create_lr_scheduler( @@ -525,16 +481,17 @@ class PluginManager: return scheduler return None - def add_callbacks_pre_trainer(self, cfg, model): - """ - Calls the add_callbacks_pre_trainer method of all registered plugins. + def add_callbacks_pre_trainer( + self, cfg: DictDefault, model: PreTrainedModel + ) -> list[Callable]: + """Calls the add_callbacks_pre_trainer method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. + Args: + cfg: The configuration for the plugins. + model: The loaded model. Returns: - List[callable]: A list of callback functions to be added to the TrainingArgs. + A list of callback functions to be added to the `TrainingArgs`. """ callbacks = [] for plugin in self.plugins.values(): @@ -543,16 +500,17 @@ class PluginManager: callbacks.extend(plugin_callbacks) return callbacks - def add_callbacks_post_trainer(self, cfg, trainer): - """ - Calls the add_callbacks_post_trainer method of all registered plugins. + def add_callbacks_post_trainer( + self, cfg: DictDefault, trainer: Trainer + ) -> list[Callable]: + """Calls the `add_callbacks_post_trainer` method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - trainer (object): The trainer object for training. + Args: + cfg: The configuration for the plugins. + trainer: The trainer object for training. Returns: - List[callable]: A list of callback functions to be added to the TrainingArgs. + A list of callback functions to be added to the `TrainingArgs`. """ callbacks = [] for plugin in self.plugins.values(): @@ -561,41 +519,31 @@ class PluginManager: callbacks.extend(plugin_callbacks) return callbacks - def post_train(self, cfg, model): - """ - Calls the post_train method of all registered plugins. + def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + """Calls the post_train method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None + Args: + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_train(cfg, model) - def post_train_unload(self, cfg): - """ - Calls the post_train_unload method of all registered plugins. + def post_train_unload(self, cfg: DictDefault): + """Calls the post_train_unload method of all registered plugins. - Parameters: - cfg (dict): The configuration for the plugins. - model (object): The loaded model. - - Returns: - None + Args: + cfg: The configuration for the plugins. + model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_train_unload(cfg) class BaseOptimizerFactory: - """ - Base class for factories to create custom optimizers - """ + """Base class for factories to create custom optimizers""" def __call__( self, opt_model, training_args, **optimizer_kwargs - ) -> "torch.optim.Optimizer": + ) -> Optimizer | None: pass diff --git a/src/axolotl/loaders/__init__.py b/src/axolotl/loaders/__init__.py new file mode 100644 index 000000000..3eef75e58 --- /dev/null +++ b/src/axolotl/loaders/__init__.py @@ -0,0 +1,10 @@ +"""Init for axolotl.loaders module""" + +# pylint: disable=unused-import +# flake8: noqa + +from .adapter import load_adapter, load_lora +from .constants import MULTIMODAL_AUTO_MODEL_MAPPING +from .model import ModelLoader +from .processor import load_processor +from .tokenizer import load_tokenizer diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py new file mode 100644 index 000000000..f7a484e9b --- /dev/null +++ b/src/axolotl/loaders/adapter.py @@ -0,0 +1,206 @@ +"""Adapter loading functionality, including LoRA / QLoRA and associated utils""" + +import logging +import os +import types +from typing import Any + +import bitsandbytes as bnb +import torch +from bitsandbytes.nn import Params4bit +from peft import ( + AdaptionPromptConfig, + LoftQConfig, + LoraConfig, + PeftConfig, + PeftMixedModel, + PeftModel, + get_peft_model, +) +from transformers import PreTrainedModel + +from axolotl.loaders.utils import get_linear_embedding_layers +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) + + +def setup_quantized_meta_for_peft(model: torch.nn.Module): + """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" + + def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument + return self + + for param in model.parameters(): + if isinstance(param, Params4bit): + param.quant_state._orig_to = ( # pylint: disable=protected-access + param.quant_state.to + ) + param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) + + +def setup_quantized_peft_meta_for_training(model: torch.nn.Module): + """Replaces dummy `quant_state.to` method with the original function to allow training to continue""" + for param in model.parameters(): + if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): + param.quant_state.to = ( + param.quant_state._orig_to # pylint: disable=protected-access + ) + param.quant_state._orig_to = None # pylint: disable=protected-access + + +def find_all_linear_names(model): + cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) + lora_module_names = set() + for name, module in model.named_modules(): + if ( + isinstance(module, cls) + or "Linear" in module.__class__.__name__ + and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) + ): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + embedding_modules = get_linear_embedding_layers(model.config.model_type) + output_embedding = embedding_modules[1] + if output_embedding in lora_module_names: # needed for 16-bit + lora_module_names.remove(output_embedding) + + return list(lora_module_names) + + +def load_lora( + model: PreTrainedModel, + cfg: DictDefault, + inference: bool = False, + config_only: bool = False, +) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: + lora_target_modules = cfg.lora_target_modules or [] + + if cfg.lora_target_linear: + linear_names = find_all_linear_names(model) + LOG.info(f"found linear modules: {repr(sorted(linear_names))}") + lora_target_modules_as_list = ( + lora_target_modules + if isinstance(lora_target_modules, list) + else [lora_target_modules] + ) + lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) + + lora_config_kwargs = {} + loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits + if loftq_bits: + lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) + lora_config_kwargs["init_lora_weights"] = "loftq" + if cfg.peft_init_lora_weights: + lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights + if cfg.peft_use_dora: + lora_config_kwargs["use_dora"] = cfg.peft_use_dora + LOG.info("Initializing LoRA weights using dora. This might take longer.") + if cfg.peft_use_rslora: + lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora + if cfg.peft_layer_replication: + lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication + + lora_config = LoraConfig( + r=cfg.lora_r, + lora_alpha=cfg.lora_alpha, + target_modules=lora_target_modules, + layers_to_transform=cfg.peft_layers_to_transform, + layers_pattern=cfg.peft_layers_pattern, + lora_dropout=cfg.lora_dropout, + fan_in_fan_out=cfg.lora_fan_in_fan_out, + modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, + bias="none", + task_type="CAUSAL_LM", + **lora_config_kwargs, + ) + + if config_only: + return None, lora_config + + rank = int(os.environ.get("LOCAL_RANK", 0)) + + if ( + cfg.fsdp + and cfg.adapter + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and rank != 0 + ): + setup_quantized_meta_for_peft(model) + + if cfg.lora_model_dir: + LOG.debug("Loading pretrained PEFT - LoRA") + model_kwargs: Any = {} + if cfg.lora_on_cpu: + model_kwargs["max_memory"] = {"cpu": "256GiB"} + model_kwargs["device_map"] = {"": "cpu"} + model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + is_trainable=(not inference), + **model_kwargs, + ) + else: + model = get_peft_model(model, lora_config) + + if rank == 0: + try: + model.print_trainable_parameters() + except AttributeError as exc: + LOG.warning( + "Exception caught during model.print_trainable_parameters(): %s", exc + ) + elif ( + cfg.fsdp + and cfg.adapter + and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and rank != 0 + ): + setup_quantized_peft_meta_for_training(model) + + return model, lora_config + + +def load_adapter( + model: PreTrainedModel, + cfg: DictDefault, + adapter: str | None, + inference: bool = False, +) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]: + if adapter is None: + return model, None + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + if adapter in ["lora", "qlora"]: + peft_model, lora_config = load_lora(model, cfg, inference=inference) + return peft_model, lora_config + if adapter == "llama-adapter": + peft_model, lora_config = load_llama_adapter(model, cfg) + return peft_model, lora_config + + raise NotImplementedError(f"{adapter} PEFT adapter not available") + + +def load_llama_adapter( + model: PreTrainedModel, cfg: DictDefault +) -> tuple[PeftModel | PeftMixedModel, PeftConfig]: + peft_config = AdaptionPromptConfig( + adapter_layers=cfg.peft_adapter.layers, # layers (L) + adapter_len=cfg.peft_adapter.len, # prompt length (K) + task_type="CAUSAL_LM", + ) + + if cfg.lora_model_dir: + LOG.debug("Loading pretrained PEFT - llama_adapter") + peft_model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + torch_dtype=torch.float16, + ) + else: + peft_model = get_peft_model(model, peft_config) + + peft_model.print_trainable_parameters() + + return peft_model, peft_config diff --git a/src/axolotl/loaders/constants.py b/src/axolotl/loaders/constants.py new file mode 100644 index 000000000..c08518dd6 --- /dev/null +++ b/src/axolotl/loaders/constants.py @@ -0,0 +1,21 @@ +"""Shared constants for axolotl.loaders module""" + +from transformers import ( + Gemma3ForConditionalGeneration, + Llama4ForConditionalGeneration, + LlavaForConditionalGeneration, + Mistral3ForConditionalGeneration, + MllamaForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLForConditionalGeneration, +) + +MULTIMODAL_AUTO_MODEL_MAPPING = { + "mllama": MllamaForConditionalGeneration, + "llama4": Llama4ForConditionalGeneration, + "llava": LlavaForConditionalGeneration, + "qwen2_vl": Qwen2VLForConditionalGeneration, + "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, + "mistral3": Mistral3ForConditionalGeneration, + "gemma3": Gemma3ForConditionalGeneration, +} diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py new file mode 100644 index 000000000..d7ac84a6d --- /dev/null +++ b/src/axolotl/loaders/model.py @@ -0,0 +1,754 @@ +"""Model loader class implementation for loading, configuring, and patching various +models. +""" + +import gc +import logging +import math +import os +from functools import cached_property +from importlib.util import find_spec +from typing import Any + +import peft +import torch +import transformers +import transformers.modeling_utils +from accelerate import init_empty_weights +from peft import PeftConfig, PeftMixedModel, PeftModel, prepare_model_for_kbit_training +from transformers import ( + AutoModelForCausalLM, + AutoModelForVision2Seq, + AwqConfig, + BitsAndBytesConfig, + GPTQConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.integrations.deepspeed import ( + HfTrainerDeepSpeedConfig, + is_deepspeed_zero3_enabled, +) + +from axolotl.common.architectures import MOE_ARCH_BLOCK +from axolotl.integrations.base import PluginManager +from axolotl.loaders.adapter import load_adapter, load_lora +from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING +from axolotl.loaders.patch_manager import PatchManager +from axolotl.loaders.utils import ( + get_linear_embedding_layers, + get_module_class_from_name, + load_model_config, +) +from axolotl.models.mamba import fix_mamba_attn_for_loss +from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import ( + get_device_count, + get_device_type, +) +from axolotl.utils.model_shard_quant import load_sharded_model_quant +from axolotl.utils.schemas.enums import RLType + +LOG = logging.getLogger(__name__) +PLUGIN_MANAGER = PluginManager.get_instance() + + +class ModelLoader: + """Manages model configuration, initialization and application of patches during + model loading. + + This class orchestrates the entire process of loading a model from configuration to + final preparation. It handles device mapping, quantization, attention mechanisms, + adapter integration, and various optimizations. + + The loading process includes: + - Loading and validating model configuration + - Applying monkey patches for optimizations / fixes + - Setting up device mapping (including multi-GPU configurations) + - Configuring quantization + - Setting attention mechanisms (Flash Attention, SDPA, etc.) + - Loading and initializing the model + - Applying adapters (LoRA, QLoRA, etc.) + + Attributes: + model: The loaded model instance (available after load() is called). + model_kwargs: Dictionary of keyword arguments passed to model initialization. + base_model: Name or path of the base model to load. + model_type: Type of model to load (e.g., `AutoModelForCausalLM`). + model_config: Configuration object for the model. + auto_model_loader: class used for loading the model (default: + `AutoModelForCausalLM`). + """ + + def __init__( + self, + cfg: DictDefault, + tokenizer: PreTrainedTokenizerBase, + *, + inference: bool = False, + reference_model: bool = False, + **kwargs, # pylint: disable=unused-argument + ): + """Initializes the ModelLoader. + + Args: + cfg: Configuration dictionary with model and training settings. + tokenizer: Tokenizer instance associated with the model. + processor: Optional processor for multimodal models. Defaults to None. + inference: Whether the model is being loaded for inference mode. Defaults + to False. + reference_model: Whether this is a reference model (used in setups like DPO + training). Defaults to False. + **kwargs: Additional keyword arguments (ignored). + """ + self.cfg = cfg + self.tokenizer = tokenizer + self.inference: bool = inference + self.reference_model: bool = reference_model + + # Init model kwargs + self.model_kwargs: dict[str, Any] = {} + if cfg.overrides_of_model_kwargs: + for key, val in cfg.overrides_of_model_kwargs.items(): + self.model_kwargs[key] = val + + # Init model + self.model: PreTrainedModel | PeftModel | PeftMixedModel + self.base_model = cfg.base_model + self.model_type = cfg.type_of_model + + # Init model config + self.model_config = load_model_config(cfg) + self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name + + # Initialize the patch manager + self.patch_manager = PatchManager( + cfg=cfg, + model_config=self.model_config, + inference=inference, + ) + + @cached_property + def has_flash_attn(self) -> bool: + """Check if flash attention is installed.""" + return find_spec("flash_attn") is not None + + @cached_property + def qlora_fsdp(self): + """Property that determines if FSDP with QLoRA is enabled.""" + return self.cfg.fsdp and self.cfg.adapter == "qlora" + + def load(self) -> tuple[PreTrainedModel, PeftConfig | None]: + """Load and prepare the model with all configurations and patches. + + Returns: + A tuple with the loaded model and its LoRA configuration (if applicable). + """ + # Initial setup and patches + self.patch_manager.apply_pre_model_load_patches() + self._apply_pre_model_load_setup() + + # Build the model + PLUGIN_MANAGER.pre_model_load(self.cfg) + skip_move_to_device = self._build_model() + PLUGIN_MANAGER.post_model_build(self.cfg, self.model) + + # Post-build model configuration + self._apply_post_model_load_setup() + + # Load adapters (LoRA, etc.) + PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model) + lora_config = self._load_adapters() + PLUGIN_MANAGER.post_lora_load(self.cfg, self.model) + + # Apply remaining patches and finalize + self._apply_post_lora_load_setup(skip_move_to_device) + self.patch_manager.apply_post_model_load_patches(self.model) + PLUGIN_MANAGER.post_model_load(self.cfg, self.model) + + return self.model, lora_config + + def _apply_pre_model_load_setup(self): + """Apply patches and setup configurations before model loading.""" + self._set_auto_model_loader() + self._set_device_map_config() + if self.cfg.revision_of_model: + self.model_kwargs["revision"] = self.cfg.revision_of_model + self._set_quantization_config() + self._set_attention_config() + + def _apply_post_model_load_setup(self): + """Configure the model after it has been loaded.""" + # Handle PeftModel if needed + if ( + isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM)) + and not self.qlora_fsdp + ): + self.model = self.model.merge_and_unload() + + self._resize_token_embeddings() + self._adjust_model_config() + self._log_memory_usage() + self._configure_embedding_dtypes() + + def _resize_token_embeddings(self): + """Resize token embeddings if needed.""" + embeddings_len = ( + math.ceil(len(self.tokenizer) / 32) * 32 + if self.cfg.resize_token_embeddings_to_32x + else len(self.tokenizer) + ) + if hasattr(self.model, "get_input_embeddings") and ( + self.model.get_input_embeddings().num_embeddings < embeddings_len + or ( + self.model.get_input_embeddings().num_embeddings > embeddings_len + and self.cfg.shrink_embeddings + ) + ): + resize_kwargs = {} + if self.cfg.mean_resizing_embeddings is not None and ( + self.model_config.model_type != "llava" + ): + resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings + self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) + else: + self.model.tie_weights() + + def _adjust_model_config(self): + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "max_position_embeddings") + and self.model.config.max_position_embeddings + and self.cfg.sequence_len > self.model.config.max_position_embeddings + ): + LOG.warning( + "increasing model.config.max_position_embeddings from " + f"{self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" + ) + self.model.config.max_position_embeddings = self.cfg.sequence_len + + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "bos_token_id") + and self.model.config.bos_token_id + and self.model.config.bos_token_id != self.tokenizer.bos_token_id + ): + self.model.config.bos_token_id = self.tokenizer.bos_token_id + + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, "eos_token_id") + and self.model.config.eos_token_id + and self.model.config.eos_token_id != self.tokenizer.eos_token_id + ): + self.model.config.eos_token_id = self.tokenizer.eos_token_id + + def _log_memory_usage(self): + """Log device memory usage after model load.""" + if hasattr(self.model, "device") and self.model.device.type in ( + "cuda", + "mps", + "npu", + ): + log_gpu_memory_usage(LOG, "after model load", self.model.device) + + def _configure_embedding_dtypes(self): + """Configure embedding module dtypes.""" + # Get embedding modules + embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) + + # Initial dtype conversion + if not self.cfg.fsdp: + # We don't run this during FSDP because this will leave mixed and bfloat16 + # dtypes in the model which FSDP doesn't like + if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast: + embedding_modules = [] + self._convert_embedding_modules_dtype( + embedding_modules, + dist_dtype=torch.float32, + before_kbit_train_or_finetune=True, + ) + + # Handle DeepSpeed Zero3 + if is_deepspeed_zero3_enabled(): + self._set_z3_leaf_modules() + + # Apply gradient checkpointing if needed + needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp + if self.cfg.adapter in ["lora", "qlora"]: + needs_fa2_dtype = True + if self.cfg.gradient_checkpointing: + self.model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs + ) + + self._prepare_model_for_quantization() + + # Convert dtypes if needed + should_convert = ( + # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so + # we need to convert them back to fp16/bf16 for flash-attn compatibility. + ( + (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) + and not self.qlora_fsdp + ) + # CCE requires embedding layers to be in fp16/bf16 for backward pass + or self.cfg.cut_cross_entropy + ) + + if should_convert: + LOG.info("Converting modules to %s", self.cfg.torch_dtype) + self._convert_embedding_modules_dtype( + embedding_modules=embedding_modules, + dist_dtype=self.cfg.torch_dtype, + before_kbit_train_or_finetune=False, + ) + + def _load_adapters(self) -> PeftConfig | None: + """Load LoRA or other adapters.""" + # Load LoRA or adapter + lora_config = None + if not self.reference_model or self.cfg.lora_model_dir: + # If we're not loading the reference model, then we're loading the model + # for training. Then, the DPO trainer doesn't want the PEFT model loaded + # over it, it just wants the LoRA / PEFT config. + if ( + self.cfg.adapter + and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO] + and not self.cfg.merge_lora + ): + _, lora_config = load_lora( + self.model, self.cfg, inference=False, config_only=True + ) + else: + self.model, lora_config = load_adapter( + self.model, self.cfg, self.cfg.adapter + ) + + return lora_config + + def _apply_post_lora_load_setup(self, skip_move_to_device: bool): + """Apply final optimizations and patches.""" + # Place model on accelerator + if ( + self.cfg.ddp + and not self.cfg.load_in_8bit + and not (self.cfg.rl and self.cfg.load_in_4bit) + and not skip_move_to_device + ): + # TODO: validate this conditional + self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") + + if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: + self.model.is_parallelizable = True + self.model.model_parallel = True + + if not any( + param.requires_grad + for _, param in self.model.named_parameters(recurse=True) + ): + LOG.warning("There are no parameters that require gradient updates") + + if self.cfg.flash_optimum: + from optimum.bettertransformer import BetterTransformer + + self.model = BetterTransformer.transform(self.model) + + if self.cfg.adapter is not None: + log_gpu_memory_usage(LOG, "after adapters", self.model.device) + + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + + def _set_auto_model_loader(self): + """Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM` + (set at `__init__`). When using a multimodal model, `self.auto_model_loader` + should be set according to the type of the model. + """ + if self.cfg.is_multimodal: + self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get( + self.model_config.model_type, AutoModelForVision2Seq + ) + + def _set_device_map_config(self): + """Setup `device_map` according to config""" + device_map = self.cfg.device_map + max_memory = self.cfg.max_memory + + if self.cfg.gpu_memory_limit: + gpu_memory_limit = ( + str(self.cfg.gpu_memory_limit) + "GiB" + if isinstance(self.cfg.gpu_memory_limit, int) + else self.cfg.gpu_memory_limit + ) + + max_memory = {} + num_device = get_device_count() + for i in range(num_device): + max_memory[i] = gpu_memory_limit + max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything + + if max_memory is not None: + # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py + from accelerate import infer_auto_device_map + + with init_empty_weights(): + model_canvas = self.auto_model_loader.from_config( + self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + ) + model_canvas.tie_weights() + device_map = infer_auto_device_map( + model_canvas, + max_memory=max_memory, + dtype=self.cfg.torch_dtype, + ) + # We can discard max_memory now as we have a device map set up + max_memory = None + + self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype + + if not is_deepspeed_zero3_enabled(): + self.model_kwargs["device_map"] = device_map + + cur_device = get_device_type() + if "mps" in str(cur_device): + self.model_kwargs["device_map"] = "mps:0" + elif "npu" in str(cur_device): + self.model_kwargs["device_map"] = "npu:0" + + # TODO: can we put the reference model on it's own gpu? I think we have to move + # logits around to calculate loss + # if cfg.rl: + # if torch.cuda.device_count() > 1: + # if reference_model: + # model_kwargs["device_map"] = "cuda:" + str( + # torch.cuda.current_device() + 1 + # ) + # else: + # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) + + def _set_quantization_config(self): + """Set up quantization config (bitsandbytes, awq, gptq, etc.)""" + self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit + self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit + + if self.cfg.gptq: + if not hasattr(self.model_config, "quantization_config"): + LOG.warning( + "model config does not contain quantization_config information" + ) + else: + if self.cfg.gptq_disable_exllama is not None: + self.model_config.quantization_config["disable_exllama"] = ( + self.cfg.gptq_disable_exllama + ) + self.model_kwargs["quantization_config"] = GPTQConfig( + **self.model_config.quantization_config + ) + if ( + self.cfg.adapter in ["qlora", "lora"] + and hasattr(self.model_config, "quantization_config") + and self.model_config.quantization_config["quant_method"] + in ["gptq", "awq", "bitsandbytes"] + ): + if self.model_config.quantization_config["quant_method"] == "gptq": + self.model_kwargs["quantization_config"] = GPTQConfig( + **self.model_config.quantization_config + ) + elif self.model_config.quantization_config["quant_method"] == "awq": + self.model_kwargs["quantization_config"] = AwqConfig( + **self.model_config.quantization_config + ) + elif ( + self.model_config.quantization_config["quant_method"] == "bitsandbytes" + ): + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **self.model_config.quantization_config + ) + elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: + bnb_config = { + "load_in_4bit": True, + "llm_int8_threshold": 6.0, + "llm_int8_has_fp16_weight": False, + "bnb_4bit_compute_dtype": self.cfg.torch_dtype, + "bnb_4bit_use_double_quant": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_quant_storage": torch.bfloat16, + } + if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( + self.cfg.deepspeed or self.cfg.fsdp + ): + # for some reason, this causes the loss to be off by an order of magnitude + # but deepspeed needs this still in bfloat16 + bnb_config["bnb_4bit_quant_storage"] = torch.float32 + + if self.cfg.bnb_config_kwargs: + bnb_config.update(self.cfg.bnb_config_kwargs) + + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, + ) + elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]: + bnb_config = { + "load_in_8bit": True, + } + # Exclude mamba blocks from int8 quantization for jamba + if self.cfg.model_config_type == "jamba": + bnb_config["llm_int8_skip_modules"] = ["mamba"] + self.model_kwargs["quantization_config"] = BitsAndBytesConfig( + **bnb_config, + ) + + # no longer needed per https://github.com/huggingface/transformers/pull/26610 + if "quantization_config" in self.model_kwargs or self.cfg.gptq: + self.model_kwargs.pop("load_in_8bit", None) + self.model_kwargs.pop("load_in_4bit", None) + + def _set_attention_config(self): + """Sample packing uses custom FA2 patch""" + if self.cfg.flex_attention: + self.model_kwargs["attn_implementation"] = "flex_attention" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flex_attention" + ) + + elif self.cfg.flash_attention: + if not self.cfg.sample_packing and self.cfg.s2_attention: + pass + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) + elif self.cfg.sdp_attention: + self.model_kwargs["attn_implementation"] = "sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "sdpa" + ) + elif self.cfg.eager_attention: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" + ) + + if self.cfg.low_cpu_mem_usage: + self.model_kwargs["low_cpu_mem_usage"] = True + + def _configure_zero3_memory_efficient_loading(self): + """Set the deepspeed config to load the model into RAM first before moving + to VRAM. + + We need to return `hf_ds_cfg` as it needs to exist before model loading. + """ + hf_ds_cfg = None + + if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3": + hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed) + hf_ds_cfg.fill_match( + "train_micro_batch_size_per_gpu", self.cfg.micro_batch_size + ) + hf_ds_cfg.fill_match( + "gradient_accumulation_steps", self.cfg.gradient_accumulation_steps + ) + hf_ds_cfg.fill_match( + "train_batch_size", + int(os.getenv("WORLD_SIZE", "1")) + * self.cfg.micro_batch_size + * self.cfg.gradient_accumulation_steps, + ) + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True + transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = ( + lambda: True + ) + + return hf_ds_cfg + + def _build_model(self) -> bool: + """Load model, with load strategy depending on config.""" + skip_move_to_device = False + if ( + self.qlora_fsdp + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and ( + self.cfg.model_config_type == "dbrx" + or self.cfg.qlora_sharded_model_loading + ) + ): + quant_storage = self.cfg.torch_dtype + quantization_config = getattr( + self.model_config, "quantization_config", None + ) + quantization_config = ( + quantization_config or self.model_kwargs["quantization_config"] + ) + self.model = load_sharded_model_quant( + self.base_model, + self.model_config, + self.cfg, + quant_storage=quant_storage, + quantization_config=quantization_config, + ) + skip_move_to_device = True + elif ( + self.model_config.model_type in ["llama", "llama4"] + and not self.cfg.trust_remote_code + and not self.cfg.gptq + ): + # TODO: Do we need to open this up for all models? + if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: + skip_move_to_device = True + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + self._configure_zero3_memory_efficient_loading() + + # Load model with random initialization if specified + if self.cfg.random_init_weights: + # AutoModel classes support the from_config method + if self.auto_model_loader in [ + AutoModelForCausalLM, + AutoModelForVision2Seq, + ]: + self.model = self.auto_model_loader.from_config( + config=self.model_config, + ) + else: + self.model = self.auto_model_loader(config=self.model_config) + else: + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + **self.model_kwargs, + ) + elif self.model_type == "MambaLMHeadModel": + # FIXME this is janky at best and hacked together to make it work + MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name + + self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] + self.model_kwargs["device"] = torch.cuda.current_device() + self.model_kwargs.pop("torch_dtype", None) + self.model_kwargs.pop("device_map", None) + + self.model = MambaLMHeadModel.from_pretrained( + self.base_model, + **self.model_kwargs, + ) + elif ( + self.model_type + and self.model_type != "AutoModelForCausalLM" + and not self.cfg.trust_remote_code + ): + if self.cfg.gptq: + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, + ) + else: + self.model = getattr(transformers, self.model_type).from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, + ) + else: + if self.cfg.gptq: + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, + ) + else: + if ( + self.cfg.fsdp + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + ): + # disabling either of these two still leads to VRAM spike before setting back down + skip_move_to_device = True + if "device_map" in self.model_kwargs: + del self.model_kwargs["device_map"] + + self._configure_zero3_memory_efficient_loading() + + self.model = self.auto_model_loader.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, + ) + if is_deepspeed_zero3_enabled(): + skip_move_to_device = True + + return skip_move_to_device + + def _set_z3_leaf_modules(self): + from deepspeed.utils import set_z3_leaf_modules + + if self.cfg.model_config_type in MOE_ARCH_BLOCK: + moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type] + moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks + set_z3_leaf_modules( + self.model, + [ + get_module_class_from_name(self.model, module_name) + for module_name in moe_blocks + ], + ) + + def _prepare_model_for_quantization(self): + """Prepare loaded model for quantization.""" + skip_prepare_model_for_kbit_training = False + if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": + # Qwen doesn't play nicely with LoRA if this is enabled + skip_prepare_model_for_kbit_training = True + + loftq_bits = ( + self.cfg.peft + and self.cfg.peft.loftq_config + and self.cfg.peft.loftq_config.loftq_bits + ) + if self.cfg.adapter == "lora" and loftq_bits: + skip_prepare_model_for_kbit_training = True + + if ( + self.qlora_fsdp + or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) + or is_deepspeed_zero3_enabled() + ): + # Make sure everything is in the same dtype + skip_prepare_model_for_kbit_training = True + + if ( + not skip_prepare_model_for_kbit_training + and self.cfg.adapter in ["lora", "qlora"] + and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) + ): + LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") + self.model = prepare_model_for_kbit_training( + self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing + ) + + def _convert_embedding_modules_dtype( + self, + embedding_modules: list[str], + dist_dtype: torch.dtype, + before_kbit_train_or_finetune: bool, + ): + for name, module in self.model.named_modules(): + if "norm" in name: + module.to(dist_dtype) + if before_kbit_train_or_finetune: + if name.endswith(".gate"): + module.to(dist_dtype) + if self.model_config.model_type == "btlm": + # don't upcast lm_head for btlm + continue + if any(m in name for m in embedding_modules) and hasattr(module, "weight"): + module.to(dist_dtype) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py new file mode 100644 index 000000000..f251f958d --- /dev/null +++ b/src/axolotl/loaders/patch_manager.py @@ -0,0 +1,380 @@ +"""Patch manager class implementation to complement `axolotl.loaders.ModelLoader`. + +Applies pre- and post-model load patches for various fixes and optimizations. +""" + +import importlib.util +import logging +from functools import cached_property + +import addict +import transformers +from transformers import PretrainedConfig, PreTrainedModel + +from axolotl.integrations.base import PluginManager +from axolotl.monkeypatch.multipack import ( + SUPPORTED_MULTIPACK_MODEL_TYPES, + patch_for_multipack, +) +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) +PLUGIN_MANAGER = PluginManager.get_instance() + + +class PatchManager: + """Manages the application of patches during the model loading process.""" + + def __init__( + self, + cfg: DictDefault, + model_config: PretrainedConfig | addict.Dict, + inference: bool = False, + ): + """Initialize the `PatchManager`. + + Args: + cfg: Configuration dictionary with model and training settings. + model_config: Configuration object for the model. + inference: Whether the model is being loaded for inference mode. + """ + self.cfg = cfg + self.model_config = model_config + self.inference = inference + + @cached_property + def has_flash_attn(self) -> bool: + """Check if flash attention is installed.""" + return importlib.util.find_spec("flash_attn") is not None + + def apply_pre_model_load_patches(self): + """Apply pre-model load patches based on config.""" + self._apply_flash_attention_patches() + self._apply_fsdp_patches() + self._apply_adapter_patches() + self._apply_flex_attention_patches() + self._apply_model_specific_patches() + self._apply_fp8_patches() + self._apply_flash_attention_peft_patches() + self._apply_gradient_checkpointing_patches() + self._patch_attention() + self._apply_multipack_patches() + self._patch_llama_derived_model() + self._apply_mistral_cross_entropy_patch() + self._apply_unsloth_self_attention_patch() + + def apply_post_model_load_patches(self, model: PreTrainedModel): + """Apply patches that require the model instance.""" + self._apply_llama_flash_attn_patches(model) + self._apply_unsloth_patches(model) + self._apply_lora_kernel_patch(model) + + def _apply_flash_attention_patches(self): + """Apply patches related to Flash Attention.""" + if self.cfg.xformers_attention and self.cfg.sample_packing: + from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 + + patch_xformers_attn_over_fa2() + self.cfg.flash_attention = True + + def _apply_fsdp_patches(self): + """Apply patches for FSDP configurations.""" + if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": + from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils + + patch_accelerate_fsdp_utils() + + def _apply_adapter_patches(self): + """Apply patches for adapter configurations.""" + if self.cfg.adapter and self.cfg.embeddings_skip_upcast: + from axolotl.monkeypatch.peft.utils import patch_peft_prep_code + + patch_peft_prep_code() + + def _apply_flex_attention_patches(self): + """Apply patches for flexible attention.""" + if self.cfg.flex_attention: + from axolotl.monkeypatch.attention.flex_attn import ( + patch_flex_make_mask, + patch_flex_wrapper, + ) + + flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} + patch_flex_wrapper(**flex_attn_compile_kwargs) + patch_flex_make_mask() + + def _apply_model_specific_patches(self): + """Apply patches specific to model architectures.""" + if ( + self.cfg.model_config_type == "llama4" + and self.cfg.llama4_linearized_experts + ): + from axolotl.monkeypatch.models.llama4.modeling import ( + patch_llama4_linearized_modeling, + ) + + patch_llama4_linearized_modeling() + + if self.cfg.model_config_type == "gemma3": + from axolotl.monkeypatch.gemma3 import ( + patch_gemma3conditionalgeneration_forward, + ) + + patch_gemma3conditionalgeneration_forward() + + def _apply_fp8_patches(self): + """Apply patches for FP8 support.""" + if self.cfg.fp8: + from axolotl.monkeypatch.trainer_accelerator_args import ( + patch_create_accelerate_code_for_fp8, + ) + + patch_create_accelerate_code_for_fp8() + + def _apply_flash_attention_peft_patches(self): + """Apply patches for Flash Attention with PEFT.""" + if self.cfg.adapter: + from axolotl.monkeypatch.transformers_fa_utils import ( + patch_fa_peft_integration, + ) + + patch_fa_peft_integration() + + def _apply_gradient_checkpointing_patches(self): + """Apply patches for gradient checkpointing.""" + if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: + from axolotl.monkeypatch.gradient_checkpointing import ( + hf_grad_checkpoint_offload_wrapper, + ) + + transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper + if self.cfg.gradient_checkpointing == "offload_disk": + from axolotl.monkeypatch.gradient_checkpointing import ( + hf_grad_checkpoint_disk_offload_wrapper, + ) + + transformers.modeling_utils.checkpoint = ( + hf_grad_checkpoint_disk_offload_wrapper + ) + + def _apply_mistral_cross_entropy_patch(self): + """Apply Mistral cross entropy patch if configured.""" + if ( + self.cfg.model_config_type == "mistral" + and self.cfg.flash_attn_cross_entropy_loss + ): + from axolotl.monkeypatch.mistral_attn_hijack_flash import ( + patch_mistral_cross_entropy, + ) + + patch_mistral_cross_entropy() + + def _apply_unsloth_self_attention_patch(self): + """Apply Unsloth self-attention patches if configured.""" + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora + + patch_self_attn_lora(self.cfg) + + def _apply_multipack_patches(self): + """Apply multipack patches if necessary.""" + if ( + self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES + and (self.cfg.flash_attention or self.cfg.flex_attention) + and self.cfg.sample_packing + ): + # Get automap config if it exists + auto_map_config = None + if isinstance(self.model_config, dict) and "auto_map" in self.model_config: + auto_map_config = self.model_config["auto_map"] + elif hasattr(self.model_config, "auto_map"): + auto_map_config = self.model_config.auto_map + + # Determine if the model has remote code + if auto_map_config is not None: + has_remote_code = "AutoModelForCausalLM" in auto_map_config + else: + has_remote_code = False + + if has_remote_code and self.cfg.trust_remote_code is False: + # If explicitly set in YAML, prefer that + has_remote_code = self.cfg.trust_remote_code + + patch_for_multipack( + self.cfg.model_config_type, + model_name=self.cfg.base_model, + has_remote_code=has_remote_code, + ) + + if self.cfg.is_llama_derived_model: + self._patch_loss_llama() + + def _patch_attention(self): + """Apply attention-specific patches based on model type.""" + if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): + return + + if self.model_config.model_type == "mllama" and self.cfg.flash_attention: + from axolotl.monkeypatch.attention.mllama import patch_mllama + + patch_mllama() + + if self.model_config.model_type == "btlm": + from axolotl.monkeypatch.btlm_attn_hijack_flash import ( + replace_btlm_attn_with_flash_attn, + ) + + replace_btlm_attn_with_flash_attn(self.cfg.base_model) + + if self.model_config.model_type == "stablelm_epoch" and self.cfg.sample_packing: + from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( + replace_stablelm_attn_with_flash_attn, + ) + + replace_stablelm_attn_with_flash_attn(self.cfg.base_model) + + def _patch_loss_llama(self): + """Patch loss functions and other optimizations for LLaMA models.""" + if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_fa_llama_cross_entropy, + ) + + patch_fa_llama_cross_entropy() + elif self.cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch + + integrate_cross_entropy_loss_patch(model_type="llama") + + if self.cfg.flash_attn_rms_norm and self.has_flash_attn: + from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm + + patch_llama_rms_norm() + elif self.cfg.unsloth_rms_norm: + from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm + + patch_unsloth_layernorm() + + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora + + patch_self_attn_lora() + + def _patch_llama_flash_attention(self, packed=False): + """Apply Flash Attention patches for LLaMA models.""" + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + replace_llama_attn_with_flash_attn, + ) + + if packed: + if self.cfg.device not in ["mps", "cpu"] and not self.inference: + LOG.info("patching with flash attention for sample packing") + replace_llama_attn_with_flash_attn( + packed=True, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, + ) + elif self.cfg.s2_attention: + LOG.info("patching w/ flash-enabled, shifted-sparse attention") + replace_llama_attn_with_flash_attn( + packed=False, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, + use_shifted_sparse_attn=True, + ) + elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: + replace_llama_attn_with_flash_attn( + packed=False, + cross_entropy=self.cfg.flash_attn_cross_entropy, + rms_norm=self.cfg.flash_attn_rms_norm, + ) + + def _patch_llama_xformers_attention(self): + """Apply xformers attention patches for LLaMA models.""" + from axolotl.monkeypatch.llama_attn_hijack_xformers import ( + hijack_llama_attention, + ) + + LOG.info("Patching with xformers attention...") + hijack_llama_attention() + + def _patch_llama_sample_packing(self): + """Apply sample packing patches for LLaMA models.""" + from axolotl.monkeypatch.llama_patch_multipack import ( + hijack_llama_prepare_4d_mask, + ) + + LOG.info("Patching llama _prepare_4d_causal_attention_mask*...") + hijack_llama_prepare_4d_mask() + + def _patch_llama_derived_model(self): + """Modify all llama derived models in one block.""" + if self.cfg.is_llama_derived_model and not ( + self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES + and (self.cfg.flash_attention or self.cfg.flex_attention) + and self.cfg.sample_packing + ): + self._patch_loss_llama() + + if self.cfg.flash_attention: + self._patch_llama_flash_attention(packed=self.cfg.sample_packing) + elif self.cfg.xformers_attention: + self._patch_llama_xformers_attention() + elif self.cfg.sample_packing: + self._patch_llama_sample_packing() + elif self.cfg.s2_attention: + raise NotImplementedError( + "Shifted-sparse attention not currently implemented without flash attention." + ) + + def _apply_llama_flash_attn_patches(self, model): + """Apply LLaMA-specific flash attention patches.""" + if ( + self.model_config.model_type in ["llama", "llama4"] + and not self.cfg.trust_remote_code + and not self.cfg.gptq + and self.cfg.flash_attention + and not self.inference + ): + # TODO(MengqingCao): split these patches seperately + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + is_xformers_swiglu_available, + replace_llama_mlp_with_swiglu, + replace_llama_qkv_with_fused, + ) + + if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): + LOG.info("Patching with SwiGLU...") + replace_llama_mlp_with_swiglu(model) + + if self.cfg.flash_attn_fuse_qkv: + LOG.info("Patching with fused QKV...") + replace_llama_qkv_with_fused(model) + + def _apply_unsloth_patches(self, model): + """Apply unsloth optimization patches.""" + if self.cfg.unsloth_lora_mlp: + from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch + + integrate_lora_mlp_patch(peft_model=model) + + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import integrate_lora_patch + + integrate_lora_patch(peft_model=model, cfg=self.cfg) + + if self.cfg.unsloth_rope: + from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + + integrate_rope_embeddings() + + def _apply_lora_kernel_patch(self, model): + """Apply LoRA kernel patches.""" + if ( + self.cfg.lora_mlp_kernel + or self.cfg.lora_qkv_kernel + or self.cfg.lora_o_kernel + ): + from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches + + apply_lora_kernel_patches(model=model, cfg=self.cfg) diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py new file mode 100644 index 000000000..57394bc67 --- /dev/null +++ b/src/axolotl/loaders/processor.py @@ -0,0 +1,56 @@ +"""Processor loading functionality for multi-modal models""" + +import logging +from typing import Any + +import transformers +from transformers import ( + AutoProcessor, + PreTrainedTokenizerBase, +) + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) + + +def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): + processor_kwargs: dict[str, Any] = {} # Do we actually need this? + + processor_cls = AutoProcessor + if cfg.processor_type: + processor_cls = getattr(transformers, cfg.processor_type) + + processor = processor_cls.from_pretrained( + cfg.processor_config, + trust_remote_code=cfg.trust_remote_code or False, + tokenizer=tokenizer, + **processor_kwargs, + ) + + # Attempt to load image size from processor if available + if ( + cfg.image_size is None + and hasattr(processor, "size") + and any(dim in processor.size for dim in ["width", "height"]) + ): + im_width = None + im_height = None + if "width" in processor.size: + im_width = processor.size["width"] + if "height" in processor.size: + im_height = processor.size["height"] + + # If both width and height are set, use a tuple + if im_width is not None and im_height is not None: + cfg.image_size = (im_width, im_height) + # If only width is set, use as integer + elif im_width is not None: + cfg.image_size = im_width + # If only height is set, use as integer + elif im_height is not None: + cfg.image_size = im_height + + LOG.debug(f"Loaded image size: {cfg.image_size} from processor") + + return processor diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py new file mode 100644 index 000000000..ec9d69e8a --- /dev/null +++ b/src/axolotl/loaders/tokenizer.py @@ -0,0 +1,281 @@ +"""Tokenizer loading functionality and associated utils""" + +import json +import logging +import os + +import transformers +from transformers import ( + AddedToken, + AutoTokenizer, +) + +from axolotl.integrations.base import PluginManager +from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config +from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN +from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.utils.distributed import ( + barrier, + is_local_main_process, + is_main_process, +) + +LOG = logging.getLogger(__name__) +PLUGIN_MANAGER = PluginManager.get_instance() + + +def modify_tokenizer_files( + tokenizer_path: str, token_mappings: dict[int, str], output_dir: str +) -> str: + """ + Modify tokenizer files to replace added_tokens strings, save to output directory, + and return the path to the modified tokenizer. + + This only works with reserved tokens that were added to the tokenizer, not tokens + already part of the vocab. + + Args: + tokenizer_path: Path or name of the original tokenizer + token_mappings: Dict mapping {token_id (int): new_token_string} + output_dir: Directory to save the modified tokenizer + + Returns: + Path to the modified tokenizer directory + + Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941 + """ + # Create the tokenizer directory in output_dir if it doesn't exist + tokenizer_dir = os.path.join(output_dir, "tokenizer") + os.makedirs(tokenizer_dir, exist_ok=True) + + if is_local_main_process(): # pylint: disable=too-many-nested-blocks + # Load the tokenizer + temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + + # Save the tokenizer to the output directory + temp_tokenizer.save_pretrained(tokenizer_dir) + + # Get the token IDs and map them to their new values + token_id_mappings = { + int(token_id): new_value for token_id, new_value in token_mappings.items() + } + + # 1. Update tokenizer_config.json - added_tokens_decoder + config_path = os.path.join(tokenizer_dir, "tokenizer_config.json") + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + config_data = json.load(f) + + # Update added_tokens_decoder + if "added_tokens_decoder" in config_data: + for token_id, new_value in token_id_mappings.items(): + token_id_str = str(token_id) + if token_id_str in config_data["added_tokens_decoder"]: + config_data["added_tokens_decoder"][token_id_str][ + "content" + ] = new_value + else: + raise ValueError( + f"Token ID {token_id_str} not found in added_tokens_decoder" + ) + + # Write the updated config back + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config_data, f, indent=2) + + # 2. Update tokenizer.json - added_tokens + tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") + if os.path.exists(tokenizer_path): + with open(tokenizer_path, "r", encoding="utf-8") as f: + tokenizer_data = json.load(f) + + # Update added_tokens + if "added_tokens" in tokenizer_data: + for token_id, new_value in token_id_mappings.items(): + for i, token_entry in enumerate(tokenizer_data["added_tokens"]): + if token_entry["id"] == token_id: + tokenizer_data["added_tokens"][i]["content"] = new_value + break + else: + # Reaching this section means the token_id was not found in tokenizer.json added_tokens + raise ValueError( + f"Token ID {token_id} not found in added_tokens" + ) + if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]: + for token_id, new_value in token_id_mappings.items(): + for entry_val, entry_id in tokenizer_data["model"]["vocab"].items(): + if entry_id == token_id: + del tokenizer_data["model"]["vocab"][entry_val] + tokenizer_data["model"]["vocab"][new_value] = token_id + break + + # Write the updated tokenizer data back + with open(tokenizer_path, "w", encoding="utf-8") as f: + json.dump(tokenizer_data, f, indent=2) + + barrier() + return tokenizer_dir + + +def load_tokenizer(cfg): + """Load and configure the tokenizer based on the provided config.""" + model_config = load_model_config(cfg) + tokenizer_kwargs = {} + use_fast = True # this is the default + + if cfg.tokenizer_use_fast is not None: + use_fast = cfg.tokenizer_use_fast + if cfg.tokenizer_legacy is not None: + # True is the default w/ https://github.com/huggingface/transformers/pull/25224 + tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy + + tokenizer_cls = AutoTokenizer + if cfg.tokenizer_type: + tokenizer_cls = getattr(transformers, cfg.tokenizer_type) + + # Set base tokenizer path + tokenizer_path = cfg.tokenizer_config + + # Apply token string overrides if specified + if cfg.added_tokens_overrides: + # Modify tokenizer files and get path to modified tokenizer + tokenizer_path = modify_tokenizer_files( + tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir + ) + + tokenizer = tokenizer_cls.from_pretrained( + tokenizer_path, + trust_remote_code=cfg.trust_remote_code or False, + use_fast=use_fast, + **tokenizer_kwargs, + ) + + if ( + tokenizer.__class__.__name__ + in [ + "LlamaTokenizer", + "LlamaTokenizerFast", + "CodeLlamaTokenizer", + "CodeLlamaTokenizerFast", + ] + and hasattr(tokenizer, "pad_token") + and not tokenizer.pad_token + ): + # set a pad_token, but use eos_token so we don't add a new token + tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN + + if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Mistral's official FA implementation requires left padding + if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: + tokenizer.padding_side = "left" + + # Qwen base only has single token, so we need to set the special tokens + if cfg.is_qwen_derived_model: + token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] + for attr_name in token_ids: + if getattr(tokenizer, attr_name) is None: + setattr(tokenizer, attr_name, tokenizer.eod_id) + + token_names = ["bos_token", "eos_token", "pad_token", "unk_token"] + for attr_name in token_names: + if getattr(tokenizer, attr_name) is None: + setattr(tokenizer, attr_name, "<|endoftext|>") + + additional_special_tokens = None + if cfg.special_tokens: + special_tokens = cfg.special_tokens.to_dict() + additional_special_tokens = special_tokens.pop( + "additional_special_tokens", None + ) + lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) + for k, val in special_tokens.items(): + # check if new special token is not already in tokenizer and + # is adapter training to make sure lora_modules_to_save is set + # pylint: disable=too-many-boolean-expressions + if ( + (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) + and (len(tokenizer.encode(val, add_special_tokens=False)) > 2) + and cfg.adapter + and ( + not cfg.lora_modules_to_save + or not all( + x in cfg.lora_modules_to_save for x in lora_modules_to_save + ) + ) + and k != "pad_token" + ): + lora_modules_to_save = ", ".join( + [f"`{x}`" for x in lora_modules_to_save] + ) + raise ValueError( + f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." + ) + + tokenizer.add_special_tokens( + {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} + ) + + # If we add bos_token and eos_token, we need to update the post processor to + # handle them correctly. + # https://github.com/huggingface/transformers/pull/24132 + bos_or_eos_in_special_tokens = ( + "bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens + ) + if ( + tokenizer.__class__.__name__ + in ( + "LlamaTokenizerFast", + "CodeLlamaTokenizerFast", + ) + and bos_or_eos_in_special_tokens + ): + tokenizer.update_post_processor() + + if cfg.tokens: + tokenizer.add_tokens( + [ + AddedToken(token, rstrip=False, lstrip=False, normalized=False) + for token in cfg.tokens + ] + ) + + # Additional special tokens are a List, and need to be treated differently than regular special + # tokens. We add them after we have called `add_tokens` in case these additional special tokens + # are new tokens. + # + # Usage: + # + # ```py + # special_tokens: + # additional_special_tokens: ["<|im_start|>", "<|im_end|>"] + # ``` + if additional_special_tokens is not None: + tokenizer.add_special_tokens( + {"additional_special_tokens": additional_special_tokens} + ) + + if is_main_process(use_environ=True): + LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") + LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") + LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") + LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + + if cfg.chat_template: + chat_template_string = get_chat_template_from_config( + cfg=cfg, + tokenizer=tokenizer, + ) + if cfg.default_system_message and cfg.chat_template == "chatml": + chat_template_string = chat_template_string.replace( + "You are a helpful assistant.", cfg.default_system_message + ) + + tokenizer.chat_template = chat_template_string + else: + LOG.info( + "No Chat template selected. Consider adding a chat template for easier inference." + ) + return tokenizer diff --git a/src/axolotl/loaders/utils.py b/src/axolotl/loaders/utils.py new file mode 100644 index 000000000..1aae4834d --- /dev/null +++ b/src/axolotl/loaders/utils.py @@ -0,0 +1,211 @@ +"""Utilities for axolotl.loaders module""" + +import contextlib +import logging +from typing import Type + +import addict +import torch +from transformers import AutoConfig, PretrainedConfig, PreTrainedModel + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) + + +def get_module_class_from_name( + module: torch.nn.Module, name: str +) -> Type[torch.nn.Module] | None: + """Gets a class from a module by its name. Copied from `accelerate.utils.dataclasses` + (https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L2805). + + Args: + module: The module to get the class from. + name: The name of the class. + + Returns: + The class type of the matching module, or `None` if no match is found. + """ + modules_children = list(module.children()) + if module.__class__.__name__ == name: + return module.__class__ + + if len(modules_children) == 0: + return None + + for child_module in modules_children: + module_class = get_module_class_from_name(child_module, name) + if module_class is not None: + return module_class + + return None + + +def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): + """Validates and adjusts model config based on `axolotl` config. + + This function performs several important checks and adjustments: + - Disables model caching for better memory efficiency + - Handles multimodal model-specific configurations + - Validates quantization settings + - Ensures proper LoRA configuration when using adapters with new tokens + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + model_config: The model's configuration object from `transformers`. + + Raises: + ValueError: If a multimodal model lacks text configuration, if GPTQ settings + are inconsistent, or if LoRA `modules_to_save` is improperly configured + with new tokens. + """ + if hasattr(model_config, "use_cache"): + model_config.use_cache = False + + if cfg.is_multimodal: + # For multimodal configs, use_cache is set in the text_config + if hasattr(model_config, "get_text_config"): + text_config = model_config.get_text_config() + if hasattr(text_config, "use_cache"): + text_config.use_cache = False + else: + raise ValueError( + "No text config found for multimodal model. Please raise an Issue with model details." + ) + + # Check if image_size is not set and load image size from model config if available + if ( + cfg.image_size is None + and hasattr(model_config, "vision_config") + and hasattr(model_config.vision_config, "image_size") + ): + cfg.image_size = model_config.vision_config.image_size + LOG.debug(f"Loaded image size: {cfg.image_size} from model config") + + quant_config_exists = ( + hasattr(model_config, "quantization_config") + and model_config.quantization_config + ) + + # Detect compressed-tensors config + is_compressed_tensors_config = ( + quant_config_exists + and model_config.quantization_config.get("quant_method") == "compressed-tensors" + ) + + if is_compressed_tensors_config: + if model_config.quantization_config.get("config_groups"): + LOG.warning( + "Found `config_groups` in a compressed-tensors config. " + "QAT integration with llmcompressor is not tested." + ) + # Skip further quant checks for compressed-tensors + return + + quant_config_method_is_gptq = ( + quant_config_exists + and "quant_method" in model_config.quantization_config + and model_config.quantization_config["quant_method"] == "gptq" + ) + + if cfg.gptq and not quant_config_method_is_gptq: + raise ValueError( + "model_config.quantization_config is not set or quant_method is not set to gptq. " + "Please make sure to point to a GPTQ model." + ) + + lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) + if ( + cfg.adapter + and cfg.tokens + and ( + not cfg.lora_modules_to_save + or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save) + ) + ): + lora_modules_to_save_joined = ", ".join( + map(lambda x: f"`{x}`", lora_modules_to_save) + ) + raise ValueError( + "`lora_modules_to_save` not properly set when adding new tokens. " + f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`." + ) + + +def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict: + """Loads and configures a model configuration from HuggingFace or local sources. + + This function determines the appropriate model config source, loads it, applies any + necessary overrides, and validates it for compatibility with the `axolotl` config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + + Returns: + A configured model configuration object (`AutoConfig` instance), or a simple + dictionary configuration for special cases like Mamba models. + + Raises: + ValueError: If configuration loading fails for reasons other than special cases + that are handled (e.g., Mamba models). + """ + model_config_name = cfg.base_model_config or cfg.base_model + if not model_config_name and cfg.tokenizer_config: + model_config_name = cfg.tokenizer_config + trust_remote_code = cfg.trust_remote_code is True + config_kwargs = {} + if cfg.revision_of_model: + config_kwargs["revision"] = cfg.revision_of_model + if cfg.num_labels: + # num_labels is used to initialize classifier models + config_kwargs["num_labels"] = cfg.num_labels + try: + model_config = AutoConfig.from_pretrained( + model_config_name, + trust_remote_code=trust_remote_code, + **config_kwargs, + ) + except ValueError as error: + if "mamba" in model_config_name: + return addict.Dict( + { + "model_type": "mamba", + } + ) + raise error + + if cfg.overrides_of_model_config: + for key, val in cfg.overrides_of_model_config.items(): + setattr(model_config, key, val) + + check_model_config(cfg, model_config) + + return model_config + + +def ensure_dtype(model: PreTrainedModel, dtype: torch.dtype = torch.bfloat16): + """Ensures all modules in the model are converted to the specified data type.""" + for name, module in model.named_modules(): + weight_mismatch = False + with contextlib.suppress(AttributeError): + weight_mismatch = module.weight.dtype != dtype + + bias_mismatch = False + with contextlib.suppress(AttributeError): + bias_mismatch = module.bias.dtype != dtype + + if weight_mismatch: + print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") + if bias_mismatch: + print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") + if weight_mismatch or bias_mismatch: + module.to(dtype) + + +def get_linear_embedding_layers(model_type: str) -> list[str]: + """Returns layer names of linear embeddings needed for LoRA based on model type.""" + if model_type == "gpt_neox": + return ["embed_in", "embed_out"] + if model_type == "falcon": + return ["word_embeddings", "lm_head"] + return ["embed_tokens", "lm_head"] diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py similarity index 91% rename from src/axolotl/utils/gradient_checkpointing/__init__.py rename to src/axolotl/monkeypatch/gradient_checkpointing/__init__.py index ae0c559e9..5d631776b 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py @@ -5,10 +5,10 @@ from functools import partial from packaging import version -from axolotl.utils.gradient_checkpointing.offload_cpu import ( +from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( CPU_Offloaded_Gradient_Checkpointer, ) -from axolotl.utils.gradient_checkpointing.offload_disk import ( +from axolotl.monkeypatch.gradient_checkpointing.offload_disk import ( Disco, ) diff --git a/src/axolotl/utils/gradient_checkpointing/offload_cpu.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py similarity index 100% rename from src/axolotl/utils/gradient_checkpointing/offload_cpu.py rename to src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py diff --git a/src/axolotl/utils/gradient_checkpointing/offload_disk.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py similarity index 100% rename from src/axolotl/utils/gradient_checkpointing/offload_disk.py rename to src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py diff --git a/src/axolotl/monkeypatch/peft/utils.py b/src/axolotl/monkeypatch/peft/utils.py index b3703d398..fdc49c5f6 100644 --- a/src/axolotl/monkeypatch/peft/utils.py +++ b/src/axolotl/monkeypatch/peft/utils.py @@ -75,4 +75,4 @@ def patch_peft_prep_code(): exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102 LOG.info("patching prepare_model_for_kbit_training to allow for overrides") peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 - axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 + axolotl.loaders.model.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821 diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 46f722eeb..52ec8f22b 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -28,11 +28,15 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module ) from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.integrations.base import PluginManager +from axolotl.loaders import ( + ModelLoader, + load_processor, + load_tokenizer, +) from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except -from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.schemas.enums import RLType from axolotl.utils.trainer import setup_trainer @@ -76,7 +80,8 @@ def setup_model_and_tokenizer( msg += " and peft_config..." LOG.debug(msg) - model, peft_config = load_model(cfg, tokenizer, processor=processor) + model_loader = ModelLoader(cfg, tokenizer, processor=processor) + model, peft_config = model_loader.load() if model.generation_config is not None: model.generation_config.do_sample = True @@ -113,7 +118,8 @@ def setup_reference_model( model_ref = None # explicit setting to None else: # load the model again for model_ref/baseline - model_ref, _ = load_model(cfg, tokenizer, reference_model=True) + model_loader = ModelLoader(cfg, tokenizer, reference_model=True) + model_ref, _ = model_loader.load() return model_ref diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index a96cc1286..49e4cfc6f 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -11,9 +11,10 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.integrations.base import PluginManager from axolotl.integrations.config import merge_input_args +from axolotl.loaders import MULTIMODAL_AUTO_MODEL_MAPPING +from axolotl.loaders.utils import load_model_config from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault -from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, ) diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 2ae93acad..491cb9877 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -10,7 +10,7 @@ from torch.utils.hooks import RemovableHandle from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import ModelOutput -from axolotl.monkeypatch.ring_attn.patch import ( +from axolotl.monkeypatch.ring_attn import ( get_ring_attn_group, patch_prepare_data_loader, patch_prepare_device_mesh, diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index dc5920099..15744d4c6 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -10,6 +10,7 @@ import yaml from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.loaders import load_tokenizer from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_strategies.kto import load as load_kto from axolotl.prompt_strategies.orpo import load as load_orpo @@ -17,7 +18,6 @@ from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first -from axolotl.utils.models import load_tokenizer from axolotl.utils.schemas.enums import RLType LOG = logging.getLogger(__name__) diff --git a/src/axolotl/utils/lora_embeddings.py b/src/axolotl/utils/lora_embeddings.py deleted file mode 100644 index 70f56655e..000000000 --- a/src/axolotl/utils/lora_embeddings.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -helpers for lora embeddings -""" - - -def get_linear_embedding_layers(model_type): - """ - returns the linear embedding layers needed for loras, dependent on the model arch - """ - if model_type == "gpt_neox": - return ["embed_in", "embed_out"] - if model_type == "falcon": - return ["word_embeddings", "lm_head"] - return ["embed_tokens", "lm_head"] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py deleted file mode 100644 index cd7499869..000000000 --- a/src/axolotl/utils/models.py +++ /dev/null @@ -1,1648 +0,0 @@ -"""Module for models and model loading""" - -# pylint: disable=too-many-lines -import gc -import importlib -import logging -import math -import os -import types -from functools import cached_property -from typing import Any, Dict, Optional, Tuple - -import addict -import bitsandbytes as bnb -import torch -import transformers -import transformers.modeling_utils -from accelerate import init_empty_weights -from bitsandbytes.nn import Params4bit -from peft import ( - LoftQConfig, - PeftConfig, - PeftModel, - PeftModelForCausalLM, - prepare_model_for_kbit_training, -) -from torch import nn -from transformers import ( - AddedToken, - AutoConfig, - AutoModelForCausalLM, - AutoModelForVision2Seq, - AutoProcessor, - AutoTokenizer, - AwqConfig, - BitsAndBytesConfig, - Gemma3ForConditionalGeneration, - GPTQConfig, - Llama4ForConditionalGeneration, - LlavaForConditionalGeneration, - Mistral3ForConditionalGeneration, - MllamaForConditionalGeneration, - PretrainedConfig, - PreTrainedModel, - PreTrainedTokenizerBase, - ProcessorMixin, - Qwen2_5_VLForConditionalGeneration, - Qwen2VLForConditionalGeneration, -) -from transformers.integrations.deepspeed import ( - HfTrainerDeepSpeedConfig, - is_deepspeed_zero3_enabled, -) - -from axolotl.common.architectures import MOE_ARCH_BLOCK -from axolotl.integrations.base import PluginManager -from axolotl.models.mamba import fix_mamba_attn_for_loss -from axolotl.monkeypatch.multipack import ( - SUPPORTED_MULTIPACK_MODEL_TYPES, - patch_for_multipack, -) -from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN -from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.chat_templates import get_chat_template_from_config -from axolotl.utils.dict import DictDefault -from axolotl.utils.distributed import ( - barrier, - get_device_count, - get_device_type, - is_local_main_process, - is_main_process, -) -from axolotl.utils.gradient_checkpointing import ( - hf_grad_checkpoint_disk_offload_wrapper, - hf_grad_checkpoint_offload_wrapper, -) -from axolotl.utils.lora_embeddings import get_linear_embedding_layers -from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant -from axolotl.utils.schemas.enums import RLType - -LOG = logging.getLogger(__name__) -PLUGIN_MANAGER = PluginManager.get_instance() - -MULTIMODAL_AUTO_MODEL_MAPPING = { - "mllama": MllamaForConditionalGeneration, - "llama4": Llama4ForConditionalGeneration, - "llava": LlavaForConditionalGeneration, - "qwen2_vl": Qwen2VLForConditionalGeneration, - "qwen2_5_vl": Qwen2_5_VLForConditionalGeneration, - "mistral3": Mistral3ForConditionalGeneration, - "gemma3": Gemma3ForConditionalGeneration, -} - - -# copied from accelerator.FullyShardedDataParallelPlugin -def get_module_class_from_name(module, name): - """ - Gets a class from a module by its name. - - Args: - module (`torch.nn.Module`): The module to get the class from. - name (`str`): The name of the class. - """ - modules_children = list(module.children()) - if module.__class__.__name__ == name: - return module.__class__ - - if len(modules_children) == 0: - return None - - for child_module in modules_children: - module_class = get_module_class_from_name(child_module, name) - if module_class is not None: - return module_class - - return None - - -def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): - # Set use_cache to False - if hasattr(model_config, "use_cache"): - model_config.use_cache = False - - if cfg.is_multimodal: - # For multimodal configs, use_cache is set in the text_config - if hasattr(model_config, "get_text_config"): - text_config = model_config.get_text_config() - if hasattr(text_config, "use_cache"): - text_config.use_cache = False - else: - raise ValueError( - "No text config found for multimodal model. Please raise an Issue with model details." - ) - - # check if image_size is not set and load image size from model config if available - if ( - cfg.image_size is None - and hasattr(model_config, "vision_config") - and hasattr(model_config.vision_config, "image_size") - ): - cfg.image_size = model_config.vision_config.image_size - LOG.debug(f"Loaded image size: {cfg.image_size} from model config") - - quant_config_exists = ( - hasattr(model_config, "quantization_config") - and model_config.quantization_config - ) - - # Detect compressed-tensors config - is_compressed_tensors_config = ( - quant_config_exists - and model_config.quantization_config.get("quant_method") == "compressed-tensors" - ) - - if is_compressed_tensors_config: - if model_config.quantization_config.get("config_groups"): - LOG.warning( - "Found `config_groups` in a compressed-tensors config. " - "QAT integration with llmcompressor is not tested." - ) - # Skip further quant checks for compressed-tensors - return - - quant_config_method_is_gptq = ( - quant_config_exists - and "quant_method" in model_config.quantization_config - and model_config.quantization_config["quant_method"] == "gptq" - ) - - if cfg.gptq and not quant_config_method_is_gptq: - raise ValueError( - "model_config.quantization_config is not set or quant_method is not set to gptq. " - "Please make sure to point to a GPTQ model." - ) - - lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) - if ( - cfg.adapter - and cfg.tokens - and ( - not cfg.lora_modules_to_save - or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save) - ) - ): - lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save)) - raise ValueError( - f"`lora_modules_to_save` not properly set when adding new tokens. Please include [{lora_modules_to_save}] in `lora_modules_to_save`." - ) - - -def load_model_config(cfg): - model_config_name = cfg.base_model_config or cfg.base_model - if not model_config_name and cfg.tokenizer_config: - model_config_name = cfg.tokenizer_config - trust_remote_code = cfg.trust_remote_code is True - config_kwargs = {} - if cfg.revision_of_model: - config_kwargs["revision"] = cfg.revision_of_model - if cfg.num_labels: - # num_labels is used to initialize classifier models - config_kwargs["num_labels"] = cfg.num_labels - try: - model_config = AutoConfig.from_pretrained( - model_config_name, - trust_remote_code=trust_remote_code, - **config_kwargs, - ) - except ValueError as err: - if "mamba" in model_config_name: - return addict.Dict( - { - "model_type": "mamba", - } - ) - raise err - - if cfg.overrides_of_model_config: - for key, val in cfg.overrides_of_model_config.items(): - setattr(model_config, key, val) - - check_model_config(cfg, model_config) - - return model_config - - -def modify_tokenizer_files( - tokenizer_path: str, token_mappings: Dict[int, str], output_dir: str -) -> str: - """ - Modify tokenizer files to replace added_tokens strings, save to output directory, and return the path to the modified tokenizer. - - This only works with reserved tokens that were added to the tokenizer, not tokens already part of the vocab. - - Args: - tokenizer_path: Path or name of the original tokenizer - token_mappings: Dict mapping {token_id (int): new_token_string} - output_dir: Directory to save the modified tokenizer - - Returns: - Path to the modified tokenizer directory - - Ref: https://github.com/huggingface/transformers/issues/27974#issuecomment-1854188941 - """ - - import json - - # Create the tokenizer directory in output_dir if it doesn't exist - tokenizer_dir = os.path.join(output_dir, "tokenizer") - os.makedirs(tokenizer_dir, exist_ok=True) - - if is_local_main_process(): # pylint: disable=too-many-nested-blocks - # Load the tokenizer - temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) - - # Save the tokenizer to the output directory - temp_tokenizer.save_pretrained(tokenizer_dir) - - # Get the token IDs and map them to their new values - token_id_mappings = { - int(token_id): new_value for token_id, new_value in token_mappings.items() - } - - # 1. Update tokenizer_config.json - added_tokens_decoder - config_path = os.path.join(tokenizer_dir, "tokenizer_config.json") - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - config_data = json.load(f) - - # Update added_tokens_decoder - if "added_tokens_decoder" in config_data: - for token_id, new_value in token_id_mappings.items(): - token_id_str = str(token_id) - if token_id_str in config_data["added_tokens_decoder"]: - config_data["added_tokens_decoder"][token_id_str][ - "content" - ] = new_value - else: - raise ValueError( - f"Token ID {token_id_str} not found in added_tokens_decoder" - ) - - # Write the updated config back - with open(config_path, "w", encoding="utf-8") as f: - json.dump(config_data, f, indent=2) - - # 2. Update tokenizer.json - added_tokens - tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") - if os.path.exists(tokenizer_path): - with open(tokenizer_path, "r", encoding="utf-8") as f: - tokenizer_data = json.load(f) - - # Update added_tokens - if "added_tokens" in tokenizer_data: - for token_id, new_value in token_id_mappings.items(): - for i, token_entry in enumerate(tokenizer_data["added_tokens"]): - if token_entry["id"] == token_id: - tokenizer_data["added_tokens"][i]["content"] = new_value - break - else: - # Reaching this section means the token_id was not found in tokenizer.json added_tokens - raise ValueError( - f"Token ID {token_id} not found in added_tokens" - ) - if "model" in tokenizer_data and "vocab" in tokenizer_data["model"]: - for token_id, new_value in token_id_mappings.items(): - for entry_val, entry_id in tokenizer_data["model"]["vocab"].items(): - if entry_id == token_id: - del tokenizer_data["model"]["vocab"][entry_val] - tokenizer_data["model"]["vocab"][new_value] = token_id - break - - # Write the updated tokenizer data back - with open(tokenizer_path, "w", encoding="utf-8") as f: - json.dump(tokenizer_data, f, indent=2) - - barrier() - return tokenizer_dir - - -def load_tokenizer(cfg): - """Load and configure the tokenizer based on the provided config.""" - model_config = load_model_config(cfg) - tokenizer_kwargs = {} - use_fast = True # this is the default - - if cfg.tokenizer_use_fast is not None: - use_fast = cfg.tokenizer_use_fast - if cfg.tokenizer_legacy is not None: - # True is the default w/ https://github.com/huggingface/transformers/pull/25224 - tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy - - tokenizer_cls = AutoTokenizer - if cfg.tokenizer_type: - tokenizer_cls = getattr(transformers, cfg.tokenizer_type) - - # Set base tokenizer path - tokenizer_path = cfg.tokenizer_config - - # Apply token string overrides if specified - if cfg.added_tokens_overrides: - # Modify tokenizer files and get path to modified tokenizer - tokenizer_path = modify_tokenizer_files( - tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir - ) - - tokenizer = tokenizer_cls.from_pretrained( - tokenizer_path, - trust_remote_code=cfg.trust_remote_code or False, - use_fast=use_fast, - **tokenizer_kwargs, - ) - - if ( - tokenizer.__class__.__name__ - in [ - "LlamaTokenizer", - "LlamaTokenizerFast", - "CodeLlamaTokenizer", - "CodeLlamaTokenizerFast", - ] - and hasattr(tokenizer, "pad_token") - and not tokenizer.pad_token - ): - # set a pad_token, but use eos_token so we don't add a new token - tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN - - if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # Mistral's official FA implementation requires left padding - if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: - tokenizer.padding_side = "left" - - # Qwen base only has single token, so we need to set the special tokens - if cfg.is_qwen_derived_model: - token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] - for attr_name in token_ids: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, tokenizer.eod_id) - - token_names = ["bos_token", "eos_token", "pad_token", "unk_token"] - for attr_name in token_names: - if getattr(tokenizer, attr_name) is None: - setattr(tokenizer, attr_name, "<|endoftext|>") - - additional_special_tokens = None - if cfg.special_tokens: - special_tokens = cfg.special_tokens.to_dict() - additional_special_tokens = special_tokens.pop( - "additional_special_tokens", None - ) - lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) - for k, val in special_tokens.items(): - # check if new special token is not already in tokenizer and - # is adapter training to make sure lora_modules_to_save is set - # pylint: disable=too-many-boolean-expressions - if ( - (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) - and (len(tokenizer.encode(val, add_special_tokens=False)) > 2) - and cfg.adapter - and ( - not cfg.lora_modules_to_save - or not all( - x in cfg.lora_modules_to_save for x in lora_modules_to_save - ) - ) - and k != "pad_token" - ): - lora_modules_to_save = ", ".join( - [f"`{x}`" for x in lora_modules_to_save] - ) - raise ValueError( - f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." - ) - - tokenizer.add_special_tokens( - {k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} - ) - - # If we add bos_token and eos_token, we need to update the post processor to - # handle them correctly. - # https://github.com/huggingface/transformers/pull/24132 - bos_or_eos_in_special_tokens = ( - "bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens - ) - if ( - tokenizer.__class__.__name__ - in ( - "LlamaTokenizerFast", - "CodeLlamaTokenizerFast", - ) - and bos_or_eos_in_special_tokens - ): - tokenizer.update_post_processor() - - if cfg.tokens: - tokenizer.add_tokens( - [ - AddedToken(token, rstrip=False, lstrip=False, normalized=False) - for token in cfg.tokens - ] - ) - - # Additional special tokens are a List, and need to be treated differently than regular special - # tokens. We add them after we have called `add_tokens` in case these additional special tokens - # are new tokens. - # - # Usage: - # - # ```py - # special_tokens: - # additional_special_tokens: ["<|im_start|>", "<|im_end|>"] - # ``` - if additional_special_tokens is not None: - tokenizer.add_special_tokens( - {"additional_special_tokens": additional_special_tokens} - ) - - if is_main_process(use_environ=True): - LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") - - if cfg.chat_template: - chat_template_string = get_chat_template_from_config( - cfg=cfg, - tokenizer=tokenizer, - ) - if cfg.default_system_message and cfg.chat_template == "chatml": - chat_template_string = chat_template_string.replace( - "You are a helpful assistant.", cfg.default_system_message - ) - - tokenizer.chat_template = chat_template_string - else: - LOG.info( - "No Chat template selected. Consider adding a chat template for easier inference." - ) - return tokenizer - - -def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): - processor_kwargs: Dict[str, Any] = {} # do we actually need this? - - processor_cls = AutoProcessor - if cfg.processor_type: - processor_cls = getattr(transformers, cfg.processor_type) - - processor = processor_cls.from_pretrained( - cfg.processor_config, - trust_remote_code=cfg.trust_remote_code or False, - tokenizer=tokenizer, - **processor_kwargs, - ) - - # Attempt to load image size from processor if available - if ( - cfg.image_size is None - and hasattr(processor, "size") - and any(dim in processor.size for dim in ["width", "height"]) - ): - im_width = None - im_height = None - if "width" in processor.size: - im_width = processor.size["width"] - if "height" in processor.size: - im_height = processor.size["height"] - - # If both width and height are set, use a tuple - if im_width is not None and im_height is not None: - cfg.image_size = (im_width, im_height) - # If only width is set, use as integer - elif im_width is not None: - cfg.image_size = im_width - # If only height is set, use as integer - elif im_height is not None: - cfg.image_size = im_height - - LOG.debug(f"Loaded image size: {cfg.image_size} from processor") - - return processor - - -class ModelLoader: - """ - ModelLoader: managing all the config and monkey patches while loading model - """ - - def __init__( - self, - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - *, - processor: ProcessorMixin = None, # pylint: disable=unused-argument - inference: bool = False, - reference_model: bool = False, - **kwargs, # pylint: disable=unused-argument - ) -> None: - self.cfg = cfg - self.tokenizer = tokenizer - self.inference: bool = inference - self.reference_model: bool = reference_model - - # init model kwargs - self.model_kwargs: Dict[str, Any] = {} - if cfg.overrides_of_model_kwargs: - for key, val in cfg.overrides_of_model_kwargs.items(): - self.model_kwargs[key] = val - - # init model - self.model: PreTrainedModel - self.base_model = cfg.base_model - self.model_type = cfg.type_of_model - - # init model config - self.model_config = load_model_config(cfg) - - self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name - - def apply_patches(self) -> None: - if self.cfg.xformers_attention and self.cfg.sample_packing: - from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 - - patch_xformers_attn_over_fa2() - self.cfg.flash_attention = True - if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": - from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils - - patch_accelerate_fsdp_utils() - - if self.cfg.adapter and self.cfg.embeddings_skip_upcast: - from axolotl.monkeypatch.peft.utils import patch_peft_prep_code - - patch_peft_prep_code() - - if self.cfg.flex_attention: - from axolotl.monkeypatch.attention.flex_attn import ( - patch_flex_make_mask, - patch_flex_wrapper, - ) - - flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} - patch_flex_wrapper(**flex_attn_compile_kwargs) - patch_flex_make_mask() - - # patch gemma3 conditional generation forward before loading plugins - # as it could be overridden by plugins - if self.cfg.model_config_type == "llama4": - if self.cfg.llama4_linearized_experts: - from axolotl.monkeypatch.models.llama4.modeling import ( - patch_llama4_linearized_modeling, - ) - - patch_llama4_linearized_modeling() - - if self.cfg.model_config_type == "gemma3": - from axolotl.monkeypatch.gemma3 import ( - patch_gemma3conditionalgeneration_forward, - ) - - patch_gemma3conditionalgeneration_forward() - - # load any patches from plugins - - PLUGIN_MANAGER.pre_model_load(self.cfg) - - # monkey patch to allow additional Accelerator init kwargs - if self.cfg.fp8: - from axolotl.monkeypatch.trainer_accelerator_args import ( - patch_create_accelerate_code_for_fp8, - ) - - patch_create_accelerate_code_for_fp8() - - if self.cfg.adapter: - from axolotl.monkeypatch.transformers_fa_utils import ( - patch_fa_peft_integration, - ) - - patch_fa_peft_integration() - - if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: - transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper - if self.cfg.gradient_checkpointing == "offload_disk": - transformers.modeling_utils.checkpoint = ( - hf_grad_checkpoint_disk_offload_wrapper - ) - - if self.cfg.flash_attention: - self.patch_attention() - - if self.cfg.sample_packing and self.cfg.s2_attention: - raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." - ) - - if ( - self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and (self.cfg.flash_attention or self.cfg.flex_attention) - and self.cfg.sample_packing - ): - if "auto_map" in self.model_config: - try: - auto_map_config = self.model_config["auto_map"] - except TypeError: - auto_map_config = self.model_config.auto_map - has_remote_code = "AutoModelForCausalLM" in auto_map_config - else: - has_remote_code = False - - if has_remote_code and self.cfg.trust_remote_code is False: - # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled - has_remote_code = self.cfg.trust_remote_code - patch_for_multipack( - self.cfg.model_config_type, - model_name=self.cfg.base_model, - has_remote_code=has_remote_code, - ) - - if self.cfg.is_llama_derived_model: - self.patch_loss_llama() - elif self.cfg.is_llama_derived_model: - self.patch_llama_derived_model() - - if ( - self.cfg.model_config_type == "mistral" - and self.cfg.flash_attn_cross_entropy_loss - ): - from axolotl.monkeypatch.mistral_attn_hijack_flash import ( - patch_mistral_cross_entropy, - ) - - patch_mistral_cross_entropy() - - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora - - patch_self_attn_lora(self.cfg) - - def patch_attention(self) -> None: - if hasattr(self.model_config, "model_type"): - if self.model_config.model_type == "mllama" and self.cfg.flash_attention: - from axolotl.monkeypatch.attention.mllama import patch_mllama - - patch_mllama() - - if self.model_config.model_type == "btlm": - from axolotl.monkeypatch.btlm_attn_hijack_flash import ( - replace_btlm_attn_with_flash_attn, - ) - - replace_btlm_attn_with_flash_attn(self.cfg.base_model) - - if ( - self.model_config.model_type == "stablelm_epoch" - and self.cfg.sample_packing - ): - from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( - replace_stablelm_attn_with_flash_attn, - ) - - replace_stablelm_attn_with_flash_attn(self.cfg.base_model) - - @cached_property - def has_flash_attn(self) -> bool: - """Check if flash attention is installed""" - return importlib.util.find_spec("flash_attn") is not None - - def patch_loss_llama(self) -> None: - """Patch loss functions and other optimizations""" - if self.has_flash_attn: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_fa_llama_cross_entropy, - patch_llama_rms_norm, - ) - - if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: - patch_fa_llama_cross_entropy() - elif self.cfg.unsloth_cross_entropy_loss: - from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch - - integrate_cross_entropy_loss_patch(model_type="llama") - - if self.cfg.flash_attn_rms_norm and self.has_flash_attn: - patch_llama_rms_norm() - elif self.cfg.unsloth_rms_norm: - from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm - - patch_unsloth_layernorm() - - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - - patch_self_attn_lora() - - def patch_llama_derived_model(self): - """Modify all llama derived models in one block""" - self.patch_loss_llama() - - if self.cfg.flash_attention: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - replace_llama_attn_with_flash_attn, - ) - - if self.cfg.sample_packing: - if self.cfg.device not in ["mps", "cpu"] and not self.inference: - LOG.info("patching with flash attention for sample packing") - replace_llama_attn_with_flash_attn( - packed=True, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - ) - elif self.cfg.s2_attention: - LOG.info("patching w/ flash-enabled, shifted-sparse attention") - replace_llama_attn_with_flash_attn( - packed=False, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - use_shifted_sparse_attn=True, - ) - elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm: - replace_llama_attn_with_flash_attn( - packed=False, - cross_entropy=self.cfg.flash_attn_cross_entropy, - rms_norm=self.cfg.flash_attn_rms_norm, - ) - elif self.cfg.xformers_attention: - from axolotl.monkeypatch.llama_attn_hijack_xformers import ( - hijack_llama_attention, - ) - - LOG.info("patching with xformers attention") - hijack_llama_attention() - elif self.cfg.sample_packing: - from axolotl.monkeypatch.llama_patch_multipack import ( - hijack_llama_prepare_4d_mask, - ) - - LOG.info("patching llama _prepare_4d_causal_attention_mask*") - hijack_llama_prepare_4d_mask() - elif self.cfg.s2_attention: - raise NotImplementedError( - "Shifted-sparse attention not currently implemented without flash attention." - ) - - def set_auto_model_loader(self): - """ - Set self.auto_model_loader. Defaults to `transformers.AutoModelForCausalLM` - (set at `__init__`). When using a multimodal model, `self.auto_model_loader` - should be set according to the type of the model. - """ - if self.cfg.is_multimodal: - self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get( - self.model_config.model_type, AutoModelForVision2Seq - ) - - def set_device_map_config(self) -> None: - device_map = self.cfg.device_map - max_memory = self.cfg.max_memory - - if self.cfg.gpu_memory_limit: - gpu_memory_limit = ( - str(self.cfg.gpu_memory_limit) + "GiB" - if isinstance(self.cfg.gpu_memory_limit, int) - else self.cfg.gpu_memory_limit - ) - - max_memory = {} - num_device = get_device_count() - for i in range(num_device): - max_memory[i] = gpu_memory_limit - max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything - - if max_memory is not None: - # Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py - from accelerate import infer_auto_device_map - - with init_empty_weights(): - model_canvas = self.auto_model_loader.from_config( - self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - ) - model_canvas.tie_weights() - device_map = infer_auto_device_map( - model_canvas, - max_memory=max_memory, - dtype=self.cfg.torch_dtype, - ) - # We can discard max_memory now as we have a device map set up for us - max_memory = None - - self.model_kwargs["device_map"] = device_map - self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype - - cur_device = get_device_type() - if "mps" in str(cur_device): - self.model_kwargs["device_map"] = "mps:0" - elif "npu" in str(cur_device): - self.model_kwargs["device_map"] = "npu:0" - - # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss - # if cfg.rl: - # if torch.cuda.device_count() > 1: - # if reference_model: - # model_kwargs["device_map"] = "cuda:" + str( - # torch.cuda.current_device() + 1 - # ) - # else: - # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device()) - - if is_deepspeed_zero3_enabled(): - del self.model_kwargs["device_map"] - - def set_quantization_config(self) -> None: - self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit - self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit - - if self.cfg.gptq: - if not hasattr(self.model_config, "quantization_config"): - LOG.warning( - "model config does not contain quantization_config information" - ) - else: - if self.cfg.gptq_disable_exllama is not None: - self.model_config.quantization_config["disable_exllama"] = ( - self.cfg.gptq_disable_exllama - ) - self.model_kwargs["quantization_config"] = GPTQConfig( - **self.model_config.quantization_config - ) - if ( - self.cfg.adapter in ["qlora", "lora"] - and hasattr(self.model_config, "quantization_config") - and self.model_config.quantization_config["quant_method"] - in ["gptq", "awq", "bitsandbytes"] - ): - if self.model_config.quantization_config["quant_method"] == "gptq": - self.model_kwargs["quantization_config"] = GPTQConfig( - **self.model_config.quantization_config - ) - elif self.model_config.quantization_config["quant_method"] == "awq": - self.model_kwargs["quantization_config"] = AwqConfig( - **self.model_config.quantization_config - ) - elif ( - self.model_config.quantization_config["quant_method"] == "bitsandbytes" - ): - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **self.model_config.quantization_config - ) - elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]: - bnb_config = { - "load_in_4bit": True, - "llm_int8_threshold": 6.0, - "llm_int8_has_fp16_weight": False, - "bnb_4bit_compute_dtype": self.cfg.torch_dtype, - "bnb_4bit_use_double_quant": True, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_quant_storage": torch.bfloat16, - } - if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not ( - self.cfg.deepspeed or self.cfg.fsdp - ): - # for some reason, this causes the loss to be off by an order of magnitude - # but deepspeed needs this still in bfloat16 - bnb_config["bnb_4bit_quant_storage"] = torch.float32 - - if self.cfg.bnb_config_kwargs: - bnb_config.update(self.cfg.bnb_config_kwargs) - - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]: - bnb_config = { - "load_in_8bit": True, - } - # Exclude mamba blocks from int8 quantization for jamba - if self.cfg.model_config_type == "jamba": - bnb_config["llm_int8_skip_modules"] = ["mamba"] - self.model_kwargs["quantization_config"] = BitsAndBytesConfig( - **bnb_config, - ) - - # no longer needed per https://github.com/huggingface/transformers/pull/26610 - if "quantization_config" in self.model_kwargs or self.cfg.gptq: - self.model_kwargs.pop("load_in_8bit", None) - self.model_kwargs.pop("load_in_4bit", None) - - def set_attention_config(self) -> None: - """ - sample packing uses custom FA2 patch - """ - if self.cfg.flex_attention: - self.model_kwargs["attn_implementation"] = "flex_attention" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flex_attention" - ) - - elif self.cfg.flash_attention: - if not self.cfg.sample_packing and self.cfg.s2_attention: - pass - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - elif self.cfg.sdp_attention: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "sdpa" - ) - elif self.cfg.eager_attention: - self.model_kwargs["attn_implementation"] = "eager" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) - - if self.cfg.low_cpu_mem_usage: - self.model_kwargs["low_cpu_mem_usage"] = True - - def build_model(self, qlora_fsdp) -> bool: - def _configure_zero3_memory_efficient_loading(): - """ - Set the deepspeed config to load the model into RAM first before moving to VRAM. - - We need to return hf_ds_cfg as it needs to exist before model loading. - """ - hf_ds_cfg = None - - if os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3": - hf_ds_cfg = HfTrainerDeepSpeedConfig(self.cfg.deepspeed) - hf_ds_cfg.fill_match( - "train_micro_batch_size_per_gpu", self.cfg.micro_batch_size - ) - hf_ds_cfg.fill_match( - "gradient_accumulation_steps", self.cfg.gradient_accumulation_steps - ) - hf_ds_cfg.fill_match( - "train_batch_size", - int(os.getenv("WORLD_SIZE", "1")) - * self.cfg.micro_batch_size - * self.cfg.gradient_accumulation_steps, - ) - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - transformers.modeling_utils.is_deepspeed_zero3_enabled = lambda: True - transformers.integrations.deepspeed.is_deepspeed_zero3_enabled = ( - lambda: True - ) - - return hf_ds_cfg - - skip_move_to_device = False - if ( # pylint: disable=condition-evals-to-constant) - (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) - and not qlora_fsdp - and False - ): - self.model = load_sharded_model( - self.base_model, - self.model_config, - self.cfg, - torch_dtype=self.cfg.torch_dtype, - ) - skip_move_to_device = True - elif ( - qlora_fsdp - and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and ( - self.cfg.model_config_type == "dbrx" - or self.cfg.qlora_sharded_model_loading - ) - ): - quant_storage = self.cfg.torch_dtype - quantization_config = hasattr( - self.model_config, "quantization_config" - ) and getattr(self.model_config, "quantization_config") - quantization_config = ( - quantization_config or self.model_kwargs["quantization_config"] - ) - self.model = load_sharded_model_quant( - self.base_model, - self.model_config, - self.cfg, - quant_storage=quant_storage, - quantization_config=quantization_config, - ) - skip_move_to_device = True - elif ( - self.model_config.model_type in ["llama", "llama4"] - and not self.cfg.trust_remote_code - and not self.cfg.gptq - ): - # TODO do we need to open this up for all models? - if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: - skip_move_to_device = True - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - _ = _configure_zero3_memory_efficient_loading() - - # Load model with random initialization if specified - if self.cfg.random_init_weights: - # AutoModel classes support the from_config method - if self.auto_model_loader in [ - AutoModelForCausalLM, - AutoModelForVision2Seq, - ]: - self.model = self.auto_model_loader.from_config( - config=self.model_config, - ) - else: - self.model = self.auto_model_loader( - config=self.model_config, - ) - else: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - **self.model_kwargs, - ) - - # TODO (MengqingCao) split these patches seperately - if self.cfg.flash_attention and not self.inference: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - is_xformers_swiglu_available, - replace_llama_mlp_with_swiglu, - replace_llama_qkv_with_fused, - ) - - if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): - LOG.info("patching with SwiGLU") - replace_llama_mlp_with_swiglu(self.model) - - if self.cfg.flash_attn_fuse_qkv: - LOG.info("patching with fused QKV") - replace_llama_qkv_with_fused(self.model) - elif self.model_type == "MambaLMHeadModel": - # FIXME this is janky at best and hacked together to make it work - MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name - - self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] - self.model_kwargs["device"] = torch.cuda.current_device() - del self.model_kwargs["torch_dtype"] - del self.model_kwargs["device_map"] - - self.model = MambaLMHeadModel.from_pretrained( - self.base_model, - **self.model_kwargs, - ) - elif ( - self.model_type - and self.model_type != "AutoModelForCausalLM" - and not self.cfg.trust_remote_code - ): - if self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - self.model = getattr(transformers, self.model_type).from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - if self.cfg.gptq: - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - else: - if ( - self.cfg.fsdp - and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - ): - # disabling either of these two still leads to VRAM spike before setting back down - skip_move_to_device = True - if "device_map" in self.model_kwargs: - del self.model_kwargs["device_map"] - - _ = _configure_zero3_memory_efficient_loading() - - self.model = self.auto_model_loader.from_pretrained( - self.base_model, - config=self.model_config, - trust_remote_code=self.cfg.trust_remote_code or False, - **self.model_kwargs, - ) - if is_deepspeed_zero3_enabled(): - skip_move_to_device = True - - return skip_move_to_device - - def adjust_model_config(self) -> None: - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "max_position_embeddings") - and self.model.config.max_position_embeddings - and self.cfg.sequence_len > self.model.config.max_position_embeddings - ): - LOG.warning( - f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" - ) - self.model.config.max_position_embeddings = self.cfg.sequence_len - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "bos_token_id") - and self.model.config.bos_token_id - and self.model.config.bos_token_id != self.tokenizer.bos_token_id - ): - self.model.config.bos_token_id = self.tokenizer.bos_token_id - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "eos_token_id") - and self.model.config.eos_token_id - and self.model.config.eos_token_id != self.tokenizer.eos_token_id - ): - self.model.config.eos_token_id = self.tokenizer.eos_token_id - - def set_z3_leaf_modules(self) -> None: - from deepspeed.utils import ( # pylint: disable=no-name-in-module - set_z3_leaf_modules, - ) - - if self.cfg.model_config_type in MOE_ARCH_BLOCK: - moe_blocks = MOE_ARCH_BLOCK[self.cfg.model_config_type] - moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks - set_z3_leaf_modules( - self.model, - [ - get_module_class_from_name(self.model, module_name) - for module_name in moe_blocks - ], - ) - - def prepare_model(self, qlora_fsdp: bool) -> None: - skip_prepare_model_for_kbit_training = False - if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora": - # Qwen doesn't play nicely with LoRA if this is enabled - skip_prepare_model_for_kbit_training = True - - loftq_bits = ( - self.cfg.peft - and self.cfg.peft.loftq_config - and self.cfg.peft.loftq_config.loftq_bits - ) - if self.cfg.adapter == "lora" and loftq_bits: - skip_prepare_model_for_kbit_training = True - - if qlora_fsdp or ( - self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - ): - # make sure everything is in the same dtype - skip_prepare_model_for_kbit_training = True - - if is_deepspeed_zero3_enabled(): - skip_prepare_model_for_kbit_training = True - - if ( - not skip_prepare_model_for_kbit_training - and self.cfg.adapter in ["lora", "qlora"] - and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) - ): - LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") - self.model = prepare_model_for_kbit_training( - self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing - ) - - def convert_embedding_modules_dtype( - self, embedding_modules, dist_dtype, before_kbit_train_or_finetune - ) -> None: - for name, module in self.model.named_modules(): - if "norm" in name: - module.to(dist_dtype) - if before_kbit_train_or_finetune: - if name.endswith(".gate"): - module.to(dist_dtype) - if self.model_config.model_type == "btlm": - # don't upcast lm_head for btlm - continue - if any(m in name for m in embedding_modules): - if hasattr(module, "weight"): - module.to(dist_dtype) - - # TODO: Deprecate this. - def apply_unsloth_lora_patch(self) -> None: - if self.cfg.unsloth_lora_mlp: - from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch - - integrate_lora_mlp_patch(self.model) - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import integrate_lora_patch - - integrate_lora_patch(self.model, self.cfg) - if self.cfg.unsloth_rope: - from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings - - integrate_rope_embeddings() - - def apply_lora_patch(self) -> None: - if ( - self.cfg.lora_mlp_kernel - or self.cfg.lora_qkv_kernel - or self.cfg.lora_o_kernel - ): - from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches - - apply_lora_kernel_patches(self.model, self.cfg) - - def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - self.apply_patches() - self.set_auto_model_loader() - self.set_device_map_config() - if self.cfg.revision_of_model: - self.model_kwargs["revision"] = self.cfg.revision_of_model - self.set_quantization_config() - self.set_attention_config() - - qlora_fsdp = self.cfg.fsdp and self.cfg.adapter == "qlora" - skip_move_to_device = False - - try: - skip_move_to_device = self.build_model(qlora_fsdp) - PLUGIN_MANAGER.post_model_build(self.cfg, self.model) - except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err) - raise err - - if isinstance(self.model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: - self.model = self.model.merge_and_unload() - - embeddings_len = ( - math.ceil(len(self.tokenizer) / 32) * 32 - if self.cfg.resize_token_embeddings_to_32x - else len(self.tokenizer) - ) - if hasattr(self.model, "get_input_embeddings") and ( - self.model.get_input_embeddings().num_embeddings < embeddings_len - or ( - self.model.get_input_embeddings().num_embeddings > embeddings_len - and self.cfg.shrink_embeddings - ) - ): - resize_kwargs = {} - if self.cfg.mean_resizing_embeddings is not None and not ( - self.model_config.model_type == "llava" - ): - resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings - self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) - else: - self.model.tie_weights() - - self.adjust_model_config() - - # log device memory usage - if hasattr(self.model, "device") and self.model.device.type in ( - "cuda", - "mps", - "npu", - ): - log_gpu_memory_usage(LOG, "after model load", self.model.device) - - # make sure these are fp32 per Ramesh et al. (2021) - embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type) - if not self.cfg.fsdp: - # we don't run this during FSDP because this will leave mixed - # float and bfloat16 dtypes in the model which FSDP doesn't like - if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast: - embedding_modules = [] - self.convert_embedding_modules_dtype( - embedding_modules, - dist_dtype=torch.float32, - before_kbit_train_or_finetune=True, - ) - - if is_deepspeed_zero3_enabled(): - self.set_z3_leaf_modules() - - needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp - if self.cfg.adapter in ["lora", "qlora"]: - needs_fa2_dtype = True - if self.cfg.gradient_checkpointing: - self.model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs - ) - - self.prepare_model(qlora_fsdp) - - should_convert = ( - # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to - # convert them back to fp16/bf16 for flash-attn compatibility. - ( - (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) - and not qlora_fsdp - ) - or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass - ) - - if should_convert: - LOG.info("Converting modules to %s", self.cfg.torch_dtype) - self.convert_embedding_modules_dtype( - embedding_modules=embedding_modules, - dist_dtype=self.cfg.torch_dtype, - before_kbit_train_or_finetune=False, - ) - - PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model) - - # --------------------------------------------------------- - # load lora or adapter - # --------------------------------------------------------- - lora_config = None - if not self.reference_model or self.cfg.lora_model_dir: - # if we're not loading the reference model, then we're loading the model for training - # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if ( - self.cfg.adapter - and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO] - and not self.cfg.merge_lora - ): - _, lora_config = load_lora( - self.model, self.cfg, inference=False, config_only=True - ) - else: - self.model, lora_config = load_adapter( - self.model, self.cfg, self.cfg.adapter - ) - - # --------------------------------------------------------- - # put model to accelerator - # --------------------------------------------------------- - if ( - self.cfg.ddp - and not self.cfg.load_in_8bit - and not (self.cfg.rl and self.cfg.load_in_4bit) - and not skip_move_to_device - ): - # TODO revaldate this conditional - self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") - - if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: - setattr(self.model, "is_parallelizable", True) - setattr(self.model, "model_parallel", True) - - # --------------------------------------------------------- - # parameters that require gradient updates - # --------------------------------------------------------- - requires_grad = [] - for name, param in self.model.named_parameters(recurse=True): - if param.requires_grad: - requires_grad.append(f"{name}: {param.requires_grad}") - if len(requires_grad) == 0: - LOG.warning("there are no parameters that require gradient updates") - - if self.cfg.flash_optimum: - from optimum.bettertransformer import BetterTransformer - - self.model = BetterTransformer.transform(self.model) - - if self.cfg.adapter is not None: - log_gpu_memory_usage(LOG, "after adapters", self.model.device) - - self.apply_unsloth_lora_patch() - self.apply_lora_patch() - - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - - PLUGIN_MANAGER.post_model_load(self.cfg, self.model) - return self.model, lora_config - - -def load_model( - cfg: DictDefault, - tokenizer: PreTrainedTokenizerBase, - *, - processor: ProcessorMixin = None, # pylint: disable=unused-argument - inference: bool = False, - reference_model: bool = False, - **kwargs, # pylint: disable=unused-argument -) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: - """ - Load a model for a given configuration and tokenizer. - """ - model_loader = ModelLoader( - cfg, - tokenizer, - processor=processor, - inference=inference, - reference_model=reference_model, - **kwargs, - ) - return model_loader.load_model() - - -def load_adapter(model, cfg, adapter, inference=False): - # type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - - if adapter is None: - return model, None - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - if adapter in ["lora", "qlora"]: - model, lora_config = load_lora(model, cfg, inference=inference) - PLUGIN_MANAGER.post_lora_load(cfg, model) - return model, lora_config - if adapter == "llama-adapter": - model, lora_config = load_llama_adapter(model, cfg) - PLUGIN_MANAGER.post_lora_load(cfg, model) - return model, lora_config - - raise NotImplementedError(f"{adapter} peft adapter not available") - - -def load_llama_adapter(model, cfg): - # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] - from peft import AdaptionPromptConfig, get_peft_model - - peft_config = AdaptionPromptConfig( - adapter_layers=cfg.peft_adapter.layers, # layers (L) - adapter_len=cfg.peft_adapter.len, # prompt length (K) - task_type="CAUSAL_LM", - ) - - if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - llama_adapter") - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - torch_dtype=torch.float16, - ) - else: - model = get_peft_model(model, peft_config) - - model.print_trainable_parameters() - - return model, peft_config - - -def find_all_linear_names(model): - cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) - lora_module_names = set() - for name, module in model.named_modules(): - if ( - isinstance(module, cls) - or "Linear" in module.__class__.__name__ - and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) - ): - names = name.split(".") - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) - - embedding_modules = get_linear_embedding_layers(model.config.model_type) - output_embedding = embedding_modules[1] - if output_embedding in lora_module_names: # needed for 16-bit - lora_module_names.remove(output_embedding) - - return list(lora_module_names) - - -def setup_quantized_meta_for_peft(model: nn.Module): - """Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" - - def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument - return self - - for param in model.parameters(): - if isinstance(param, Params4bit): - param.quant_state._orig_to = ( # pylint: disable=protected-access - param.quant_state.to - ) - param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) - - -def setup_quantized_peft_meta_for_training(model: nn.Module): - """Replaces dummy `quant_state.to` method with the original function to allow training to continue""" - for param in model.parameters(): - if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): - param.quant_state.to = ( - param.quant_state._orig_to # pylint: disable=protected-access - ) - param.quant_state._orig_to = None # pylint: disable=protected-access - - -def load_lora(model, cfg, inference=False, config_only=False): - # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]] - - from peft import LoraConfig, get_peft_model - - lora_target_modules = cfg.lora_target_modules or [] - - if cfg.lora_target_linear: - linear_names = find_all_linear_names(model) - LOG.info(f"found linear modules: {repr(sorted(linear_names))}") - lora_target_modules_as_list = ( - lora_target_modules - if isinstance(lora_target_modules, list) - else [lora_target_modules] - ) - lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) - - lora_config_kwargs = {} - loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits - if loftq_bits: - lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) - lora_config_kwargs["init_lora_weights"] = "loftq" - if cfg.peft_init_lora_weights: - lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights - if cfg.peft_use_dora: - lora_config_kwargs["use_dora"] = cfg.peft_use_dora - LOG.info("Initializing LoRA weights using dora. This might take longer.") - if cfg.peft_use_rslora: - lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora - if cfg.peft_layer_replication: - lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication - - lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, - target_modules=lora_target_modules, - layers_to_transform=cfg.peft_layers_to_transform, - layers_pattern=cfg.peft_layers_pattern, - lora_dropout=cfg.lora_dropout, - fan_in_fan_out=cfg.lora_fan_in_fan_out, - modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, - bias="none", - task_type="CAUSAL_LM", - **lora_config_kwargs, - ) - - if config_only: - return None, lora_config - - rank = int(os.environ.get("LOCAL_RANK", 0)) - - if ( - cfg.fsdp - and cfg.adapter - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and rank != 0 - ): - setup_quantized_meta_for_peft(model) - - if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - LoRA") - model_kwargs: Any = {} - if cfg.lora_on_cpu: - model_kwargs["max_memory"] = {"cpu": "256GiB"} - model_kwargs["device_map"] = {"": "cpu"} - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - is_trainable=(not inference), - **model_kwargs, - ) - else: - model = get_peft_model(model, lora_config) - - if rank == 0: - try: - model.print_trainable_parameters() - except AttributeError as exc: - LOG.warning( - "Exception caught during model.print_trainable_parameters(): %s", exc - ) - elif ( - cfg.fsdp - and cfg.adapter - and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading - and rank != 0 - ): - setup_quantized_peft_meta_for_training(model) - - return model, lora_config - - -def ensure_dtype(model, dtype=torch.bfloat16): - for name, module in model.named_modules(): - weight_mismatch = False - bias_mismatch = False - try: - weight_mismatch = module.weight.dtype != dtype - except AttributeError: - pass - try: - bias_mismatch = module.bias.dtype != dtype - except AttributeError: - pass - - if weight_mismatch: - print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") - if bias_mismatch: - print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") - if weight_mismatch or bias_mismatch: - module.to(dtype) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 8ae9d5c04..cc5f54ac4 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -470,6 +470,16 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_sample_packing_with_s2attn(cls, data): + if data.get("sample_packing") and data.get("s2_attention"): + raise ValueError( + "Received `sample_packing=true` and `s2_attention=true`; however, \ + shifted-sparse attention does not currently support sample packing." + ) + return data + @model_validator(mode="before") @classmethod def check_batch_flattening_fa(cls, data): diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index d1ad273ea..492578c40 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -1,13 +1,11 @@ -""" -unit tests for axolotl.core.trainer_builder -""" +"""Unit tests for axolotl.core.trainer_builder""" import pytest from axolotl.core.trainer_builder import HFRLTrainerBuilder +from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.schemas.enums import RLType @@ -50,7 +48,7 @@ def fixture_tokenizer(cfg): @pytest.fixture(name="model") def fixture_model(cfg, tokenizer): - return load_model(cfg, tokenizer) + return ModelLoader(cfg, tokenizer).load() class TestHFRLTrainerBuilder: diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 26090e697..5ea88b001 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -6,9 +6,9 @@ import unittest import transformers +from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer from ..utils import with_temp_dir @@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer, inference=False) + ModelLoader(cfg, tokenizer, inference=False).load() @with_temp_dir def test_mistral_multipack(self, temp_dir): @@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase): cfg = validate_config(cfg) normalize_config(cfg) tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer, inference=False) + ModelLoader(cfg, tokenizer, inference=False).load() assert ( "torch.jit" diff --git a/tests/e2e/test_load_model.py b/tests/e2e/test_load_model.py index 96745c040..5061945b4 100644 --- a/tests/e2e/test_load_model.py +++ b/tests/e2e/test_load_model.py @@ -6,8 +6,8 @@ import tempfile import pytest import torch +from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.dict import DictDefault -from axolotl.utils.models import ModelLoader, load_model, load_tokenizer @pytest.fixture(name="temp_dir") @@ -58,6 +58,8 @@ class TestLoadModelUtils: ModelLoader( cfg=self.cfg, tokenizer="", + inference=False, + reference_model=True, ) ) @@ -71,13 +73,8 @@ class TestLoadModelUtils: ): self.cfg.output_dir = temp_dir self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all - self.model_loader.model, _ = load_model( - self.cfg, - self.model_loader.tokenizer, - inference=False, - reference_model=True, - ) - self.model_loader.convert_embedding_modules_dtype( + self.model_loader.load() + self.model_loader._convert_embedding_modules_dtype( embedding_modules, dist_dtype, before_kbit_train_or_finetune ) for name, module in self.model_loader.model.named_modules(): diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 683db61b2..1c7325dff 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -9,11 +9,11 @@ from typing import Optional import pytest from pydantic import ValidationError +from axolotl.loaders.utils import check_model_config from axolotl.utils import is_comet_available from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.models import check_model_config from axolotl.utils.schemas.config import AxolotlConfigWCapabilities from axolotl.utils.wandb_ import setup_wandb_env_vars @@ -1215,6 +1215,20 @@ class TestValidation(BaseValidation): cfg, capabilities=capabilities, env_capabilities=env_capabilities ) + def test_cfg_throws_error_with_s2_attention_and_sample_packing(self, minimal_cfg): + test_cfg = DictDefault( + { + "s2_attention": True, + "sample_packing": True, + } + | minimal_cfg + ) + with pytest.raises( + ValidationError, + match=r".*shifted-sparse attention does not currently support sample packing*", + ): + validate_config(test_cfg) + class TestTorchCompileValidation(BaseValidation): """ diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 1d41a248d..29672c9e5 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -1,7 +1,8 @@ -""" -Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function. +"""Test suite for functions in the `axolotl.utils.data.utils` module, focusing on the +`deduplicate_and_log_datasets` function. -Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command. +Additionally, this test suite includes tests for functions that indirectly call +`deduplicate_and_log_datasets` during the execution of the preprocess command. """ import hashlib @@ -11,20 +12,19 @@ from unittest.mock import patch import pytest from datasets import Dataset +from axolotl.loaders import load_processor, load_tokenizer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.utils import deduplicate_and_log_datasets from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_processor, load_tokenizer from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION from tests.hf_offline_utils import enable_hf_offline def verify_deduplication(actual_dataset, expected_dataset, dataset_name): - """ - Validates deduplication results and size consistency. + """Validates deduplication results and size consistency. Parameters: - actual_dataset: Deduplicated dataset. @@ -49,9 +49,7 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name): class TestDeduplicateIndividualFunctions(unittest.TestCase): - """ - test class for deduplication function in data utils - """ + """Test class for deduplication function in data utils""" def setUp(self): # Sample data with duplicates @@ -248,7 +246,7 @@ class TestDeduplicateRLDataset: # pylint: disable=duplicate-code with ( patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, - patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, + patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls mock_load_dataset.side_effect = [ @@ -272,7 +270,7 @@ class TestDeduplicateRLDataset: # pylint: disable=duplicate-code with ( patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, - patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, + patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls mock_load_dataset.side_effect = [ @@ -411,7 +409,7 @@ class TestDeduplicateNonRL(unittest.TestCase): class TestWrongCollisions(unittest.TestCase): - """Creating mock datasets for testing wrong collisions""" + """Creating mock datasets for testing wrong collisions.""" def setUp(self): self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]} diff --git a/tests/utils/test_models.py b/tests/test_loaders.py similarity index 83% rename from tests/utils/test_models.py rename to tests/test_loaders.py index bcc1ba5d1..7313a8267 100644 --- a/tests/utils/test_models.py +++ b/tests/test_loaders.py @@ -1,18 +1,18 @@ -"""Module for testing models utils file.""" +"""Module for `axolotl.loaders`.""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.utils.import_utils import is_torch_mps_available +from axolotl.loaders import ModelLoader from axolotl.utils.dict import DictDefault -from axolotl.utils.models import ModelLoader, load_model class TestModelsUtils: - """Testing module for models utils.""" + """Testing module for `axolotl.loaders`.""" def setup_method(self) -> None: # load config @@ -50,7 +50,8 @@ class TestModelsUtils: device_map = self.cfg.device_map if is_torch_mps_available(): device_map = "mps" - self.model_loader.set_device_map_config() + # pylint: disable=protected-access + self.model_loader._set_device_map_config() if is_deepspeed_zero3_enabled(): assert "device_map" not in self.model_loader.model_kwargs else: @@ -59,29 +60,6 @@ class TestModelsUtils: # check torch_dtype assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"] - def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): - cfg = DictDefault( - { - "s2_attention": True, - "sample_packing": True, - "base_model": "", - "model_type": "AutoModelForCausalLM", - } - ) - - # Mock out call to HF hub - with patch( - "axolotl.utils.models.load_model_config" - ) as mocked_load_model_config: - mocked_load_model_config.return_value = {} - with pytest.raises(ValueError) as exc: - # Should error before hitting tokenizer, so we pass in an empty str - load_model(cfg, tokenizer="") # type: ignore - assert ( - "shifted-sparse attention does not currently support sample packing" - in str(exc.value) - ) - @pytest.mark.parametrize("adapter", ["lora", "qlora", None]) @pytest.mark.parametrize("load_in_8bit", [True, False]) @pytest.mark.parametrize("load_in_4bit", [True, False]) @@ -99,7 +77,8 @@ class TestModelsUtils: self.cfg.gptq = gptq self.cfg.adapter = adapter - self.model_loader.set_quantization_config() + # pylint: disable=protected-access + self.model_loader._set_quantization_config() if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq: assert not ( hasattr(self.model_loader.model_kwargs, "load_in_8bit") diff --git a/tests/test_lora.py b/tests/test_lora.py index 540371bef..6edcdd88e 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -2,9 +2,9 @@ tests for loading loras """ +from axolotl.loaders import ModelLoader, load_tokenizer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer # pylint: disable=duplicate-code minimal_config = DictDefault( @@ -46,7 +46,7 @@ class TestLoRALoad: cfg = validate_config(cfg) normalize_config(cfg) tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer) + ModelLoader(cfg, tokenizer).load() def test_load_lora_weights_empty_dropout(self): cfg = DictDefault( @@ -67,4 +67,4 @@ class TestLoRALoad: normalize_config(cfg) assert cfg.lora_dropout == 0.0 tokenizer = load_tokenizer(cfg) - load_model(cfg, tokenizer) + ModelLoader(cfg, tokenizer).load() diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index ffd51bc29..406462038 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -6,8 +6,8 @@ import unittest import pytest +from axolotl.loaders import load_tokenizer from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_tokenizer from tests.hf_offline_utils import enable_hf_offline diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 From a535b68043d25c892ab554622621c927f53027e6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 23 May 2025 16:28:31 -0400 Subject: [PATCH 17/19] update quarto for model loading refactor (#2716) * update quarto for model loading refactor * fix desc --- _quarto.yml | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/_quarto.yml b/_quarto.yml index c09aecaea..a530e380a 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -54,6 +54,15 @@ quartodoc: - core.trainers.grpo.trainer - core.trainers.grpo.sampler - core.trainers.utils + - title: Model Loading + desc: Functionality for loading and patching models, tokenizers, etc. + contents: + - loaders.model + - loaders.tokenizer + - loaders.processor + - loaders.adapter + - loaders.patch_manager + - loaders.constants - title: Mixins desc: Mixin classes for augmenting trainers contents: @@ -123,11 +132,9 @@ quartodoc: - title: Utils desc: Utility functions contents: - - utils.models - utils.tokenization - utils.chat_templates - utils.lora - - utils.lora_embeddings - utils.model_shard_quant - utils.bench - utils.freeze From d27c35ac445e7da68981006b2437abdba01b1f90 Mon Sep 17 00:00:00 2001 From: xzuyn <16216325+xzuyn@users.noreply.github.com> Date: Fri, 23 May 2025 18:40:43 -0400 Subject: [PATCH 18/19] Liger GraniteMoE (#2715) --- src/axolotl/integrations/liger/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 22d988633..c7ac42372 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -175,6 +175,16 @@ class LigerPlugin(BasePlugin): rms_norm=cfg.liger_rms_norm, layer_norm=cfg.liger_layer_norm, ) + elif cfg.model_config_type == "granitemoe": + from liger_kernel.transformers import apply_liger_kernel_to_granite + + apply_liger_kernel_to_granite( + rope=cfg.liger_rope, + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + rms_norm=cfg.liger_rms_norm, + swiglu=cfg.liger_glu_activation, + ) else: logging.warning( f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." From 5eb01f3df194ce6d663cebf2de8f3fb8fe7ec8e0 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 23 May 2025 21:16:51 -0400 Subject: [PATCH 19/19] Fix quarto (#2717) * missing modules * fix quarto complaints --- _quarto.yml | 4 +- docs/getting-started.qmd | 2 +- src/axolotl/integrations/base.py | 46 ++++--- src/axolotl/utils/samplers/multipack.py | 158 ++++++++++++------------ 4 files changed, 106 insertions(+), 104 deletions(-) diff --git a/_quarto.yml b/_quarto.yml index a530e380a..df6992d92 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -129,6 +129,8 @@ quartodoc: - monkeypatch.attention.mllama - monkeypatch.data.batch_dataset_fetcher - monkeypatch.mixtral + - monkeypatch.gradient_checkpointing.offload_cpu + - monkeypatch.gradient_checkpointing.offload_disk - title: Utils desc: Utility functions contents: @@ -145,8 +147,6 @@ quartodoc: - utils.optimizers.adopt - utils.data.pretraining - utils.data.sft - - utils.gradient_checkpointing.offload_cpu - - utils.gradient_checkpointing.offload_disk - title: Schemas desc: Pydantic data models for Axolotl config contents: diff --git a/docs/getting-started.qmd b/docs/getting-started.qmd index 064985e35..6f1b54348 100644 --- a/docs/getting-started.qmd +++ b/docs/getting-started.qmd @@ -180,7 +180,7 @@ Now that you have the basics, you might want to: Check our other guides for details on these topics: - [Configuration Guide](config.qmd) - Full configuration options -- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources +- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources - [Dataset Formats](dataset-formats) - Working with different data formats - [Multi-GPU Training](multi-gpu.qmd) - [Multi-Node Training](multi-node.qmd) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 2beaf667a..eb2b29cbe 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -39,31 +39,39 @@ if TYPE_CHECKING: class BasePlugin: """Base class for all plugins. Defines the interface for plugin methods. - Methods: - register(cfg): Registers the plugin with the given configuration. - load_datasets(cfg): Loads and preprocesses the dataset for training. - pre_model_load(cfg): Performs actions before the model is loaded. - post_model_build(cfg, model): Performs actions after the model is loaded, but + A plugin is a reusable, modular, and self-contained piece of code that extends + the functionality of Axolotl. Plugins can be used to integrate third-party models, + modify the training process, or add new features. + + To create a new plugin, you need to inherit from the BasePlugin class and + implement the required methods. + + Note: + Plugin methods include: + - register(cfg): Registers the plugin with the given configuration. + - load_datasets(cfg): Loads and preprocesses the dataset for training. + - pre_model_load(cfg): Performs actions before the model is loaded. + - post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied. - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. - post_model_load(cfg, model): Performs actions after the model is loaded, + - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. + - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. + - post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. - post_trainer_create(cfg, trainer): Performs actions after the trainer is + - post_trainer_create(cfg, trainer): Performs actions after the trainer is created. - create_optimizer(cfg, trainer): Creates and returns an optimizer for training. - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and + - create_optimizer(cfg, trainer): Creates and returns an optimizer for training. + - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler. - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before + - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after + - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training. """ def __init__(self): """Initializes the BasePlugin.""" - def register(self, cfg): # pylint: disable=unused-argument + def register(self, cfg: DictDefault): # pylint: disable=unused-argument """Registers the plugin with the given configuration. Args: @@ -275,10 +283,11 @@ class PluginManager: Attributes: plugins: A list of loaded plugins. - Methods: - get_instance(): Static method to get the singleton instance of `PluginManager`. - register(plugin_name: str): Registers a new plugin by its name. - pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. + Note: + Key methods include: + - get_instance(): Static method to get the singleton instance of `PluginManager`. + - register(plugin_name: str): Registers a new plugin by its name. + - pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. """ plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() @@ -534,7 +543,6 @@ class PluginManager: Args: cfg: The configuration for the plugins. - model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_train_unload(cfg) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 2df2d9e19..1bfa2ec6e 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -7,7 +7,7 @@ import logging import math from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context -from typing import Iterable, Union +from typing import Iterable, Iterator, Union import numba import numpy as np @@ -20,19 +20,19 @@ LOG.setLevel(logging.INFO) @numba.njit -def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int): - """ - First-fit-decreasing bin packing algorithm check +def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int) -> bool: + """First-fit-decreasing bin packing algorithm check. - Checks if sequences with the given lengths could fit in the specified number of bins + Checks if sequences with the given lengths could fit in the specified number of + bins. Args: - sequence_lengths: Array of sequence lengths - bin_capacity: Maximum capacity of each bin - num_bins: Number of bins available + sequence_lengths: Array of sequence lengths. + bin_capacity: Maximum capacity of each bin. + num_bins: Number of bins available. Returns: - True if all sequences can be packed, False otherwise + `True` if all sequences can be packed, `False` otherwise. """ # Sort sequence lengths in descending order for optimal packing sequence_lengths = np.sort(sequence_lengths)[::-1] @@ -63,20 +63,19 @@ def pack_group( max_bins: int, bin_size: int, safe_mode: bool = True, -): - """ - Pack a group of sequences into bins using First-Fit Decreasing algorithm +) -> list[list[int]]: + """Pack a group of sequences into bins using First-Fit Decreasing algorithm. Args: - sequence_lengths: Array of sequence lengths - group_offset: Offset to apply to indices when returning results - bin_capacity: Maximum capacity of each bin - max_bins: Maximum number of bins to use - bin_size: Maximum number of sequences per bin - safe_mode: If True, use a more conservative packing approach + sequence_lengths: Array of sequence lengths. + group_offset: Offset to apply to indices when returning results. + bin_capacity: Maximum capacity of each bin. + max_bins: Maximum number of bins to use. + bin_size: Maximum number of sequences per bin. + safe_mode: If True, use a more conservative packing approach. Returns: - List of bins, where each bin contains indices of sequences assigned to it + List of bins, where each bin contains indices of sequences assigned to it. """ bins_remaining_space: list = [] # Tracks remaining capacity in each bin bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin @@ -111,8 +110,10 @@ def pack_group( return bins_assigned_sequences -# Define a standalone function for multiprocessing -def _process_group(args): +def _process_group( + args: tuple[np.ndarray, int, int, int, int, bool], +) -> list[list[int]]: + """Standalone function for multiprocessing.""" group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args return pack_group( group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode @@ -127,22 +128,21 @@ def pack_parallel( num_processes: int | None = None, safe_mode: bool = True, mp_start_method: str | None = "spawn", -): - """ - Pack sequences into bins using parallel processing +) -> list[list[int]]: + """Pack sequences into bins using parallel processing. Args: - sequence_lengths: Array of sequence lengths - bin_capacity: Maximum capacity of each bin as total number of tokens - group_size: Number of sequences to process in each group - bin_size: Maximum number of bins to use - num_processes: Number of parallel processes to use - safe_mode: If True, use a more conservative packing approach + sequence_lengths: Array of sequence lengths. + bin_capacity: Maximum capacity of each bin as total number of tokens. + group_size: Number of sequences to process in each group. + bin_size: Maximum number of bins to use. + num_processes: Number of parallel processes to use. + safe_mode: If True, use a more conservative packing approach. mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver'). 'spawn' is often safer with Numba/PyTorch. Set to None to use system default. Returns: - List of bins, where each bin contains indices of sequences assigned to it + List of bins, where each bin contains indices of sequences assigned to it. """ num_items = len(sequence_lengths) if num_processes is None: @@ -191,20 +191,20 @@ def pack_parallel( @numba.njit def allocate_sequentially( sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int -): - """ - Sequential allocator that preserves example order +) -> tuple[list[list[int]], int, int]: + """Sequential allocator that preserves example order. Args: - sequence_lengths: The lengths of all examples - rank: The current rank (for distributed training) - bin_capacity: The capacity of each bin (maximum sequence length) - num_ranks: Number of ranks (processes/GPUs) + sequence_lengths: The lengths of all examples. + rank: The current rank (for distributed training). + bin_capacity: The capacity of each bin (maximum sequence length). + num_ranks: Number of ranks (processes / GPUs). Returns: - rank_batches: List of batches for the current rank - total_tokens_used: Number of actual example tokens - total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) + rank_batches: List of batches for the current rank. + total_tokens_used: Number of actual example tokens. + total_token_slots: Maximum theoretical number of example tokens (number of bins + * bin capacity). """ result = [] total_used = 0 @@ -240,8 +240,7 @@ def allocate_sequentially( class MultipackBatchSampler(BatchSampler): - """ - Batch sampler class for efficient packing of variable-length sequences + """Batch sampler class for efficient packing of variable-length sequences This sampler packs sequences into fixed-capacity bins (batches) to maximize GPU memory utilization and training throughput by reducing padding. @@ -250,6 +249,9 @@ class MultipackBatchSampler(BatchSampler): sequential packing (preserving original sequence order). """ + _batches: list[list[list[int]]] | None = None + _len_across_ranks: int | None = None + def __init__( self, sampler: Union[Sampler[int], Iterable[int]], @@ -287,11 +289,6 @@ class MultipackBatchSampler(BatchSampler): # The number of times to calculate batches to determine minimum packed dataset length self.num_count_samples = num_count_samples - # Minimum packed dataset length across all ranks (determined by gather/broadcast) - self.len_across_ranks = None - - # Cache for batches - self._batches = None if self.sequential and not isinstance(sampler, SequentialSampler): LOG.warning( @@ -303,16 +300,15 @@ class MultipackBatchSampler(BatchSampler): self.epoch = epoch self._batches = None # Invalidate batch cache - def generate_batches(self, set_stats=False): - """ - Generate packed batches for training + def generate_batches(self, set_stats: bool = False) -> list[list[list[int]]]: + """Generate packed batches for training. Args: - set_stats: Whether to update efficiency statistics + set_stats: Whether to update efficiency statistics. Returns: - List of batches, where each batch contains multiple bins, - and each bin contains multiple sequence indices + List of batches, where each batch contains multiple bins, and each bin + contains multiple sequence indices. """ if self._batches is not None: return self._batches @@ -375,23 +371,21 @@ class MultipackBatchSampler(BatchSampler): self._batches = batches return batches - def __iter__(self): - """ - Return an iterator over batches + def __iter__(self) -> Iterator[list[list[int]]]: + """Return an iterator over batches. - The batches are truncated to match the minimum number of batches across all ranks - to ensure distributed training balance + The batches are truncated to match the minimum number of batches across all + ranks to ensure distributed training balance. """ batches = self.generate_batches(set_stats=True) - if self.len_across_ranks: + if self._len_across_ranks: # Truncate batches to ensure all ranks have the same number of batches - batches = batches[: self.len_across_ranks] + batches = batches[: self._len_across_ranks] return iter(batches) - def efficiency(self): - """ - Calculate the packing efficiency (ratio of tokens used to total token slots) - Higher is better - 1.0 would mean perfect packing with no wasted space + def efficiency(self) -> float: + """Calculate the packing efficiency (ratio of tokens used to total token slots). + Higher is better - 1.0 would mean perfect packing with no wasted space. """ if self.total_token_slots == 0: self.generate_batches(set_stats=True) @@ -400,10 +394,12 @@ class MultipackBatchSampler(BatchSampler): # Return a Python float instead of potentially a numpy float return float(self.total_tokens_used / self.total_token_slots) - def gather_efficiency(self): - """ - Gather and synchronize packing efficiency estimates across all distributed ranks - Returns a conservative efficiency estimate based on the measurements + def gather_efficiency(self) -> float: + """Gather and synchronize packing efficiency estimates across all distributed + ranks. + + Returns: + A conservative efficiency estimate based on the measurements. """ def calc_sample_packing_eff_est(estimates: list[float]): @@ -424,13 +420,12 @@ class MultipackBatchSampler(BatchSampler): ) return sample_packing_eff_est - def gather_len_batches(self, num): - """ - Gather and synchronize batch counts across all distributed ranks - Returns the minimum number of batches available on any rank + def gather_len_batches(self, num: int) -> int: + """Gather and synchronize batch counts across all distributed ranks. Returns + the minimum number of batches available on any rank. """ - def calc_min_len(estimates: list[(int, float)]): + def calc_min_len(estimates: list[int]) -> int: LOG.info(f"gather_len_batches: {repr(estimates)}") return math.floor(min(estimates)) @@ -438,22 +433,21 @@ class MultipackBatchSampler(BatchSampler): min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) return min_len_batches - def __len__(self): - """ - Return the total number of batches that will be yielded by this sampler + def __len__(self) -> int: + """Return the total number of batches that will be yielded by this sampler. - This is calculated as the minimum number of batches available on any rank - to ensure balanced distributed training + This is calculated as the minimum number of batches available on any rank to + ensure balanced distributed training. """ if self._batches is None: self._batches = self.generate_batches(set_stats=True) - if self.len_across_ranks is None: + if self._len_across_ranks is None: # Sample multiple times to get stable estimate len_batches = min( # pylint: disable=consider-using-generator [len(self._batches) for _ in range(self.num_count_samples)] ) # Gather minimum across all ranks - self.len_across_ranks = self.gather_len_batches(len_batches) + self._len_across_ranks = self.gather_len_batches(len_batches) - return self.len_across_ranks + return self._len_across_ranks