From ce07081d6c3217a9d55c959feaa6ae146bcdc32b Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 1 Apr 2025 20:35:10 +0000 Subject: [PATCH] doc updates; config fix --- docs/config.qmd | 7 ++++--- docs/sequence_parallelism.qmd | 22 +++++++++++++--------- examples/llama-3/instruct-dpo-lora-8b.yml | 3 +++ 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 7c41c5126..74f025536 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -686,9 +686,10 @@ ddp_broadcast_buffers: # E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized # subsequences, or set to 4 to split into four equal-sized subsequences. # See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details. -sequence_parallel_degree: -# Optional; strides across the key dimension. Larger values use more memory but should make training faster. -# Must evenly divide the number of KV heads in your model. +sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across +flash_attention: true # SP requires flash attention +micro_batch_size: 1 # SP requires this is set to 1 +# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster heads_k_stride: 1 # Path to torch distx for optim 'adamw_anyprecision' diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index 98ca4d746..f39811855 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -23,9 +23,10 @@ Use sequence parallelism when: To enable sequence parallelism, add the following to your configuration file: ```yaml -# Set to a divisor (> 1) of the number of GPUs available -sequence_parallel_degree: 4 # Split sequences across 4 GPUs -# Optional; strides across the key dimension. Larger values use more memory but should make training faster. +sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU +flash_attention: true # SP requires flash attention +micro_batch_size: 1 # SP requires this is set to 1 +# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster heads_k_stride: 1 ``` @@ -66,15 +67,16 @@ sequence_len: 8192 ... sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU -flash_attention: true # Required with sequence parallelism -# Optional; strides across the key dimension. Larger values use more memory but should make training faster. +flash_attention: true # SP requires flash attention +micro_batch_size: 1 # SP requires this is set to 1 +# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster heads_k_stride: 1 ... ``` -This will train the Llama 3 8B model with 8K context length, with each sequence split -into 2 subsequences of length 4096 across 2 GPUs. +This will train the Llama 3 8B model with 8192 context length, with each sequence split +into 4 subsequences of length 2048 across 4 GPUs. ## Sample Packing with Sequence Parallelism @@ -86,12 +88,14 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality. ## Effect on Batch Size +First, note that sequence parallelism supports only the case where `micro_batch_size: 1`. + When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: - Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) - The number of batches processed per step decreases For example: -- With 8 GPUs and no sequence parallelism: 8 different batches processed per step +- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step - With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) -- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 +- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2 diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml index c7568dd78..44d45b4c1 100644 --- a/examples/llama-3/instruct-dpo-lora-8b.yml +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -82,3 +82,6 @@ deepspeed: weight_decay: 0.0 fsdp: fsdp_config: + +special_tokens: + pad_token: "<|end_of_text|>"