diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4fcf08352..01606f902 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,6 +31,11 @@ jobs: python_version: "3.11" pytorch: 2.7.0 axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.7.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -94,6 +99,11 @@ jobs: python_version: "3.11" pytorch: 2.7.0 axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.7.0 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c296e2314..69f0a030d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -295,6 +295,7 @@ jobs: find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; docker-e2e-tests-1st: + # Run this job first as a gate for running the remainder of the test matrix if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] @@ -341,6 +342,8 @@ jobs: # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] timeout-minutes: 90 + # Only run the remainder of the matrix if the first e2e check passed; + # this is to save on wasted compute costs for known failures that get caught in the first run needs: [pre-commit, pytest, docker-e2e-tests-1st] strategy: @@ -365,6 +368,12 @@ jobs: pytorch: 2.7.0 num_gpus: 1 axolotl_extras: + - cuda: 128 + cuda_version: 12.8.1 + python_version: "3.11" + pytorch: 2.7.0 + num_gpus: 1 + axolotl_extras: steps: - name: Checkout uses: actions/checkout@v4 diff --git a/_quarto.yml b/_quarto.yml index 56ebe9d68..15d385711 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -62,7 +62,6 @@ quartodoc: - core.trainers.mixins.optimizer - core.trainers.mixins.rng_state_loader - core.trainers.mixins.scheduler - - core.trainers.mixins.sequence_parallel - title: Context Managers desc: Context managers for altering trainer behaviors contents: @@ -141,7 +140,8 @@ quartodoc: - utils.optimizers.adopt - utils.data.pretraining - utils.data.sft - - utils.gradient_checkpointing.unsloth + - utils.gradient_checkpointing.offload_cpu + - utils.gradient_checkpointing.offload_disk - title: Schemas desc: Pydantic data models for Axolotl config contents: diff --git a/cicd/multigpu.py b/cicd/multigpu.py index 90d4ce1ee..7de4ae0a7 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str): image=cicd_image, gpu=GPU_CONFIG, timeout=90 * 60, - cpu=8.0, + cpu=16.0, memory=131072 * N_GPUS, volumes=VOLUME_CONFIG, ) diff --git a/docs/config.qmd b/docs/config.qmd index 2e0e25987..298fb6aa9 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -540,7 +540,7 @@ train_on_inputs: false # Note that training loss may have an oscillating pattern with this enabled. group_by_length: false -# Whether to use gradient checkpointing. Available options are: true, false, "offload". +# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk". # https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing gradient_checkpointing: false # additional kwargs to pass to the trainer for gradient checkpointing @@ -634,7 +634,9 @@ weight_decay: # adamw hyperparams adam_beta1: adam_beta2: +adam_beta3: # only used for CAME Optimizer adam_epsilon: +adam_epsilon2: # only used for CAME Optimizer # Gradient clipping max norm max_grad_norm: diff --git a/docs/getting-started.qmd b/docs/getting-started.qmd index a0501ad21..064985e35 100644 --- a/docs/getting-started.qmd +++ b/docs/getting-started.qmd @@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format: Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to format them. -2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca +2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca` format): ```json @@ -120,6 +120,12 @@ axolotl train my_training.yml ## Common Tasks {#sec-common-tasks} +::: {.callout-tip} + +The same yaml file is used for training, inference, and merging. + +::: + ### Testing Your Model {#sec-testing} After training, test your model: @@ -128,6 +134,16 @@ After training, test your model: axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" ``` +More details can be found in [Inference](inference.qmd). + +### Using a UI {#sec-ui} + +Launch a Gradio interface: + +```bash +axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio +``` + ### Preprocessing Data {#sec-preprocessing} For large datasets, preprocess first: @@ -136,14 +152,22 @@ For large datasets, preprocess first: axolotl preprocess my_training.yml ``` -### Using a UI {#sec-ui} +Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset. -Launch a Gradio interface: +More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd). + +### Merging LoRA weights {#sec-merging-lora} + +To merge the LoRA weights back into the base model, run: ```bash -axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio +axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out" ``` +The merged model will be saved in the `{output_dir}/merged` directory. + +More details can be found in [Merging LoRA weights](inference.qmd#sec-merging). + ## Next Steps {#sec-next-steps} Now that you have the basics, you might want to: @@ -156,6 +180,7 @@ Now that you have the basics, you might want to: Check our other guides for details on these topics: - [Configuration Guide](config.qmd) - Full configuration options +- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources - [Dataset Formats](dataset-formats) - Working with different data formats - [Multi-GPU Training](multi-gpu.qmd) - [Multi-Node Training](multi-node.qmd) diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index 55eaca6c3..fee7d17e5 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -87,20 +87,7 @@ We support sequence parallelism (SP) via the allows one to split up sequences across GPUs, which is useful in the event that a single sequence causes OOM errors during model training. -First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`, -or from source with `pip install .[ring-flash-attn]`. - -Your Axolotl YAML config should contain the following lines: - -```{.yaml} -sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU -flash_attention: true # Required with sequence parallelism - -# Optional; strides across the key dimension. Larger values use more memory but will make training faster. -heads_k_stride: 1 -``` - -See our [dedicated guide](sequence_parallelism.qmd) for more details. +See our [dedicated guide](sequence_parallelism.qmd) for more information. ### FSDP + QLoRA {#sec-fsdp-qlora} diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index 1bff17ce9..b98206135 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -41,7 +41,7 @@ When sequence parallelism is enabled: 1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group 2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids -3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences +3. Position IDs are adjusted to maintain proper relative positions 4. The trainer uses special ring communication patterns for attention operations ## Requirements @@ -67,9 +67,11 @@ sequence_len: 8192 ... sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU -flash_attention: true # Required with sequence parallelism # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 +# Optional; one of "varlen_llama3" or "batch_ring". Defaults to +# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise. +ring_attn_func: ... ``` diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml index 3547c6c98..bd896c2b3 100644 --- a/examples/qwen2/dpo.yaml +++ b/examples/qwen2/dpo.yaml @@ -2,7 +2,6 @@ base_model: Qwen/Qwen2.5-0.5B # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name - chat_template: qwen_25 rl: dpo datasets: diff --git a/src/axolotl/core/trainer_builder/base.py b/src/axolotl/core/trainer_builder/base.py index 8fbf1efe8..c3d5faa3c 100644 --- a/src/axolotl/core/trainer_builder/base.py +++ b/src/axolotl/core/trainer_builder/base.py @@ -465,8 +465,6 @@ class TrainerBuilderBase(abc.ABC): "save_only_model", "include_tokens_per_second", "weight_decay", - "sequence_parallel_degree", - "ring_attn_func", "seed", ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: diff --git a/src/axolotl/core/trainer_builder/rl.py b/src/axolotl/core/trainer_builder/rl.py index c45edbe4a..52be54f79 100644 --- a/src/axolotl/core/trainer_builder/rl.py +++ b/src/axolotl/core/trainer_builder/rl.py @@ -185,7 +185,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.eval_dataset: trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: - trainer_kwargs["peft_config"] = self.peft_config + if self.cfg.rl is not RLType.GRPO: + trainer_kwargs["peft_config"] = self.peft_config if self.cfg.precompute_ref_log_probs is not None: trainer_kwargs["precompute_ref_log_probs"] = ( self.cfg.precompute_ref_log_probs diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 2f0ce6894..d5cfc23df 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -29,7 +29,6 @@ from axolotl.core.trainers.mixins import ( OptimizerMixin, RngLoaderMixin, SchedulerMixin, - SequenceParallelMixin, ) from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, @@ -40,9 +39,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = logging.getLogger(__name__) -class AxolotlTrainer( - SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer -): +class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] @@ -68,10 +65,6 @@ class AxolotlTrainer( if self.args.orpo_alpha: self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - # Initialize sequence parallelism if enabled - if self.args.sequence_parallel_degree > 1: - self._setup_sequence_parallel() - def _wrap_model(self, model, training=True, dataloader=None): if self.args.torch_compile: torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access @@ -122,8 +115,8 @@ class AxolotlTrainer( def _get_train_sampler(self) -> Sampler | None: """ - Helper method to get the sampler for training. Handles cases for sequence - parallelism, sample packing, and curriculum sampling (sequential). + Helper method to get the sampler for training. Handles cases for sample packing + and curriculum sampling (sequential). Returns: If the dataset is non-empty, a sampler is returned, the type of which @@ -132,9 +125,7 @@ class AxolotlTrainer( use_sample_packing = self.args.sample_packing and not self.args.pretraining # Determine the base sampler first - if self.args.sequence_parallel_degree > 1: - base_sampler = self._sp_get_train_sampler(self.train_dataset) - elif self.args.curriculum_sampling: + if self.args.curriculum_sampling: base_sampler = SequentialSampler(self.train_dataset) elif use_sample_packing: base_sampler = RandomSampler(self.train_dataset) @@ -153,8 +144,7 @@ class AxolotlTrainer( def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None: """ - Helper method to get the sampler for evaluation. Handles sequence parallelism - and sample packing cases. + Helper method to get the sampler for evaluation. Handles sample packing case. Returns: If the dataset is non-empty, a sampler is returned, the type of which @@ -168,9 +158,7 @@ class AxolotlTrainer( ) # Determine the base sampler - if self.args.sequence_parallel_degree > 1: - base_sampler = self._sp_get_eval_sampler(eval_dataset) - elif use_multipack: + if use_multipack: base_sampler = SequentialSampler(eval_dataset) else: return super()._get_eval_sampler(eval_dataset) @@ -236,14 +224,6 @@ class AxolotlTrainer( ): self.accelerator.even_batches = False - # Return unprepared dataloader if using sequence parallelism - # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation - # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., - # slice each batch along the sequence dimension). - if self.args.sequence_parallel_degree > 1: - return dataloader - - # Otherwise prepare with accelerator return self.accelerator.prepare_data_loader(dataloader) def get_train_dataloader(self) -> DataLoader: @@ -287,12 +267,7 @@ class AxolotlTrainer( return dataloader - # Handle sample packing or sequence parallelism - if ( - self.args.sample_packing - and self.args.eval_sample_packing is not False - or self.args.sequence_parallel_degree > 1 - ): + if self.args.sample_packing and self.args.eval_sample_packing is not False: # Get appropriate data collator self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.eval_data_collator @@ -302,17 +277,6 @@ class AxolotlTrainer( if "length" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns(["length"]) - # Handle dataset preprocessing for SP - if self.args.sequence_parallel_degree > 1: - if isinstance(eval_dataset, datasets.Dataset): - eval_dataset = self._remove_unused_columns( - eval_dataset, description="evaluation" - ) - else: - self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init - self.data_collator, description="evaluation" - ) - # Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise batch_size = ( self.args.eval_batch_size diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 1ce7deea7..c2c80c0bc 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -1,31 +1,15 @@ -""" -DPO trainer for axolotl -""" +"""DPO trainer for axolotl""" import gc -import random from functools import wraps -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Union -import pandas as pd import torch -import wandb -from accelerate import PartialState -from datasets import Dataset, IterableDataset from peft.optimizers import create_loraplus_optimizer from torch import nn -from torch.utils.data import DataLoader -from transformers import ( - BaseImageProcessor, - FeatureExtractionMixin, - PreTrainedTokenizerBase, - ProcessorMixin, - Trainer, -) -from transformers.trainer_utils import EvalLoopOutput +from transformers import Trainer from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt -from trl.trainer.utils import log_table_to_comet_experiment +from trl import DPOTrainer from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.utils import ( @@ -38,9 +22,7 @@ if is_sagemaker_mp_enabled(): class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): - """ - Extend the base DPOTrainer for axolotl helpers - """ + """Extend the base DPOTrainer for axolotl helpers.""" tag_names = ["axolotl", "dpo"] @@ -85,8 +67,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): @wraps(DPOTrainer.push_to_hub) def push_to_hub(self, *args, **kwargs) -> str: """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + Overwrite the `push_to_hub` method in order to force-add the tags when pushing + the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` + for more details. """ kwargs = sanitize_kwargs_for_ds_tagging( dataset_tags=self.dataset_tags, kwargs=kwargs @@ -95,64 +78,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): return super().push_to_hub(*args, **kwargs) - # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release - def _prepare_dataset( - self, - dataset: Union[Dataset, IterableDataset], - processing_class: Union[ - PreTrainedTokenizerBase, - BaseImageProcessor, - FeatureExtractionMixin, - ProcessorMixin, - ], - args: DPOConfig, - dataset_name: str, - ) -> Union[Dataset, IterableDataset]: - # Build the kwargs for the `map` function - map_kwargs: Dict[str, Any] = {"writer_batch_size": 10} - if isinstance(dataset, Dataset): # IterableDataset does not support num_proc - map_kwargs["num_proc"] = args.dataset_num_proc - - with PartialState().main_process_first(): - # Extract prompt if needed - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" - dataset = dataset.map(maybe_extract_prompt, **map_kwargs) - - # Apply the chat template if needed - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" - dataset = dataset.map( - maybe_apply_chat_template, - fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, - **map_kwargs, - ) - - # Tokenize the dataset - if isinstance( - dataset, Dataset - ): # `IterableDataset.map` does not support `desc` - map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" - - dataset = dataset.map( - self.tokenize_row if not self.is_vision_model else self.process_row, - remove_columns=["chosen", "rejected"], - fn_kwargs={ - "processing_class": processing_class, - "max_prompt_length": args.max_prompt_length, - "max_completion_length": args.max_completion_length, - # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) - "add_special_tokens": False, - }, - **map_kwargs, - ) - - return dataset - @staticmethod def tokenize_row( features, @@ -192,69 +117,3 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): gc.collect() torch.cuda.empty_cache() return loss - - # TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release - def evaluation_loop( - self, - dataloader: DataLoader, - description: str, - prediction_loss_only: Optional[bool] = None, - ignore_keys: Optional[list[str]] = None, - metric_key_prefix: str = "eval", - ) -> EvalLoopOutput: - """ - Overriding built-in evaluation loop to store metrics for each batch. - Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. - - Works both with or without labels. - """ - - # Sample and save to game log if requested (for one batch to save time) - if self.generate_during_eval: - # Generate random indices within the range of the total number of samples - num_samples = len(dataloader.dataset) - random_indices = random.sample( - range(num_samples), k=self.args.eval_batch_size - ) - - # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader - random_batch_dataset = dataloader.dataset.select(random_indices) - random_batch = self.data_collator(random_batch_dataset) - random_batch = self._prepare_inputs(random_batch) - - policy_output_decoded, ref_output_decoded = ( - self.generate_from_model_and_ref(self.model, random_batch) - ) - - table = pd.DataFrame( - columns=["Prompt", "Policy", "Ref Model"], - data=[ - [prompt, pol[len(prompt) :], ref[len(prompt) :]] - for prompt, pol, ref in zip( - random_batch_dataset["prompt"], - policy_output_decoded, - ref_output_decoded, - ) - ], - ) - if "wandb" in self.args.report_to and self.accelerator.is_main_process: - wandb.log({"game_log": wandb.Table(data=table)}) - - if "comet_ml" in self.args.report_to: - log_table_to_comet_experiment( - name="game_log.csv", - table=table, - ) - - # Base evaluation - initial_output = super( # pylint: disable=bad-super-call - DPOTrainer, self - ).evaluation_loop( - dataloader, - description, - prediction_loss_only, - ignore_keys, - metric_key_prefix, - ) - - return initial_output diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index bc3d140b1..a603ed860 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-lines,duplicate-code,protected-access,no-member import warnings -from contextlib import nullcontext from typing import Any import datasets @@ -14,7 +13,7 @@ from accelerate.utils import ( broadcast_object_list, gather, gather_object, - is_peft_model, + is_peft_available, ) from datasets import Dataset, IterableDataset from torch import nn @@ -30,15 +29,13 @@ from transformers import ( TrainerCallback, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_peft_available from trl import GRPOTrainer from trl.data_utils import ( apply_chat_template, is_conversational, maybe_apply_chat_template, ) -from trl.extras.profiling import profiling_context, profiling_decorator -from trl.import_utils import is_deepspeed_available +from trl.extras.profiling import profiling_context from trl.models import unwrap_model_for_generation from trl.trainer.grpo_config import GRPOConfig from trl.trainer.grpo_trainer import RewardFunc, nanstd @@ -46,68 +43,18 @@ from trl.trainer.utils import pad from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin -from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group +from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group if is_peft_available(): # pylint: disable=unused-import from peft import PeftConfig -if is_deepspeed_available(): - import deepspeed - class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): """Extend the base GRPOTrainer for axolotl helpers""" _tag_names = ["trl", "grpo", "axolotl"] - @profiling_decorator - def _move_model_to_vllm(self): - # For DeepSpeed ZeRO-3, we need to gather all parameters before operations - deepspeed_plugin = self.accelerator.state.deepspeed_plugin - zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 - gather_if_zero3 = ( - deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext - ) - - if is_peft_model(self.model): - # With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging - # adapters in a sharded manner is not supported. - with gather_if_zero3(list(self.model.parameters())): - self.model.merge_adapter() - - # Update vLLM weights while parameters are gathered - for name, param in self.model.named_parameters(): - # When using PEFT, we need to recover the original parameter name and discard some parameters - name = ( - name.removeprefix("base_model.model.") - .removeprefix("base_model.model.") - .replace(".base_layer", "") - ) - if self.model.prefix in name: - continue - # When module to save, remove its prefix and discard the original module - if "original_module" in name: - continue - name = name.replace("modules_to_save.default.", "") - - if self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - - # Unmerge adapters while parameters are still gathered - self.model.unmerge_adapter() - # Parameters will automatically be repartitioned when exiting the context - else: - # For non-PEFT models, simply gather and update each parameter individually. - for name, param in self.model.named_parameters(): - with gather_if_zero3([param]): - if self.accelerator.is_main_process: - self.vllm_client.update_named_param(name, param.data) - - # Reset cache on main process - if self.accelerator.is_main_process: - self.vllm_client.reset_prefix_cache() - class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): """Extend the base GRPOTrainer for sequence parallelism handling""" diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 44751b465..a71cb321a 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -6,4 +6,3 @@ from .optimizer import OptimizerMixin from .rng_state_loader import RngLoaderMixin from .scheduler import SchedulerMixin -from .sequence_parallel import SequenceParallelMixin diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py deleted file mode 100644 index 0f30458cd..000000000 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Module for Axolotl trainer sequence parallelism mixin""" - -import torch.distributed as dist -from datasets import Dataset -from torch.utils.data import DistributedSampler, Sampler - -from axolotl.monkeypatch.attention.ring_attn import ( - get_ring_attn_group, -) - - -class SequenceParallelMixin: - """ - Mixin class for sequence parallelism support in trainers. - - This mixin provides functionality for handling sequence parallelism, - specifically for creating appropriate data samplers. - """ - - args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] - - def _setup_sequence_parallel(self): - """Set up sequence parallelism environment.""" - self.ring_attn_group = get_ring_attn_group() - - def _create_sequence_parallel_sampler( - self, - dataset: Dataset, - shuffle: bool = True, - is_eval: bool = False, - ) -> DistributedSampler: - """ - Helper method to create sampler for sequence parallelism (SP). - - We create a distributed sampler with rank equal to the SP group ID, which - means that all ranks in the SP group receive the same sample / set of samples - per training step. We also set the number of replicas equal to the number of - SP groups, which is a bit of a hack / unintended use, but works! - - Args: - dataset: Dataset to sample from. - shuffle: Whether to shuffle the dataset. - is_eval: Whether we are creating a sampler for evaluation or training. - - Returns: - Distributed sampler. - """ - num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree - sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree - - return DistributedSampler( - dataset, - num_replicas=num_sp_groups, - rank=sp_group_id, - seed=self.args.seed if shuffle else None, - shuffle=shuffle, - drop_last=not is_eval, - ) - - def _sp_get_train_sampler(self, dataset) -> Sampler | None: - """ - Get a training sampler configured for sequence parallelism. - - Args: - dataset: The training dataset - - Returns: - Configured sequence parallel sampler. - """ - return self._create_sequence_parallel_sampler( - dataset, - shuffle=not self.args.curriculum_sampling, - ) - - def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None: - """ - Get an evaluation sampler configured for sequence parallelism. - - Args: - eval_dataset: The evaluation dataset. - - Returns: - Configured sequence parallel sampler. - """ - return self._create_sequence_parallel_sampler( - eval_dataset, shuffle=False, is_eval=True - ) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 0b14e7661..9c93f77c7 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -9,8 +9,6 @@ from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig -from axolotl.utils.schemas.enums import RingAttnFunc - @dataclass class AxolotlTrainingMixins: @@ -216,14 +214,16 @@ class AxolotlTrainingMixins: }, ) - sequence_parallel_degree: Optional[int] = field( - default=1, - metadata={"help": "The number of workers to use in sequence parallelism"}, - ) - ring_attn_func: Optional[RingAttnFunc] = field( + adam_beta3: Optional[float] = field( default=None, metadata={ - "help": "The ring-flash-attn function to use in sequence parallelism" + "help": "The beta3 hyperparameter used in some optimizers such as CAME" + }, + ) + adam_epsilon2: Optional[float] = field( + default=None, + metadata={ + "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" }, ) diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py index 99e17910e..ea9e10724 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py @@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import ( from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( - _CONFIG_FOR_DOC, - COHERE_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py index 4c8d2261a..ae3d8c6ef 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py @@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import ( from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.gemma.modeling_gemma import ( - _CONFIG_FOR_DOC, - GEMMA_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py index ccf0c160d..644e5cce7 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py @@ -20,15 +20,11 @@ from torch import nn from transformers.cache_utils import Cache, HybridCache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.gemma3.modeling_gemma3 import ( - _CONFIG_FOR_DOC, - GEMMA3_INPUTS_DOCSTRING, Gemma3CausalLMOutputWithPast, logger, ) from transformers.utils import ( - add_start_docstrings_to_model_forward, is_torchdynamo_compiling, - replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg @@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, @@ -170,10 +162,6 @@ def cce_forward( @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward_multimodal( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py index 42ab996b9..bed411ace 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py @@ -19,15 +19,9 @@ from transformers.modeling_outputs import ( CausalLMOutputWithPast, ) from transformers.models.llama.modeling_llama import ( - _CONFIG_FOR_DOC, - LLAMA_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple @@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py index 7204f5c90..3143e9c8d 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py @@ -16,22 +16,12 @@ from torch import nn from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama4.modeling_llama4 import ( - _CONFIG_FOR_DOC, - LLAMA4_INPUTS_DOCSTRING, Llama4CausalLMOutputWithPast, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) _PATCH_OPTS: PatchOptions | None = None -@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, @@ -160,9 +150,6 @@ def cce_forward( ) -@replace_return_docstrings( - output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward_multimodal( self, input_ids: torch.LongTensor | None = None, # type: ignore diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py index adb65fa8f..aa252701e 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py @@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import ( Mistral3CausalLMOutputWithPast, ) from transformers.models.mistral.modeling_mistral import ( - _CONFIG_FOR_DOC, - MISTRAL_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack from transformers.utils import ( - add_start_docstrings_to_model_forward, is_torchdynamo_compiling, - replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg @@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py index 0811bf55a..afe56266e 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py @@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import ( apply_lce, ) from transformers.models.qwen2_moe.modeling_qwen2_moe import ( - _CONFIG_FOR_DOC, - QWEN2MOE_INPUTS_DOCSTRING, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, load_balancing_loss_func, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple @@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py index 250c3ab6b..79af01cfa 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py @@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import ( ) from torch.nn import CrossEntropyLoss from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - _CONFIG_FOR_DOC, - QWEN2_VL_INPUTS_DOCSTRING, Qwen2VLCausalLMOutputWithPast, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) _PATCH_OPTS: PatchOptions | None = None -@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward_multimodal( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py index c5cd76f94..90466e64b 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py @@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import ( TransformersModelT, apply_lce, ) -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - _CONFIG_FOR_DOC, - QWEN3_MOE_INPUTS_DOCSTRING, KwargsForCausalLM, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, load_balancing_loss_func, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple @@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/__init__.py b/src/axolotl/integrations/liger/models/__init__.py similarity index 100% rename from src/axolotl/monkeypatch/attention/ring_attn/adapters/__init__.py rename to src/axolotl/integrations/liger/models/__init__.py diff --git a/src/axolotl/integrations/liger/models/deepseekv2.py b/src/axolotl/integrations/liger/models/deepseekv2.py index c29fd4e79..2f0d2a704 100644 --- a/src/axolotl/integrations/liger/models/deepseekv2.py +++ b/src/axolotl/integrations/liger/models/deepseekv2.py @@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import CausalLMOutputWithPast -# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) -# @replace_return_docstrings( -# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -# ) def lce_forward( self, input_ids: torch.LongTensor = None, diff --git a/src/axolotl/integrations/liger/models/jamba.py b/src/axolotl/integrations/liger/models/jamba.py index 7ab464c88..d25529970 100644 --- a/src/axolotl/integrations/liger/models/jamba.py +++ b/src/axolotl/integrations/liger/models/jamba.py @@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import MoeCausalLMOutputWithPast from transformers.models.jamba.modeling_jamba import ( - _CONFIG_FOR_DOC, - JAMBA_INPUTS_DOCSTRING, HybridMambaAttentionDynamicCache, load_balancing_loss_func, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) -@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def lce_forward( self, input_ids: torch.LongTensor = None, diff --git a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py b/src/axolotl/monkeypatch/attention/ring_attn/__init__.py deleted file mode 100644 index a50ad456e..000000000 --- a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Init for ring attention monkeypatch module""" - -# pylint: disable=unused-import -# flake8: noqa - -from .patch import ( - get_ring_attn_group, - register_ring_attn, - set_ring_attn_group, - update_ring_attn_params, -) diff --git a/src/axolotl/monkeypatch/attention/ring_attn/patch.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py deleted file mode 100644 index 8cbba338a..000000000 --- a/src/axolotl/monkeypatch/attention/ring_attn/patch.py +++ /dev/null @@ -1,131 +0,0 @@ -""" -Ring attention group registration and flash attention patching. - -Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention) -package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in -their sequence parallel version of Flash Attention 2. -""" - -import torch -import torch.distributed as dist -from accelerate.logging import get_logger - -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids -from axolotl.utils.schemas.enums import RingAttnFunc - -LOG = get_logger(__name__) - - -RING_ATTN_GROUP = None - - -def get_ring_attn_group() -> dist.ProcessGroup: - """ - Getter for ring attention group on this rank. - - Returns: - The process group for ring attention for this rank. - """ - return RING_ATTN_GROUP - - -def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): - """ - Setter for ring attention group on this rank. - - Args: - Process group for ring attention. - """ - global RING_ATTN_GROUP # pylint: disable=global-statement - RING_ATTN_GROUP = ring_attn_group - - -def register_ring_attn( - sequence_parallel_degree: int, - heads_k_stride: int | None, - ring_attn_func: RingAttnFunc | None, -): - """ - Create ring attention group and substitute flash attn with ring flash attn. - - Args: - sequence_parallel_degree: Sequence parallelism factor. - heads_k_stride: Sequence parallelism K head stride size. Passed - through to `ring_flash_attn.substitute_hf_flash_attn`. - ring_attn_func: `ring_flash_attn` ring attention implemention. If sample - packing is enabled, it must be a `varlen` function; otherwise, it must be a - `batch` function. - """ - if get_ring_attn_group() is not None: - LOG.info("Ring attention already registered, exiting early...") - return - - LOG.info( - "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs" - ) - - rank = dist.get_rank() - world_size = dist.get_world_size() - - assert sequence_parallel_degree <= world_size, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must be less than or equal to world_size ({world_size})" - ) - assert world_size % sequence_parallel_degree == 0, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must evenly divide world_size ({world_size})" - ) - - # Assign ranks to sequence parallel groups - group_assignments = {} - for i in range(world_size // sequence_parallel_degree): - ring_attn_ranks = list( - range( - i * sequence_parallel_degree, - (i + 1) * sequence_parallel_degree, - ) - ) - group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") - - # Track which GPUs are in which groups - for r in ring_attn_ranks: - group_assignments[r] = i - - if rank in ring_attn_ranks: - set_ring_attn_group(group) - - # Log the GPU group assignments - if rank == 0: - LOG.info(f"Sequence parallel group assignments: {group_assignments}") - - if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: - from ring_flash_attn import substitute_hf_flash_attn - - substitute_hf_flash_attn( - process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 - ) - elif ring_attn_func is RingAttnFunc.BATCH_RING: - from axolotl.monkeypatch.attention.ring_attn.adapters.batch import ( - substitute_hf_flash_attn, - ) - - substitute_hf_flash_attn( - process_group=get_ring_attn_group(), - ring_attn_func=ring_attn_func, - ) - - -def update_ring_attn_params(position_ids: torch.Tensor | None): - """ - Calculate the cumulative sequence lengths for the current forward pass and pass the - value to the substituted `ring_flash_attn`. - - Args: - position_ids: Optional tensor of position IDs (for sample packed data). - """ - from ring_flash_attn import update_ring_flash_attn_params - - cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) - cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) - update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) diff --git a/src/axolotl/monkeypatch/gemma3.py b/src/axolotl/monkeypatch/gemma3.py index 38183fa0e..36f591efd 100644 --- a/src/axolotl/monkeypatch/gemma3.py +++ b/src/axolotl/monkeypatch/gemma3.py @@ -7,24 +7,16 @@ from typing import Optional, Tuple, Union import torch from transformers.cache_utils import Cache from transformers.models.gemma3.modeling_gemma3 import ( - _CONFIG_FOR_DOC, - GEMMA3_INPUTS_DOCSTRING, Gemma3CausalLMOutputWithPast, logger, ) from transformers.utils import ( - add_start_docstrings_to_model_forward, is_torchdynamo_compiling, - replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def new_forward( self, input_ids: torch.LongTensor = None, diff --git a/src/axolotl/monkeypatch/ring_attn/__init__.py b/src/axolotl/monkeypatch/ring_attn/__init__.py new file mode 100644 index 000000000..5833b9ce4 --- /dev/null +++ b/src/axolotl/monkeypatch/ring_attn/__init__.py @@ -0,0 +1,22 @@ +"""Init for ring attention monkeypatch module""" + +# pylint: disable=unused-import +# flake8: noqa + +from .patch import ( + get_ring_attn_group, + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, + set_ring_attn_group, + update_ring_attn_params, +) + +__all__ = ( + "get_ring_attn_group", + "patch_prepare_data_loader", + "patch_prepare_device_mesh", + "register_ring_attn", + "set_ring_attn_group", + "update_ring_attn_params", +) diff --git a/src/axolotl/monkeypatch/ring_attn/adapters/__init__.py b/src/axolotl/monkeypatch/ring_attn/adapters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py b/src/axolotl/monkeypatch/ring_attn/adapters/batch.py similarity index 100% rename from src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py rename to src/axolotl/monkeypatch/ring_attn/adapters/batch.py diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py new file mode 100644 index 000000000..4329d9f13 --- /dev/null +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -0,0 +1,223 @@ +"""Ring attention group registration and flash attention patching. + +Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention) +package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in +their sequence parallel version of Flash Attention 2. + +We also provide some patches for accelerate functions to prepare the dataloader for +sequence parallelism training. +""" + +import inspect + +import accelerate +import torch +import torch.distributed as dist +from accelerate.logging import get_logger + +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.schemas.enums import RingAttnFunc + +LOG = get_logger(__name__) + + +RING_ATTN_GROUP = None + +ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 + submesh_dp_size = 1 + submesh_tp_size = 1 + if "tp" in torch_device_mesh.mesh_dim_names: + submesh_tp_size = torch_device_mesh["tp"].size() + if "dp" in torch_device_mesh.mesh_dim_names: + submesh_dp_size = torch_device_mesh["dp"].size() + if "fsdp" in torch_device_mesh.mesh_dim_names: + submesh_fsdp_size = torch_device_mesh["fsdp"].size() + process_index = process_index // submesh_tp_size""" + +NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 + submesh_dp_size = 1 + submesh_tp_size = 1 + submesh_cp_size = 1 + if "cp" in torch_device_mesh.mesh_dim_names: + submesh_cp_size = torch_device_mesh["cp"].size() + if "tp" in torch_device_mesh.mesh_dim_names: + submesh_tp_size = torch_device_mesh["tp"].size() + if "dp" in torch_device_mesh.mesh_dim_names: + submesh_dp_size = torch_device_mesh["dp"].size() + if "fsdp" in torch_device_mesh.mesh_dim_names: + submesh_fsdp_size = torch_device_mesh["fsdp"].size() + process_index = process_index // (submesh_tp_size * submesh_cp_size)""" + + +def get_ring_attn_group() -> dist.ProcessGroup: + """Getter for ring attention group on this rank.""" + return RING_ATTN_GROUP + + +def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): + """Setter for ring attention group on this rank.""" + global RING_ATTN_GROUP # pylint: disable=global-statement + RING_ATTN_GROUP = ring_attn_group + + +def register_ring_attn( + sequence_parallel_degree: int, + heads_k_stride: int | None, + ring_attn_func: RingAttnFunc | None, +): + """Create ring attention group and substitute flash attn with ring flash attn. + + Args: + sequence_parallel_degree: Sequence parallelism factor. + heads_k_stride: Sequence parallelism K head stride size. Passed + through to `ring_flash_attn.substitute_hf_flash_attn`. + ring_attn_func: `ring_flash_attn` ring attention implemention. If sample + packing is enabled, it must be a `varlen` function; otherwise, it must be a + `batch` function. + """ + rank = dist.get_rank() + world_size = dist.get_world_size() + + if rank == 0: + LOG.info( + "Enabling ring attention sequence parallelism: " + f"each sequence will be processed across {sequence_parallel_degree} GPUs" + ) + + assert sequence_parallel_degree <= world_size, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " + f"must be less than or equal to world_size ({world_size})" + ) + assert world_size % sequence_parallel_degree == 0, ( + f"sequence_parallel_degree ({sequence_parallel_degree}) " + f"must evenly divide world_size ({world_size})" + ) + + # Assign ranks to sequence parallel groups + group_assignments = {} + for i in range(world_size // sequence_parallel_degree): + ring_attn_ranks = list( + range( + i * sequence_parallel_degree, + (i + 1) * sequence_parallel_degree, + ) + ) + group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") + + # Track which GPUs are in which groups + for r in ring_attn_ranks: + group_assignments[r] = i + + if rank in ring_attn_ranks: + set_ring_attn_group(group) + + # Log the GPU group assignments + if rank == 0: + LOG.info(f"Sequence parallel group assignments: {group_assignments}") + + if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: + from ring_flash_attn import substitute_hf_flash_attn + + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 + ) + elif ring_attn_func is RingAttnFunc.BATCH_RING: + from axolotl.monkeypatch.ring_attn.adapters.batch import ( + substitute_hf_flash_attn, + ) + + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), + ring_attn_func=ring_attn_func, + ) + + +def update_ring_attn_params(position_ids: torch.Tensor | None): + """ + Calculate the cumulative sequence lengths for the current forward pass and pass the + value to the substituted `ring_flash_attn`. + + Args: + position_ids: Optional tensor of position IDs (for sample packed data). + """ + from ring_flash_attn import update_ring_flash_attn_params + + cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) + cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) + update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) + + +def patch_prepare_data_loader(): + """Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree. + + Raies: + RuntimeError: If source code to patch does not exist. + """ + original_fn = accelerate.data_loader.prepare_data_loader + original_source = inspect.getsource(original_fn) + + if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source: + raise RuntimeError( + "SP patch failed - target snippet not found. " + "Check accelerate's version or update the patch." + ) + + patched_source = original_source.replace( + ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE + ) + + # Create a new function from the patched source + namespace = {} + exec( # pylint: disable=exec-used # nosec B102 + patched_source, accelerate.data_loader.__dict__, namespace + ) + patched_function = namespace["prepare_data_loader"] + + accelerate.data_loader.prepare_data_loader = patched_function + LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support") + + +def patch_prepare_device_mesh(sequence_parallel_degree: int): + """Patches the `Accelerator._prepare_device_mesh` method to create a device mesh + that includes sequence parallelism with the specified degree. + + Args: + sequence_parallel_degree (int): The degree of sequence parallelism to use. + """ + + def _prepare_device_mesh(self): + """Prepare the device mesh for distributed training. The dataloader will + determine how to load data based on the device mesh. + """ + if self.state.torch_tp_plugin: + return self.state.torch_tp_plugin.torch_device_mesh + if ( + self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED + and hasattr(self.state, "ds_device_mesh") + ): + return self.state.ds_device_mesh + + # Create device mesh with sequence parallelism + world_size = dist.get_world_size() + mesh_shape = ( + world_size // sequence_parallel_degree, + sequence_parallel_degree, + ) + device_ids = list(range(world_size)) + + # Note that we use "cp" instead of "sp" to match the PyTorch native "context + # parallelism" implementation naming + return dist.DeviceMesh( + "cuda", + torch.tensor(device_ids).reshape(mesh_shape), + mesh_dim_names=("dp", "cp"), + ) + + # Replace the original method with our new method + # pylint: disable=protected-access + accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh + + LOG.info( + "Successfully patched Accelerator._prepare_device_mesh " + f"with sequence_parallel_degree={sequence_parallel_degree}" + ) diff --git a/src/axolotl/utils/callbacks/mlflow_.py b/src/axolotl/utils/callbacks/mlflow_.py index 056fb51cc..43fd4dab0 100644 --- a/src/axolotl/utils/callbacks/mlflow_.py +++ b/src/axolotl/utils/callbacks/mlflow_.py @@ -1,6 +1,7 @@ """MLFlow module for trainer callbacks""" import logging +import os from shutil import copyfile from tempfile import NamedTemporaryFile from typing import TYPE_CHECKING @@ -16,6 +17,11 @@ if TYPE_CHECKING: LOG = logging.getLogger("axolotl.callbacks") +def should_log_artifacts() -> bool: + truths = ["TRUE", "1", "YES"] + return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths + + class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): # pylint: disable=duplicate-code """Callback to save axolotl config to mlflow""" @@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback): ): if is_main_process(): try: - with NamedTemporaryFile( - mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" - ) as temp_file: - copyfile(self.axolotl_config_path, temp_file.name) - mlflow.log_artifact(temp_file.name, artifact_path="") + if should_log_artifacts(): + with NamedTemporaryFile( + mode="w", delete=False, suffix=".yml", prefix="axolotl_config_" + ) as temp_file: + copyfile(self.axolotl_config_path, temp_file.name) + mlflow.log_artifact(temp_file.name, artifact_path="") + LOG.info( + "The Axolotl config has been saved to the MLflow artifacts." + ) + else: LOG.info( - "The Axolotl config has been saved to the MLflow artifacts." + "Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)" ) except (FileNotFoundError, ConnectionError) as err: LOG.warning(f"Error while saving Axolotl config to MLflow: {err}") diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 66044f7f0..6e4f9bada 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -1,6 +1,7 @@ """Module for Axolotl trainer sequence parallelism manager and utilities""" import functools +import inspect import torch import torch.distributed as dist @@ -9,7 +10,7 @@ from torch.utils.hooks import RemovableHandle from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils import ModelOutput -from axolotl.monkeypatch.attention.ring_attn.patch import ( +from axolotl.monkeypatch.ring_attn.patch import ( get_ring_attn_group, update_ring_attn_params, ) @@ -206,12 +207,25 @@ class SequenceParallelContextManager: def __enter__(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): - # Apply sequence parallelism to kwargs and get original sequence length and padding info - kwargs, self.original_seq_len, self.pad_len = ( - self.apply_sequence_parallelism(batch=kwargs) + # Get parameter names from the model's forward function + forward_params = list( + inspect.signature(self.models[0].forward).parameters.keys() ) - return args, kwargs + updated_kwargs = kwargs.copy() + for i, arg in enumerate(args): + if i < len(forward_params): + updated_kwargs[forward_params[i]] = arg + + # Any excess positional arguments are kept as-is + remaining_args = args[len(forward_params) :] + + # Apply sequence parallelism to updated kwargs + updated_kwargs, self.original_seq_len, self.pad_len = ( + self.apply_sequence_parallelism(updated_kwargs) + ) + + return remaining_args, updated_kwargs # Forward post-hook to gather outputs def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 821d28645..57bff4a6b 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -74,6 +74,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): num_proc=cfg.dataset_processes, load_from_cache_file=not cfg.is_preprocess, desc="Mapping RL Dataset", + num_proc=cfg.dataset_processes, **map_kwargs, ) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 5fa0cb60d..6de2d2cf7 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -484,7 +484,7 @@ def get_dataset_wrapper( } LOG.info( - f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}" + f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}" ) if ( diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index f84f76d80..ae0c559e9 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -5,8 +5,11 @@ from functools import partial from packaging import version -from axolotl.utils.gradient_checkpointing.unsloth import ( - Unsloth_Offloaded_Gradient_Checkpointer, +from axolotl.utils.gradient_checkpointing.offload_cpu import ( + CPU_Offloaded_Gradient_Checkpointer, +) +from axolotl.utils.gradient_checkpointing.offload_disk import ( + Disco, ) transformers_version = version.parse(importlib.metadata.version("transformers")) @@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper( decoder_layer, *args, use_reentrant=None ): # pylint: disable=unused-argument if uses_gc_layers(decoder_layer): - return Unsloth_Offloaded_Gradient_Checkpointer.apply( + return CPU_Offloaded_Gradient_Checkpointer.apply( decoder_layer, *args, ) - return Unsloth_Offloaded_Gradient_Checkpointer.apply( + return CPU_Offloaded_Gradient_Checkpointer.apply( + ( + decoder_layer.func.__self__ + if isinstance(decoder_layer, partial) + else decoder_layer.__self__ + ), + *args, + ) + + +def hf_grad_checkpoint_disk_offload_wrapper( + decoder_layer, *args, use_reentrant=None +): # pylint: disable=unused-argument + if uses_gc_layers(decoder_layer): + return Disco.apply( + decoder_layer, + *args, + ) + + return Disco.apply( ( decoder_layer.func.__self__ if isinstance(decoder_layer, partial) diff --git a/src/axolotl/utils/gradient_checkpointing/unsloth.py b/src/axolotl/utils/gradient_checkpointing/offload_cpu.py similarity index 95% rename from src/axolotl/utils/gradient_checkpointing/unsloth.py rename to src/axolotl/utils/gradient_checkpointing/offload_cpu.py index 7a14614b1..bbb5ad40d 100644 --- a/src/axolotl/utils/gradient_checkpointing/unsloth.py +++ b/src/axolotl/utils/gradient_checkpointing/offload_cpu.py @@ -1,4 +1,4 @@ -"""Unsloth checkpointing""" +"""CPU offloaded checkpointing""" # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # @@ -26,7 +26,7 @@ else: torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") -class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name +class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name torch.autograd.Function ): """ diff --git a/src/axolotl/utils/gradient_checkpointing/offload_disk.py b/src/axolotl/utils/gradient_checkpointing/offload_disk.py new file mode 100644 index 000000000..90e70f504 --- /dev/null +++ b/src/axolotl/utils/gradient_checkpointing/offload_disk.py @@ -0,0 +1,531 @@ +""" +DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching +""" + +# Copyright 2025 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import concurrent.futures +import logging +import os +import queue +import shutil +import tempfile +import threading +import time +import uuid +from collections import deque +from concurrent.futures import Future +from typing import Dict + +import torch + +torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") +torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") + +# Setup logger +logger = logging.getLogger(__name__) + + +class DiskOffloadManager: + """ + Manages offloaded tensors and handles prefetching in a separate thread. + Includes synchronization to prevent race conditions. + """ + + def __init__( + self, + prefetch_size: int = 3, + prefetch_to_gpu: bool = True, + save_workers: int = 4, + ): + """ + Args: + prefetch_size: Maximum number of tensors to prefetch in the background. + prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory. + save_workers: Maximum number of concurrent save operations. + """ + self.temp_dir = tempfile.mkdtemp(prefix="disco_") + + # Track tensor paths and their status + self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO) + self.file_locks: Dict[str, threading.Lock] = ( + {} + ) # Maps file_path -> threading.Lock() + # Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted") + self.file_status: Dict[str, str] = {} + + self.max_prefetch = prefetch_size + self.prefetch_to_gpu = prefetch_to_gpu + + # Thread synchronization + self.manager_lock = threading.RLock() # Used for thread-safe operations + + # Prefetch queue and cache + self.prefetch_queue: queue.Queue = queue.Queue() + self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor + + # Save queue and thread pool + self.save_queue: queue.Queue = queue.Queue() + self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers) + self.save_futures: Dict[str, Future] = {} + self.save_semaphore = threading.Semaphore( + save_workers * 2 + ) # Limit concurrent save operations + + # Start prefetch worker thread + self.stop_event = threading.Event() + # start multiple threads for prefetching + self.prefetch_worker_count = 2 + self.prefetch_workers = [] + for _ in range(self.prefetch_worker_count): + worker = threading.Thread(target=self._prefetch_worker, daemon=True) + worker.start() + self.prefetch_workers.append(worker) + + # Start save worker thread + self.save_worker = threading.Thread(target=self._save_worker, daemon=True) + self.save_worker.start() + self.idx = 0 + + atexit.register(self.cleanup) + + def _save_worker(self): + """Background thread that processes the save queue""" + while not self.stop_event.is_set(): + try: + save_item = self.save_queue.get(timeout=0.5) + if save_item is None: + continue + + tensor, file_path = save_item + + # Submit the save task to the thread pool + future = self.save_pool.submit( + self._save_tensor_to_disk, tensor, file_path + ) + with self.manager_lock: + self.save_futures[file_path] = future + + self.save_queue.task_done() + + except queue.Empty: + time.sleep(0.01) # Small sleep to prevent CPU spinning + continue + + def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str): + """Actually save the tensor to disk""" + try: + # Save tensor to disk + cpu_tensor = tensor.detach().cpu() + torch.save(cpu_tensor, file_path) + del cpu_tensor + + with self.manager_lock: + # Mark file as ready + self.file_status[file_path] = "ready" + + # Release semaphore + self.save_semaphore.release() + + return True + except FileNotFoundError as e: + logger.error(f"Error saving tensor to {file_path}: {e}") + with self.manager_lock: + self.file_status[file_path] = "error" + + # Release semaphore + self.save_semaphore.release() + + return False + + def _prefetch_worker(self): + """Background thread that loads tensors from disk ahead of time""" + while not self.stop_event.is_set(): + try: + file_path = self.prefetch_queue.get(timeout=0.5) + if file_path is None: + continue + + # Check if file is available and not already in cache + with self.manager_lock: + if ( + file_path not in self.file_status + or self.file_status[file_path] == "deleted" + ): + self.prefetch_queue.task_done() + if file_path in self.prefetch_cache: + self.prefetch_queue.task_done() + continue + + # If file is still being saved, wait for it + if ( + self.file_status[file_path] == "saving" + and file_path in self.save_futures + ): + # Re-queue this prefetch request with a little delay + self.prefetch_queue.task_done() + time.sleep(0.1) + self.prefetch_queue.put(file_path) + continue + + # Mark file as being prefetched + self.file_status[file_path] = "prefetching" + + # Load tensor from disk and store in cache + try: + if os.path.exists(file_path): + if self.prefetch_to_gpu: + tensor = torch.load( + file_path, + map_location=torch.device("cuda"), + weights_only=True, + ) + else: + tensor = torch.load(file_path, weights_only=True) + + with self.manager_lock: + self.prefetch_cache[file_path] = tensor + self.file_status[file_path] = "ready" + else: + with self.manager_lock: + if self.file_status.get(file_path) != "deleted": + logger.warning( + f"Prefetch error: File not found {file_path}" + ) + self.file_status[file_path] = "missing" + + except FileNotFoundError as e: + with self.manager_lock: + if self.file_status.get(file_path) != "deleted": + logger.warning(f"Prefetch error for {file_path}: {e}") + self.file_status[file_path] = "error" + + self.prefetch_queue.task_done() + + except queue.Empty: + time.sleep(0.01) # Small sleep to prevent CPU spinning + continue + + def save_tensor(self, tensor: torch.Tensor): + """Save tensor to disk asynchronously and return file path with thread-safe operations""" + # Generate unique file path + self.idx += 1 + file_path: str = os.path.join( + self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt" + ) + + with self.manager_lock: + # Mark file as being saved + self.file_locks[file_path] = threading.Lock() + self.file_status[file_path] = "saving" + # Add to history + self.tensor_paths.append(file_path) + + # Acquire semaphore to limit concurrent save operations + self.save_semaphore.acquire() # pylint: disable=consider-using-with + # Queue tensor for saving in background + self.save_queue.put((tensor.detach(), file_path)) + + return file_path + + def wait_for_save(self, file_path, timeout=None) -> None: + """Wait for a tensor to be saved to disk""" + start_time = time.time() + while timeout is None or time.time() - start_time < timeout: + with self.manager_lock: + if self.file_status.get(file_path) == "ready": + return + if self.file_status.get(file_path) in ["error", "missing", "deleted"]: + return + + if file_path in self.save_futures: + future = self.save_futures[file_path] + if future.done(): + return + + # Small sleep to prevent CPU spinning + time.sleep(0.01) + + # Timeout + logger.warning(f"Timeout waiting for tensor to be saved: {file_path}") + return + + def load_tensor(self, file_path, target_device="cuda"): + """Load tensor from disk or prefetch cache with proper synchronization""" + # Wait for tensor to be saved if it's still in progress + self.wait_for_save(file_path) + + tensor = None + + # Try to get from cache first + with self.manager_lock: + # Check if tensor is already in cache + if file_path in self.prefetch_cache: + tensor = self.prefetch_cache[file_path] + del self.prefetch_cache[file_path] + self.file_status[file_path] = "loaded" + + if tensor is not None: + # Ensure tensor is on correct device + if target_device != "cpu" and tensor.device.type == "cpu": + tensor = tensor.to(target_device, non_blocking=True) + return tensor + + # If not in cache, load directly from disk + try: + if not os.path.exists(file_path): + logger.error(f"File not found for loading: {file_path}") + raise FileNotFoundError(f"File not found: {file_path}") + + tensor = torch.load(file_path, weights_only=True) + + with self.manager_lock: + self.file_status[file_path] = "loaded" + + if target_device != "cpu": + tensor = tensor.to(target_device, non_blocking=True) + + return tensor + + except Exception as e: + logger.error(f"Error loading tensor from {file_path}: {e}") + raise + + def _safe_delete_file(self, file_path): + """Safely delete a file with proper synchronization""" + with self.manager_lock: + # Make sure any save operation is completed + if file_path in self.save_futures: + future = self.save_futures[file_path] + try: + if not future.done(): + future.cancel() + del self.save_futures[file_path] + except FileNotFoundError as e: + logger.warning( + f"Error canceling save operation for {file_path}: {e}" + ) + + # Only delete if file exists and is not being prefetched + status = self.file_status.get(file_path) + if status in ["ready", "loaded", "error", "missing"]: + try: + if os.path.exists(file_path): + os.remove(file_path) + self.file_status[file_path] = "deleted" + return True + except FileNotFoundError as e: + logger.warning(f"Error deleting file {file_path}: {e}") + return False + + def trigger_prefetch(self, n=None): + """Trigger prefetching of the next N tensors with proper synchronization""" + if n is None: + n = self.max_prefetch + + prefetch_paths = [] + with self.manager_lock: + # Find files that are ready to be prefetched (not already in cache or being prefetched) + for path in reversed(self.tensor_paths): + if ( + path not in self.prefetch_cache + and self.file_status.get(path) == "ready" + ): + prefetch_paths.append(path) + if len(prefetch_paths) >= n: + break + + # Queue files for prefetching + for path in prefetch_paths: + self.prefetch_queue.put(path) + + def cleanup_tensor(self, file_path: str): + """Clean up a specific tensor file after it's been used""" + with self.manager_lock: + if file_path in self.tensor_paths: + self.tensor_paths.remove(file_path) + + # Remove from prefetch cache if present + if file_path in self.prefetch_cache: + del self.prefetch_cache[file_path] + + # Remove from save futures if present + if file_path in self.save_futures: + future = self.save_futures[file_path] + if not future.done(): + future.cancel() + del self.save_futures[file_path] + + # Try to delete the file + self._safe_delete_file(file_path) + + def cleanup(self): + """Clean up all temp files and stop prefetch thread with proper synchronization""" + self.stop_event.set() + + # Cancel all pending save operations + with self.manager_lock: + for _, future in self.save_futures.items(): + if not future.done(): + future.cancel() + self.save_futures.clear() + + # Drain the save queue + while not self.save_queue.empty(): + try: + self.save_queue.get_nowait() + self.save_queue.task_done() + except queue.Empty: + break + + # Shutdown the save pool + self.save_pool.shutdown(wait=False) + + # Join the save worker thread + if self.save_worker.is_alive(): + self.save_worker.join(timeout=2.0) + + # Join the prefetch worker threads + for thread in self.prefetch_workers: + if thread.is_alive(): + thread.join(timeout=2.0) + + # Clear cache and remove all temporary files + with self.manager_lock: + self.prefetch_cache.clear() + paths_to_delete = list(self.tensor_paths) + self.tensor_paths.clear() + + # Delete all temporary files + for path in paths_to_delete: + self._safe_delete_file(path) + + # Remove temp directory + try: + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir, ignore_errors=True) + except FileNotFoundError as e: + logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}") + + +class Disco(torch.autograd.Function): + """ + Disco: DIsk-based Storage and Checkpointing with Optimized prefetching + Advanced disk-based gradient checkpointer with prefetching. + """ + + # Shared manager instance across all checkpointing operations + _manager = None + + @staticmethod + def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4): + """Get or create the offload manager""" + if Disco._manager is None: + Disco._manager = DiskOffloadManager( + prefetch_size=prefetch_size, + prefetch_to_gpu=prefetch_to_gpu, + save_workers=save_workers, + ) + return Disco._manager + + @staticmethod + @torch_cuda_amp_custom_fwd + def forward( + ctx, + forward_function, + hidden_states, + *args, + prefetch_size=1, + prefetch_to_gpu=True, + save_workers=4, + ): + """Forward pass that offloads activations to disk asynchronously""" + # Get or create the manager + manager = Disco.get_instance( + prefetch_size=prefetch_size, + prefetch_to_gpu=prefetch_to_gpu, + save_workers=save_workers, + ) + + # Save tensor to disk asynchronously + file_path = manager.save_tensor(hidden_states) + + # Run forward pass immediately without waiting for save to complete + with torch.no_grad(): + output = forward_function(hidden_states, *args) + + # Store what we need for backward + ctx.save_for_backward(torch.tensor([0])) # Dummy tensor + ctx.file_path = file_path + ctx.forward_function = forward_function + ctx.args = args + + return output + + @staticmethod + @torch_cuda_amp_custom_bwd + def backward(ctx, *grad_outputs): + """Backward pass that loads activations from disk with prefetching""" + # Get the manager + manager = Disco._manager + + # Trigger prefetching for future tensors + # This happens at the start of backward, so should have time to complete + manager.trigger_prefetch() + + # Load hidden states from disk or prefetch cache + file_path = ctx.file_path + try: + # Ensure the file is saved before we try to load it + manager.wait_for_save(file_path) + + hidden_states = manager.load_tensor(file_path) + hidden_states.requires_grad = True + + # Compute gradients + with torch.enable_grad(): + output = ctx.forward_function(hidden_states, *ctx.args) + + # Handle tuple outputs properly + if isinstance(output, tuple): + if len(grad_outputs) == len(output): + torch.autograd.backward(output, grad_outputs) + else: + torch.autograd.backward(output, grad_outputs[0]) + else: + torch.autograd.backward(output, grad_outputs[0]) + + # Clean up the file after we're done with it + manager.cleanup_tensor(file_path) + + return ( + ( + None, # forward_function + hidden_states.grad, # hidden_states grad + ) + + (None,) * len(ctx.args) # for each arg + + ( + None, # prefetch_size + None, # prefetch_to_gpu + None, # save_workers + ) + ) + + except Exception as e: + logger.error(f"Error in backward pass: {e}") + # Clean up the file even on error + manager.cleanup_tensor(file_path) + raise diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index dff6d854b..6236f78e8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -59,6 +59,7 @@ from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) +from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config @@ -70,7 +71,10 @@ from axolotl.utils.distributed import ( is_local_main_process, is_main_process, ) -from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper +from axolotl.utils.gradient_checkpointing import ( + hf_grad_checkpoint_disk_offload_wrapper, + hf_grad_checkpoint_offload_wrapper, +) from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant from axolotl.utils.schemas.enums import RLType @@ -620,6 +624,10 @@ class ModelLoader: if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper + if self.cfg.gradient_checkpointing == "offload_disk": + transformers.modeling_utils.checkpoint = ( + hf_grad_checkpoint_disk_offload_wrapper + ) if self.cfg.flash_attention: self.patch_attention() @@ -674,16 +682,25 @@ class ModelLoader: patch_self_attn_lora(self.cfg) if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: - from axolotl.monkeypatch.attention.ring_attn import register_ring_attn + from axolotl.monkeypatch.ring_attn import ( + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, + ) # Initialize ring attn for sequence parallelism. This must be done after # model init but before the first forward pass, since it modifies flash # attn to use ring comm for SP training across multiple GPUs. - register_ring_attn( - sequence_parallel_degree=self.cfg.sequence_parallel_degree, - heads_k_stride=self.cfg.heads_k_stride, - ring_attn_func=self.cfg.ring_attn_func, - ) + if get_ring_attn_group() is None: # If already set, this is already patched + register_ring_attn( + sequence_parallel_degree=self.cfg.sequence_parallel_degree, + heads_k_stride=self.cfg.heads_k_stride, + ring_attn_func=self.cfg.ring_attn_func, + ) + patch_prepare_data_loader() + patch_prepare_device_mesh( + sequence_parallel_degree=self.cfg.sequence_parallel_degree + ) def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 25c802959..8ae9d5c04 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -178,7 +178,7 @@ class AxolotlInputConfig( # torch_dtype: torch.dtype | None - gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field( + gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field( default=False ) gradient_checkpointing_kwargs: dict[str, Any] | None = None diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index a1eade531..575b7a620 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): """ ) - @pytest.mark.skip(reason="flaky test") @pytest.mark.parametrize( "num_gpus", [1, 2], @@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "NCCL_P2P_LEVEL": "LOC", **current_env, "CUDA_VISIBLE_DEVICES": "1", - "VLLM_DISABLE_COMPILE_CACHE": "1", - # "VLLM_USE_V1": "0", } vllm_process = start_vllm( cfg.base_model, @@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): finally: recursive_kill(vllm_process) - @pytest.mark.skip(reason="flaky test") @pytest.mark.parametrize( "num_gpus", [1, 2], @@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable **current_env, "CUDA_VISIBLE_DEVICES": "1", - "VLLM_DISABLE_COMPILE_CACHE": "1", - # "VLLM_USE_V1": "0", } vllm_process = start_vllm( cfg.base_model, diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py index cbabab6fd..45107b871 100644 --- a/tests/e2e/patched/test_activation_checkpointing.py +++ b/tests/e2e/patched/test_activation_checkpointing.py @@ -26,10 +26,15 @@ class TestActivationCheckpointing: E2E tests for activation checkpointing """ + @pytest.mark.parametrize( + "gradient_checkpointing", + ["offload", "offload_disk"], + ) def test_activation_checkpointing_offload( self, temp_dir, fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name + gradient_checkpointing, ): # pylint: disable=duplicate-code cfg = DictDefault( @@ -64,7 +69,7 @@ class TestActivationCheckpointing: "sample_packing": True, "bf16": True, "save_safetensors": True, - "gradient_checkpointing": "offload", + "gradient_checkpointing": gradient_checkpointing, } ) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 8efe62940..83faa779f 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -10,7 +10,7 @@ import pytest import torch from accelerate.state import PartialState -from axolotl.monkeypatch.attention.ring_attn import ( +from axolotl.monkeypatch.ring_attn import ( get_ring_attn_group, register_ring_attn, set_ring_attn_group, @@ -313,13 +313,13 @@ class TestApplySequenceParallelism: # Mock the process group monkeypatch.setattr( - "axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group", + "axolotl.monkeypatch.ring_attn.get_ring_attn_group", MagicMock, ) # Mock update_ring_attn_params monkeypatch.setattr( - "axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params", + "axolotl.monkeypatch.ring_attn.update_ring_attn_params", lambda **kwargs: None, )