From c67910fa6f11ee69659779bc7009ccd40f34f322 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 5 Jun 2025 07:20:33 -0700 Subject: [PATCH] bump hf deps (#2735) [skip ci] * bump hf deps * upgrade liger-kernel too * install cce from fork for transformers fix * fix reference to vocab size in gemma3 patch * use padding_idx instead of pad_token_id * remove fixed gemma3 patch * use updated cce fork * fix local mllama cce patches w docstring * add test for multipack with trainer setup and fix trainer for trainer refactor upstream * bump modal version * guard for iterable datasetS * mllama model arch layout changed in latest transformers * fix batch sampler with drop_last * fix: address upstream vlm changes for lora * fix: update references to old lora target path * fix: remove mllama fa2 patch due to upstream fix * fix: lora kernel patch path for multimodal models * fix: removed mllama from quarto * run test for came optim on 2.6.0+ * fix fsdp2 patch and remove deprecated patch * make sure to set sequence_parallel_degree for grpo * Add SP test for GRPO * add sp to grpo config for trainer * use reward_funcs as kwarg to grpo trainer * fix the comprehension for reward funcs * reward funcs already passed in as args * init sp_group right before training * fix check for adding models to SP context * make sure to pass args to super * upgrade deepspeed * use updated trl and add reasoning flags for vllm * patch the worker --------- Co-authored-by: NanoCode012 --- .github/workflows/multi-gpu-e2e.yml | 2 +- .github/workflows/tests.yml | 6 +- _quarto.yml | 1 - docs/multimodal.qmd | 2 +- examples/gemma3/gemma-3-4b-qlora.yml | 2 +- examples/gemma3/gemma-3-4b-vision-qlora.yml | 2 +- examples/llama-3-vision/lora-11b.yaml | 2 +- examples/llava/lora-7b.yaml | 2 +- .../mistral/mistral-small-3.1-24B-lora.yml | 2 +- examples/pixtral/lora-12b.yml | 2 +- requirements.txt | 16 +- scripts/cutcrossentropy_install.py | 2 +- setup.py | 2 +- src/axolotl/cli/args.py | 8 + src/axolotl/cli/vllm_serve.py | 87 ++++++- src/axolotl/core/builders/causal.py | 4 +- src/axolotl/core/trainers/base.py | 188 ++++++-------- src/axolotl/core/trainers/grpo/__init__.py | 8 +- src/axolotl/core/trainers/grpo/args.py | 2 + src/axolotl/core/trainers/grpo/trainer.py | 9 + .../cut_cross_entropy/monkeypatch/mllama.py | 13 - src/axolotl/loaders/patch_manager.py | 12 - src/axolotl/monkeypatch/accelerate/fsdp2.py | 119 +++++---- src/axolotl/monkeypatch/attention/mllama.py | 230 ------------------ src/axolotl/monkeypatch/gemma3.py | 230 ------------------ src/axolotl/monkeypatch/lora_kernels.py | 9 +- src/axolotl/train.py | 2 +- src/axolotl/utils/schemas/vllm.py | 9 + tests/e2e/integrations/test_kd.py | 2 +- tests/e2e/multigpu/solo/test_grpo.py | 93 +++++++ tests/e2e/test_llama_vision.py | 4 +- tests/e2e/test_optimizers.py | 8 +- tests/test_packed_dataset.py | 85 +++++++ 33 files changed, 470 insertions(+), 695 deletions(-) delete mode 100644 src/axolotl/monkeypatch/attention/mllama.py delete mode 100644 src/axolotl/monkeypatch/gemma3.py diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index 8c7692d13..0167df67a 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -59,7 +59,7 @@ jobs: - name: Install Modal run: | python -m pip install --upgrade pip - pip install modal==0.71.8 jinja2 + pip install modal==1.0.2 jinja2 - name: Update env vars run: | echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 69f0a030d..29c5bef38 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -322,7 +322,7 @@ jobs: - name: Install Modal run: | python -m pip install --upgrade pip - pip install modal==0.71.8 jinja2 + pip install modal==1.0.2 jinja2 - name: Update env vars run: | echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV @@ -384,7 +384,7 @@ jobs: - name: Install Modal run: | python -m pip install --upgrade pip - pip install modal==0.71.8 jinja2 + pip install modal==1.0.2 jinja2 - name: Update env vars run: | echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV @@ -424,7 +424,7 @@ jobs: - name: Install Modal run: | python -m pip install --upgrade pip - pip install modal==0.71.8 jinja2 + pip install modal==1.0.2 jinja2 - name: Update env vars run: | echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV diff --git a/_quarto.yml b/_quarto.yml index a970cd08b..9b97095ce 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -129,7 +129,6 @@ quartodoc: - monkeypatch.trainer_fsdp_optim - monkeypatch.transformers_fa_utils - monkeypatch.unsloth_ - - monkeypatch.attention.mllama - monkeypatch.data.batch_dataset_fetcher - monkeypatch.mixtral - monkeypatch.gradient_checkpointing.offload_cpu diff --git a/docs/multimodal.qmd b/docs/multimodal.qmd index 3506db340..ec51a8ec3 100644 --- a/docs/multimodal.qmd +++ b/docs/multimodal.qmd @@ -43,7 +43,7 @@ datasets: # leave the vision model and vision tower frozen # load_in_8bit: true adapter: lora -lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' # (optional) if you want to resize images to a set size image_size: 512 diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index 29f8cc1e1..0d89d9ffb 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -28,7 +28,7 @@ pad_to_sequence_len: true lora_r: 32 lora_alpha: 16 lora_dropout: 0.05 -lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' wandb_project: wandb_entity: diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml index 3fd9eb5f0..339df92e5 100644 --- a/examples/gemma3/gemma-3-4b-vision-qlora.yml +++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml @@ -30,7 +30,7 @@ pad_to_sequence_len: false lora_r: 32 lora_alpha: 16 lora_dropout: 0.05 -lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' wandb_project: wandb_entity: diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml index f4883e903..2b0ae2c70 100644 --- a/examples/llama-3-vision/lora-11b.yaml +++ b/examples/llama-3-vision/lora-11b.yaml @@ -29,7 +29,7 @@ pad_to_sequence_len: false lora_r: 32 lora_alpha: 16 lora_dropout: 0.05 -lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' wandb_project: wandb_entity: diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml index 54edd04dc..5198c8e74 100644 --- a/examples/llava/lora-7b.yaml +++ b/examples/llava/lora-7b.yaml @@ -25,7 +25,7 @@ pad_to_sequence_len: false lora_r: 32 lora_alpha: 16 lora_dropout: 0.05 -lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' wandb_project: wandb_entity: diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml index 198b3f373..3e3b45862 100644 --- a/examples/mistral/mistral-small-3.1-24B-lora.yml +++ b/examples/mistral/mistral-small-3.1-24B-lora.yml @@ -27,7 +27,7 @@ pad_to_sequence_len: false lora_r: 32 lora_alpha: 16 lora_dropout: 0.05 -lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' wandb_project: wandb_entity: diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml index dec8e4b5e..6ad0a5e99 100644 --- a/examples/pixtral/lora-12b.yml +++ b/examples/pixtral/lora-12b.yml @@ -25,7 +25,7 @@ pad_to_sequence_len: false lora_r: 32 lora_alpha: 16 lora_dropout: 0.05 -lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' wandb_project: wandb_entity: diff --git a/requirements.txt b/requirements.txt index 4e632b0f3..5c5bb0030 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,20 +6,20 @@ triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 autoawq==0.2.7.post3 -liger-kernel==0.5.9 +liger-kernel==0.5.10 # END section packaging==23.2 -huggingface_hub==0.31.0 +huggingface_hub==0.32.2 peft==0.15.2 -transformers==4.51.3 +transformers==4.52.3 tokenizers>=0.21.1 -accelerate==1.6.0 -datasets==3.5.1 -deepspeed>=0.15.4 -trl==0.17.0 -hf_xet==1.1.0 +accelerate==1.7.0 +datasets==3.6.0 +deepspeed>=0.17.0 +trl==0.18.1 +hf_xet==1.1.2 hqq==0.2.5 optimum==1.16.2 diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index bc6213dd9..3ff6dfa8f 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -25,5 +25,5 @@ if cce_spec: print( UNINSTALL_PREFIX - + 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"' + + 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"' ) diff --git a/setup.py b/setup.py index 97e7f5ff5..28f71f789 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,7 @@ extras_require = { "yunchang==0.6.0", ], "deepspeed": [ - "deepspeed==0.15.4", + "deepspeed==0.17.0", "deepspeed-kernels", ], "mamba-ssm": [ diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index b859b99c8..e8571a900 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -88,6 +88,14 @@ class VllmServeCliArgs: }, ) + enable_reasoning: Optional[bool] = field( + default=None, + ) + + reasoning_parser: Optional[str] = field( + default=None, + ) + @dataclass class QuantizeCliArgs: diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py index d3c4ad68d..448b25a7e 100644 --- a/src/axolotl/cli/vllm_serve.py +++ b/src/axolotl/cli/vllm_serve.py @@ -2,14 +2,27 @@ CLI to start the vllm server for online RL """ +import os +from dataclasses import dataclass, field from pathlib import Path from typing import Union +import trl from trl.scripts.vllm_serve import ScriptArguments from axolotl.cli.config import load_cfg +@dataclass +class AxolotlScriptArguments(ScriptArguments): + """ + Additional arguments for the VLLM server + """ + + reasoning_parser: str = field(default="", kw_only=True) + enable_reasoning: bool | None = field(default=None, kw_only=True) + + def do_vllm_serve( config: Union[Path, str], cli_args: dict, @@ -24,6 +37,7 @@ def do_vllm_serve( Returns: process_id: the process id of the started VLLM server """ + patch_vllm_worker() cfg = load_cfg(config) model = cfg.base_model @@ -43,9 +57,16 @@ def do_vllm_serve( enable_prefix_caching = ( cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching ) + reasoning_parser = ( + cli_args.get("reasoning_parser") or cfg.vllm.reasoning_parser or "" + ) + enable_reasoning = ( + cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False + ) - vllm_script_args = ScriptArguments( - model, + # pylint: disable=unexpected-keyword-arg + vllm_script_args = AxolotlScriptArguments( + model=model, tensor_parallel_size=tensor_parallel_size, host=host, port=port, @@ -53,5 +74,67 @@ def do_vllm_serve( dtype=dtype, max_model_len=max_model_len, enable_prefix_caching=enable_prefix_caching, + reasoning_parser=reasoning_parser, + enable_reasoning=enable_reasoning, ) vllm_serve_main(vllm_script_args) + + +def patch_vllm_worker(): + from multiprocessing.connection import Connection + + from vllm import LLM + + def llm_worker( + script_args: AxolotlScriptArguments, + data_parallel_rank: int, + master_port: int, + connection: Connection, + ) -> None: + # Set required environment variables for DP to work with vLLM + os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) + os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) + os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) + + llm = LLM( + model=script_args.model, + revision=script_args.revision, + tensor_parallel_size=script_args.tensor_parallel_size, + gpu_memory_utilization=script_args.gpu_memory_utilization, + enforce_eager=script_args.enforce_eager, + dtype=script_args.dtype, + # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can + # directly reuse the KV cache if it shares the same prefix with one of the existing queries. + # This is particularly useful here because we generate completions from the same prompts. + enable_prefix_caching=script_args.enable_prefix_caching, + kv_cache_dtype=script_args.kv_cache_dtype, + max_model_len=script_args.max_model_len, + worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension", + enable_reasoning=script_args.enable_reasoning, + reasoning_parser=script_args.reasoning_parser, + ) + + # Send ready signal to parent process + connection.send({"status": "ready"}) + + while True: + # Wait for commands from the parent process + try: + command = connection.recv() + except KeyboardInterrupt: + llm.collective_rpc(method="close_communicator") + break + + # Handle commands + if command["type"] in ["call", "fire_and_forget"]: + method_name = command["method"] + args, kwargs = command.get("args", ()), command.get("kwargs", {}) + method = getattr(llm, method_name) + result = method(*args, **kwargs) + if command["type"] == "call": + connection.send(result) + elif command["type"] == "shutdown": + break + + trl.scripts.vllm_serve.llm_worker = llm_worker diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 3b1d0b3c2..7a81616ba 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -244,7 +244,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) training_arguments_kwargs["multipack_real_batches"] = ( - not self.cfg.flash_attention or self.cfg.multipack_real_batches + self.cfg.multipack_real_batches + if self.cfg.multipack_real_batches is not None + else not self.cfg.flash_attention ) training_arguments_kwargs["eval_sample_packing"] = bool( self.cfg.eval_sample_packing diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 25e9f9f0a..70e443cb3 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -6,8 +6,8 @@ from __future__ import annotations import os from collections import defaultdict -from functools import wraps -from typing import Literal +from functools import partial, wraps +from typing import Callable, Literal, Optional import datasets import torch @@ -113,7 +113,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): drop_last=True, ) - def _get_train_sampler(self) -> Sampler | None: + def _get_train_sampler( + self, train_dataset: Optional[Dataset] = None + ) -> Optional[Sampler]: """ Helper method to get the sampler for training. Handles cases for sample packing and curriculum sampling (sequential). @@ -137,7 +139,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): if use_sample_packing: return self._create_multipack_sampler( base_sampler=base_sampler, - dataset=self.train_dataset, + dataset=train_dataset, ) return base_sampler @@ -150,8 +152,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): If the dataset is non-empty, a sampler is returned, the type of which depends on the passed training args. """ - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - # Multipacking enabled if training is enabled and eval is not explicitly disabled use_multipack = ( self.args.sample_packing and self.args.eval_sample_packing is not False @@ -172,125 +172,91 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): return base_sampler - def _create_dataloader_params(self, is_eval=False, custom_batch_size=None): - """Create common dataloader parameters for train or eval.""" - batch_size = custom_batch_size or ( - self.args.eval_batch_size if is_eval else self._train_batch_size - ) + def _get_dataloader( + self, + dataset: Dataset, + description: str, + batch_size: int, + sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None, + is_training: bool = False, + dataloader_key: Optional[str] = None, + ) -> DataLoader: + """Create a [`~torch.utils.data.DataLoader`] from the given dataset.""" - params = { + data_collator = self.data_collator if is_training else self.eval_data_collator + + if dataset.column_names and "length" in dataset.column_names: + dataset = dataset.remove_columns(["length"]) + + if isinstance(dataset, datasets.Dataset): + if is_training: + if not self.args.sample_packing or self.args.pretraining: + dataset = self._remove_unused_columns( + dataset, description="training" + ) + elif ( + not is_training + and self.args.sample_packing + and self.args.eval_sample_packing is not False + ): + batch_size = ( + batch_size + if self.args.sample_packing + else self.args.per_device_eval_batch_size + ) + else: + dataset = self._remove_unused_columns(dataset, description=description) + else: + data_collator = self._get_collator_with_removed_columns( + self.data_collator, description=description + ) + + dataloader_params = { "batch_size": batch_size, - "collate_fn": self.data_collator, + "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, } - # Add persistent workers only for training - if not is_eval and hasattr(self.args, "dataloader_persistent_workers"): - params["persistent_workers"] = self.args.dataloader_persistent_workers - - # Add prefetch factor if specified - if self.args.dataloader_prefetch_factor: - params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - return params - - def _prepare_dataloader( - self, dataset, sampler, is_eval=False, custom_batch_size=None - ): - """Prepare a dataloader with the given dataset and sampler.""" - # Get base parameters - dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size) - - # Add sampler configuration if not isinstance(dataset, torch.utils.data.IterableDataset): - if isinstance(sampler, BatchSampler): - # batch_size and batch_sampler are mutually exclusive - dataloader_params["batch_sampler"] = sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - if not is_eval: - dataloader_params["worker_init_fn"] = seed_worker - - # Create the dataloader - dataloader = DataLoader(dataset, **dataloader_params) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + if sampler_fn is not None: + sampler = sampler_fn(dataset) + if isinstance(sampler, BatchSampler): + # batch_size and batch_sampler are mutually exclusive + dataloader_params["batch_sampler"] = sampler + del dataloader_params["batch_size"] + del dataloader_params["drop_last"] + else: + dataloader_params["sampler"] = sampler + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + if is_training: + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) if self.args.sample_packing and ( - (not is_eval and not self.args.pretraining) - or (is_eval and self.args.eval_sample_packing is not False) + (is_training and not self.args.pretraining) + or (not is_training and self.args.eval_sample_packing is not False) ): self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader(dataloader) + dataloader = DataLoader(dataset, **dataloader_params) - def get_train_dataloader(self) -> DataLoader: - """Get dataloader for training""" - train_dataset = self.train_dataset - data_collator = self.data_collator # type: ignore + # Accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version for eval dataloaders. + # fmt: off + if dataloader_key is not None and self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition + else: + self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init + # fmt: on - # Handle dataset preprocessing - if isinstance(train_dataset, datasets.Dataset): - if self.args.sample_packing and not self.args.pretraining: - train_dataset = train_dataset.remove_columns(["length"]) - if not self.args.sample_packing or self.args.pretraining: - train_dataset = self._remove_unused_columns( - train_dataset, description="training" - ) - else: - self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init - data_collator, - description="training", - ) - - # Get sampler and create dataloader - sampler = self._get_train_sampler() - return self._prepare_dataloader(train_dataset, sampler, is_eval=False) - - def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader: - """Get dataloader for evaluation""" - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - - # Handle special case: sample packing is enabled but eval_sample_packing is False - if self.args.sample_packing and self.args.eval_sample_packing is False: - self.data_collator = ( # pylint: disable=attribute-defined-outside-init - self.eval_data_collator - ) - if "length" in eval_dataset.column_names: - eval_dataset = eval_dataset.remove_columns(["length"]) - dataloader = super().get_eval_dataloader(eval_dataset) - self.data_collator = ( # pylint: disable=attribute-defined-outside-init - self.train_data_collator - ) - - return dataloader - - if self.args.sample_packing and self.args.eval_sample_packing is not False: - # Get appropriate data collator - self.data_collator = ( # pylint: disable=attribute-defined-outside-init - self.eval_data_collator - if hasattr(self, "eval_data_collator") and self.eval_data_collator - else self.data_collator - ) - if "length" in eval_dataset.column_names: - eval_dataset = eval_dataset.remove_columns(["length"]) - - # Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise - batch_size = ( - self.args.eval_batch_size - if self.args.sample_packing - else self.args.per_device_eval_batch_size - ) - sampler = self._get_eval_sampler(eval_dataset) - dataloader = self._prepare_dataloader( - eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size - ) - - return dataloader - - return super().get_eval_dataloader(eval_dataset) + return self.accelerator.prepare(dataloader) def _get_bench_sampler( self, bench_dataset: Dataset diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index a37c8baca..c0f10be23 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -69,6 +69,9 @@ class GRPOStrategy: grpo_args_kwargs["log_completions"] = trl.log_completions grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print + if cfg.sequence_parallel_degree > 1: + grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree + if trl.reward_weights: grpo_args_kwargs["reward_weights"] = trl.reward_weights @@ -106,7 +109,9 @@ class GRPOStrategy: return grpo_args_kwargs @classmethod - def set_trainer_args(cls, cfg: DictDefault) -> list[Any]: + def set_trainer_args( + cls, cfg: DictDefault + ) -> list[Any]: # pylint: disable=unused-argument trainer_args = [] if cfg.trl and cfg.trl.reward_funcs: reward_funcs = [] @@ -123,6 +128,7 @@ class GRPOStrategy: trainer_kwargs["reward_processing_classes"] = ( cfg.trl.reward_processing_classes ) + return trainer_kwargs @classmethod diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index 76be88c89..5c8b1a33b 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -12,3 +12,5 @@ from axolotl.core.training_args import AxolotlTrainingMixins @dataclass class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): """Axolotl GRPO Config for GRPO training""" + + sequence_parallel_degree: int | None = None diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 5c93c69df..dccc85d80 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -136,6 +136,13 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): f"the valid values for the number of generations are: {possible_values}." ) + self.sp_group = None + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.local_rank = 0 + self.local_world_size = 1 + + def train(self, *args, **kwargs): # Initialize the SP group self.sp_group = get_ring_attn_group() self.rank = dist.get_rank() @@ -143,6 +150,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): self.local_rank = dist.get_rank(group=self.sp_group) self.local_world_size = dist.get_world_size(group=self.sp_group) + return super().train(*args, **kwargs) + def _get_train_sampler(self) -> Sampler: effective_batch_size = ( self.args.per_device_train_batch_size diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py index 850764e10..e82853e6c 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py @@ -15,23 +15,14 @@ from cut_cross_entropy.transformers.utils import ( from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.mllama.modeling_mllama import ( - MLLAMA_INPUTS_DOCSTRING, _prepare_cross_attention_mask, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" -) def cce_forward( self, input_ids: torch.LongTensor | None = None, @@ -164,10 +155,6 @@ def cce_forward( @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class="MllamaConfig" -) def cce_forward_multimodal( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 56888b607..23f79d368 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -116,13 +116,6 @@ class PatchManager: patch_llama4_linearized_modeling() - if self.cfg.model_config_type == "gemma3": - from axolotl.monkeypatch.gemma3 import ( - patch_gemma3conditionalgeneration_forward, - ) - - patch_gemma3conditionalgeneration_forward() - def _apply_fp8_patches(self): """Apply patches for FP8 support.""" if self.cfg.fp8: @@ -212,11 +205,6 @@ class PatchManager: if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): return - if self.model_config.model_type == "mllama" and self.cfg.flash_attention: - from axolotl.monkeypatch.attention.mllama import patch_mllama - - patch_mllama() - if self.model_config.model_type == "btlm": from axolotl.monkeypatch.btlm_attn_hijack_flash import ( replace_btlm_attn_with_flash_attn, diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 6a7d48236..955c06cbe 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -18,27 +18,65 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic Args: accelerator (`Accelerator`): The accelerator instance - model (`torch.nn.Module`): The model to load the state dict into + model (`torch.nn.Module`): + The model to load the state dict into, expected to be on meta device or a VRAM spike can occur full_sd (`dict`): The full state dict to load, can only be on rank 0 """ import torch.distributed as dist from torch.distributed.tensor import distribute_tensor - LOG.info("Broadcasting full state dict to all ranks...") - sharded_sd = model.state_dict() - param_names = sorted(sharded_sd.keys()) + # Model was previously copied to meta device + meta_sharded_sd = model.state_dict() + sharded_sd = {} + + # Rank 0 distributes the full state dict to other ranks + def _infer_parameter_dtype(model, param_name, empty_param): + try: + old_param = model.get_parameter_or_buffer(param_name) + except AttributeError: + # Need this for LORA, as there some params are not *parameters* of sorts + base_param_name, local_param_name = param_name.rsplit(".", 1) + submodule = model.get_submodule(base_param_name) + old_param = getattr(submodule, local_param_name) + + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + casting_dtype = None + is_param_float8_e4m3fn = ( + is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + ) + + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + casting_dtype = old_param.dtype + + return old_param is not None and old_param.is_contiguous(), casting_dtype + + def _cast_and_contiguous(tensor, to_contiguous, dtype): + if dtype is not None: + tensor = tensor.to(dtype=dtype) + if to_contiguous: + tensor = tensor.contiguous() + return tensor + + param_names = sorted(meta_sharded_sd.keys()) + for param_name in param_names: - mesh = sharded_sd[param_name].device_mesh + mesh = meta_sharded_sd[param_name].device_mesh if accelerator.is_main_process: - # Use the corresponding tensor from full_sd (assuming the key exists in full_sd) full_param = full_sd[param_name].detach().cuda() dist.broadcast(full_param, src=0, group=mesh.get_group()) sharded_tensor = distribute_tensor( full_param, mesh, sharded_sd[param_name].placements ) + to_contiguous, casting_dtype = _infer_parameter_dtype( + model, + param_name, + full_param, + ) + sharded_tensor = _cast_and_contiguous( + sharded_tensor, to_contiguous, casting_dtype + ) sharded_sd[param_name] = sharded_tensor else: - # Prepare a tensor of matching shape and dtype full_tensor = torch.empty( sharded_sd[param_name].size(), device="cuda", @@ -48,57 +86,19 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic sharded_tensor = distribute_tensor( full_tensor, mesh, sharded_sd[param_name].placements ) + to_contiguous, casting_dtype = _infer_parameter_dtype( + model, + param_name, + full_tensor, + ) + sharded_tensor = _cast_and_contiguous( + sharded_tensor, to_contiguous, casting_dtype + ) sharded_sd[param_name] = sharded_tensor + # we set `assign=True` because our params are on meta device model.load_state_dict(sharded_sd, assign=True) - - -def set_state_dict_type(self, state_dict_type=None): - """ - Set the state dict config based on the `StateDictType`. - """ - import os - - from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullOptimStateDictConfig, - FullStateDictConfig, - ShardedOptimStateDictConfig, - ShardedStateDictConfig, - StateDictType, - ) - - # Override the state_dict_type if provided, typical use case: - # user trains with sharded, but final save is with full - if state_dict_type is not None: - self.state_dict_type = state_dict_type - - if self.state_dict_type is None: - self.state_dict_type = os.environ.get( - "FSDP_STATE_DICT_TYPE", - "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT", - ) - if isinstance(self.state_dict_type, str): - if self.state_dict_type.isdigit(): - self.state_dict_type = StateDictType(int(self.state_dict_type)) - else: - self.state_dict_type = StateDictType[self.state_dict_type.upper()] - - if self.state_dict_type == StateDictType.FULL_STATE_DICT: - if self.state_dict_config is None: - self.state_dict_config = FullStateDictConfig( - offload_to_cpu=True, rank0_only=True - ) - if self.optim_state_dict_config is None: - self.optim_state_dict_config = FullOptimStateDictConfig( - offload_to_cpu=True, rank0_only=True - ) - elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT: - if self.state_dict_config is None: - self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) - if self.optim_state_dict_config is None: - self.optim_state_dict_config = ShardedOptimStateDictConfig( - offload_to_cpu=True - ) + return model def get_state_dict(self, model, unwrap=True): @@ -208,12 +208,3 @@ def patch_accelerate_fsdp2(): "Accelerator.get_state_dict", get_state_dict, ) - - accelerate.utils.dataclasses.FullyShardedDataParallelPlugin.set_state_dict_type = ( - set_state_dict_type - ) - setattr( - sys.modules["accelerate.utils.dataclasses"], - "FullyShardedDataParallelPlugin.set_state_dict_type", - set_state_dict_type, - ) diff --git a/src/axolotl/monkeypatch/attention/mllama.py b/src/axolotl/monkeypatch/attention/mllama.py deleted file mode 100644 index c9e8fb5e1..000000000 --- a/src/axolotl/monkeypatch/attention/mllama.py +++ /dev/null @@ -1,230 +0,0 @@ -""" -Monkeypatch for Vision Llama for FA2 support -""" - -# pylint: disable=duplicate-code - -from typing import Optional, Tuple - -import torch -from flash_attn.flash_attn_interface import flash_attn_func -from transformers.cache_utils import Cache -from transformers.modeling_flash_attention_utils import _flash_attention_forward -from transformers.models.mllama.configuration_mllama import MllamaTextConfig -from transformers.models.mllama.modeling_mllama import ( - MllamaTextCrossAttention, - MllamaTextSelfAttention, - apply_rotary_pos_emb, - repeat_kv, -) -from transformers.utils import is_flash_attn_greater_or_equal_2_10 - - -class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention): - """ - Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and - implements the forward pass using Flash Attention for improved performance. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # Check if flash attention version is greater or equal to 2.1 - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[ # pylint: disable=unused-argument - torch.Tensor - ] = None, - output_attentions: bool = False, - use_cache: bool = False, # pylint: disable=unused-argument - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - query_states = self.q_norm(query_states) - - if cross_attention_states is not None: - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, -1, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = self.k_norm(key_states) - if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, - value_states, - self.layer_idx, - {"cache_position": cache_position}, - ) - elif cache_position[0] != 0: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) - else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - - # Transpose to get the expected layout for flash attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # Apply Flash Attention - dropout_rate = self.dropout if self.training else 0.0 - output = flash_attn_func( - query_states, - key_states, - value_states, - dropout_p=dropout_rate, - softmax_scale=None, - causal=False, - return_attn_probs=output_attentions, - ) - - attn_output = output.contiguous().view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention): - """ - Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and - implements the forward pass using Flash Attention for improved performance. - """ - - def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs): - super().__init__(config, layer_idx, *args, **kwargs) - - # Check if flash attention version is greater or equal to 2.1 - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, # pylint: disable=unused-argument - past_key_value=None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, # pylint: disable=unused-argument - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x num_heads x head_dim - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin - ) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # Transpose to get the expected layout for flash attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.dropout if self.training else 0.0 - - # Handle potential silent casting to float32 - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = ( - self.config._pre_quantization_dtype # pylint: disable=protected-access - ) - else: - target_dtype = self.q_proj.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=True, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def patch_mllama(): - from transformers.models.mllama.modeling_mllama import ( - MLLAMA_TEXT_ATTENTION_CLASSES, - MLLAMA_TEXT_CROSS_ATTENTION_CLASSES, - MLLAMA_VISION_ATTENTION_CLASSES, - MllamaPreTrainedModel, - ) - - MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access - True - ) - MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2 - MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = ( - MllamaTextCrossFlashAttention2 - ) - # fallback to SDPA - MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = ( - MLLAMA_VISION_ATTENTION_CLASSES["sdpa"] - ) diff --git a/src/axolotl/monkeypatch/gemma3.py b/src/axolotl/monkeypatch/gemma3.py deleted file mode 100644 index 36f591efd..000000000 --- a/src/axolotl/monkeypatch/gemma3.py +++ /dev/null @@ -1,230 +0,0 @@ -"""Monkeypatch for gemma3 conditional generation forward to fix loss exploding""" - -# pylint: disable=duplicate-code - -from typing import Optional, Tuple, Union - -import torch -from transformers.cache_utils import Cache -from transformers.models.gemma3.modeling_gemma3 import ( - Gemma3CausalLMOutputWithPast, - logger, -) -from transformers.utils import ( - is_torchdynamo_compiling, -) -from transformers.utils.deprecation import deprecate_kwarg - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def new_forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, -) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration - - >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") - >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") - - >>> prompt = "answer en Where is the cow standing?" - >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "answer en Where is the cow standing?\nbeach" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - is_training = token_type_ids is not None and labels is not None - - # Replace image id with PAD if the image token is OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_index >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_index - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - cache_position = torch.arange( - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor( - self.config.image_token_index, - dtype=torch.long, - device=inputs_embeds.device, - ) - ) - else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( - -1 - ) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - - if ( - not is_torchdynamo_compiling() - and inputs_embeds[special_image_mask].numel() != image_features.numel() - ): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: - logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", - ) - labels = torch.where( - input_ids == self.pad_token_id, self.config.ignore_index, labels - ) - - causal_mask = self._update_causal_mask( # pylint: disable=protected-access - attention_mask, - token_type_ids, - past_key_values, - cache_position, - inputs_embeds, - is_training, - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - **lm_kwargs, - ) - - logits = outputs[0] - loss = None - if labels is not None: - if attention_mask is not None: - # Get the shifted attention mask - shift_attention_mask = attention_mask[:, -logits.shape[1] + 1 :].to( - logits.device - ) # +1 for shift - - # Filter logits and labels based on attention mask - valid_indices = shift_attention_mask != 0 - filtered_logits = logits[..., :-1, :][valid_indices] - filtered_labels = labels[..., 1:][valid_indices.to(labels.device)] - - # TODO: do we need to handle num_items_in_batch given we filter the logits and labels? - - loss = self.loss_function( - logits=filtered_logits, - labels=None, # we pass shift_labels - shift_labels=filtered_labels, - vocab_size=self.config.text_config.vocab_size, - **lm_kwargs, - ) - else: - # Standard case without filtering - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.text_config.vocab_size, - **lm_kwargs, - ) - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Gemma3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_gemma3conditionalgeneration_forward(): - from transformers.models.gemma3.modeling_gemma3 import ( - Gemma3ForConditionalGeneration, - ) - - Gemma3ForConditionalGeneration.forward = new_forward diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 11e0989cf..a7875eefe 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -342,10 +342,11 @@ def apply_lora_kernel_patches( layers = [] # check for multimodal models first - if hasattr(model, "language_model"): - layers = model.language_model.model.layers - elif hasattr(model, "model"): - layers = model.model.model.layers + pretrained_model = model.model + if hasattr(pretrained_model, "language_model"): + layers = pretrained_model.language_model.layers + elif hasattr(pretrained_model, "model"): + layers = pretrained_model.model.layers else: raise NotImplementedError( f"Model type {model.config.model_type} is not supported yet. Please create an Issue." diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b59bd8a75..866a9c454 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -204,7 +204,7 @@ def execute_training( if cfg.sequence_parallel_degree > 1: models = [trainer.model] - if hasattr(trainer, "ref_model"): + if hasattr(trainer, "ref_model") and trainer.ref_model: models.append(trainer.ref_model) stack.enter_context( diff --git a/src/axolotl/utils/schemas/vllm.py b/src/axolotl/utils/schemas/vllm.py index 48441de5e..0ae635589 100644 --- a/src/axolotl/utils/schemas/vllm.py +++ b/src/axolotl/utils/schemas/vllm.py @@ -44,3 +44,12 @@ class VllmConfig(BaseModel): default=8000, json_schema_extra={"description": "Port of the vLLM server to start on"}, ) + + enable_reasoning: bool | None = Field( + default=None, + json_schema_extra={"description": "Enable reasoning for VLLM"}, + ) + reasoning_parser: str | None = Field( + default=None, + json_schema_extra={"description": "Reasoning parser for VLLM"}, + ) diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index f36eef953..dad777947 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -90,7 +90,7 @@ class TestKnowledgeDistillation: train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() check_tensorboard( - temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high" + temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high" ) @pytest.mark.parametrize( diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index 6c7a9b2e4..8ea2e3ce4 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -262,6 +262,99 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): **current_env, }, ) + finally: + (recursive_kill(vllm_process)) + + @require_vllm + def test_llama_lora_sp(self, temp_dir): + rnd_reward_suffix = str(random.randint(1000, 9999)) + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "chat_template": "llama3", + "rl": "grpo", + "trl": { + "beta": 0.001, + "max_completion_length": 256, + "use_vllm": True, + "num_generations": 4, + "reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"], + }, + "vllm": { + "max_model_len": 800, + "enable_prefix_caching": True, + }, + "datasets": [ + { + "path": "openai/gsm8k", + "name": "main", + "type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform", + }, + ], + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "sequence_parallel_degree": 2, + "flash_attention": True, + "sequence_len": 1024, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "max_steps": 3, + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 2, + "warmup_steps": 10, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.0001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + } + ) + + self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix) + + current_env = os.environ.copy() + env = { + "NCCL_P2P_LEVEL": "LOC", + **current_env, + "CUDA_VISIBLE_DEVICES": "1", + } + vllm_process = start_vllm( + cfg.base_model, + env=env, + quiet=True, + wait=300, + gpu_memory_utilization=0.15, + max_model_len=cfg.vllm.max_model_len, + enable_prefix_caching=cfg.vllm.enable_prefix_caching, + host="0.0.0.0", + port=8000, + ) + + try: + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + str(2), + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ], + env={ + "NCCL_P2P_LEVEL": "LOC", + "NCCL_DEBUG": "INFO", + **current_env, + }, + ) finally: recursive_kill(vllm_process) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index b93947f0d..32657c156 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -33,7 +33,7 @@ class TestLlamaVision(unittest.TestCase): "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.05, - "lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj", + "lora_target_modules": r"model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj", "val_set_size": 0, "chat_template": "llama3_2_vision", "datasets": [ @@ -81,7 +81,7 @@ class TestLlamaVision(unittest.TestCase): "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.05, - "lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj", + "lora_target_modules": r"model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj", "val_set_size": 0, "chat_template": "llama3_2_vision", "datasets": [ diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index d0837f191..e812a5f7e 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -10,7 +10,12 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault -from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir +from .utils import ( + check_model_output_exists, + require_torch_2_5_1, + require_torch_2_6_0, + with_temp_dir, +) class TestCustomOptimizers(unittest.TestCase): @@ -196,6 +201,7 @@ class TestCustomOptimizers(unittest.TestCase): check_model_output_exists(temp_dir, cfg) @with_temp_dir + @require_torch_2_6_0 def test_came_pytorch(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index 45fc75282..8b29eab21 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -6,10 +6,16 @@ from pathlib import Path from datasets import Dataset, load_dataset from transformers import AutoTokenizer +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompters import AlpacaPrompter +from axolotl.train import setup_model_and_trainer +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault +from tests.e2e.utils import with_temp_dir from tests.hf_offline_utils import enable_hf_offline @@ -67,6 +73,85 @@ class TestPacking(unittest.TestCase): assert example["position_ids"][next_bos_index] == 0 assert example["position_ids"][next_bos_index + 1] == 1 + @with_temp_dir + def test_lora_packing(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "sample_packing": True, + "multipack_real_batches": False, + "eval_sample_packing": True, + "adapter": "lora", + "lora_r": 32, + "lora_alpha": 64, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.2, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 20, + "save_steps": 10, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "fp16": False, + "bf16": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + ( + trainer, + _, + _, + _, + _, + ) = setup_model_and_trainer(cfg, dataset_meta) + + sampler = trainer._get_eval_sampler( # pylint: disable=protected-access + trainer.eval_dataset + ) + assert "MultipackBatchSampler" in sampler.__class__.__name__ + assert ( + "V2BatchSamplerDataCollatorForSeq2Seq" + in trainer.eval_data_collator.__class__.__name__ + ) + dataloader = trainer.get_eval_dataloader(trainer.eval_dataset) + dataloader_iter = iter(dataloader) + batch = next(dataloader_iter) + assert batch["input_ids"].shape == (1, 8192) + + sampler = trainer._get_train_sampler( # pylint: disable=protected-access + trainer.train_dataset + ) + assert "MultipackBatchSampler" in sampler.__class__.__name__ + assert ( + "V2BatchSamplerDataCollatorForSeq2Seq" + in trainer.train_data_collator.__class__.__name__ + ) + dataloader = trainer.get_train_dataloader() + dataloader_iter = iter(dataloader) + batch = next(dataloader_iter) + assert batch["input_ids"].shape == (1, 8192) + if __name__ == "__main__": unittest.main()