diff --git a/_quarto.yml b/_quarto.yml index 0a8e023cf..804fc5e84 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -243,6 +243,7 @@ website: - docs/unsloth.qmd - docs/torchao.qmd - docs/custom_integrations.qmd + - docs/sequence_parallelism.qmd - section: "Troubleshooting" contents: diff --git a/docs/config.qmd b/docs/config.qmd index 71ddcbff6..753cf47e1 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -658,6 +658,9 @@ ddp_broadcast_buffers: # subsequences, or set to 4 to split into four equal-sized subsequences. # See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details. sequence_parallel_degree: +# Optional; strides across the key dimension. Larger values use more memory but should make training faster. +# Must evenly divide the number of KV heads in your model. +heads_k_stride: 1 # Path to torch distx for optim 'adamw_anyprecision' torchdistx_path: diff --git a/docs/multi-gpu.qmd b/docs/multi-gpu.qmd index 19293bb5b..5aec89763 100644 --- a/docs/multi-gpu.qmd +++ b/docs/multi-gpu.qmd @@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training: - DeepSpeed (recommended) - FSDP (Fully Sharded Data Parallel) +- Sequence parallelism - FSDP + QLoRA ## DeepSpeed {#sec-deepspeed} @@ -66,6 +67,28 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer ``` +## Sequence parallelism {#sec-sequence-parallelism} + +We support sequence parallelism (SP) via the +[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This +allows one to split up sequences across GPUs, which is useful in the event that a +single sequence causes OOM errors during model training. + +First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`, +or from source with `pip install .[ring-flash-attn]`. + +Your Axolotl YAML config should contain the following lines: + +```{.yaml} +sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU +flash_attention: true # Required with sequence parallelism + +# Optional; strides across the key dimension. Larger values use more memory but will make training faster. +heads_k_stride: 1 +``` + +See our [dedicated guide](sequence_parallelism.qmd) for more details. + ### FSDP + QLoRA {#sec-fsdp-qlora} For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd). diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index cb297c0e0..98ca4d746 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -25,6 +25,8 @@ To enable sequence parallelism, add the following to your configuration file: ```yaml # Set to a divisor (> 1) of the number of GPUs available sequence_parallel_degree: 4 # Split sequences across 4 GPUs +# Optional; strides across the key dimension. Larger values use more memory but should make training faster. +heads_k_stride: 1 ``` The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: @@ -58,11 +60,16 @@ To use sequence parallelism, you need: ## Example ```yaml -# Example config with sequence parallelism base_model: meta-llama/Llama-3-8B-Instruct sequence_len: 8192 -sequence_parallel_degree: 2 # Split each sequence into 4 parts + +... + +sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU flash_attention: true # Required with sequence parallelism +# Optional; strides across the key dimension. Larger values use more memory but should make training faster. +heads_k_stride: 1 + ... ``` diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b237b1ef3..2349932ba 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -69,7 +69,6 @@ from axolotl.utils.callbacks import ( LossWatchDogCallback, SaveAxolotlConfigtoWandBCallback, SaveBetterTransformerModelCallback, - SaveModelCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, log_prediction_callback_factory, @@ -249,7 +248,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.gc_steps: callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) - callbacks.append(SaveModelCallback()) return callbacks @@ -937,7 +935,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() - callbacks.append(SaveModelCallback()) return callbacks diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index a81b87909..6c9d0b429 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -38,13 +38,19 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): RING_ATTN_GROUP = ring_attn_group -def register_ring_attn(sequence_parallel_degree: int): +def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None): """ Create ring attention group and substitute flash attn with ring flash attn. Args: sequence_parallel_degree: Sequence parallelism factor. + heads_k_stride: Sequence parallelism K head stride size. Passed + through to `ring_flash_attn.substitute_hf_flash_attn`. """ + if get_ring_attn_group() is not None: + LOG.info("Ring attention already registered, exiting early...") + return + LOG.info( "Enabling ring attention sequence parallelism: " f"each sequence will be processed across {sequence_parallel_degree} GPUs" @@ -84,6 +90,11 @@ def register_ring_attn(sequence_parallel_degree: int): if rank == 0: LOG.info(f"Sequence parallel group assignments: {group_assignments}") + if heads_k_stride is None: + heads_k_stride = 1 + from ring_flash_attn import substitute_hf_flash_attn - substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree) + substitute_hf_flash_attn( + process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride + ) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 94f180ef4..ffe4699f8 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -816,27 +816,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback): return control -class SaveModelCallback(TrainerCallback): - """Callback to save model on train end""" - - def on_step_end( # pylint: disable=unused-argument - self, - args: TrainingArguments, - state: TrainerState, - control: TrainerControl, - **kwargs, - ): - # Save - if state.global_step >= state.max_steps: - control.should_save = True - - def on_train_end( # pylint: disable=unused-argument - self, args, state, control, **kwargs - ): - control.should_save = True - return control - - class GCCallback(TrainerCallback): """Callback to garbage collect torch cache""" diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 10c171d83..9611ffca2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -609,7 +609,10 @@ class ModelLoader: # Initialize ring attn for sequence parallelism. This must be done after # model init but before the first forward pass, since it modifies flash # attn to use ring comm for SP training across multiple GPUs. - register_ring_attn(self.cfg.sequence_parallel_degree) + register_ring_attn( + sequence_parallel_degree=self.cfg.sequence_parallel_degree, + heads_k_stride=self.cfg.heads_k_stride, + ) def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d52146092..51c5cf08e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -248,6 +248,7 @@ class AxolotlInputConfig( val_set_size: float | None = Field(default=0.0) sequence_parallel_degree: int | None = None + heads_k_stride: int | None = None special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None @@ -1108,7 +1109,7 @@ class AxolotlInputConfig( @field_validator("sequence_parallel_degree", mode="before") @classmethod - def check_sequence_parallel_config(cls, value, info): + def check_sequence_parallel_degree(cls, value, info): if not value: value = 1 diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 903cb1450..70beb8a54 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -110,7 +110,7 @@ class TestRingAttention: mock_new_group.return_value = mock_group # Call register_ring_attn with size 4 - register_ring_attn(sequence_parallel_degree=4) + register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1) # Verify the number of calls without examining the arguments assert mock_new_group.call_count == 2