Compare commits

...

6 Commits

Author SHA1 Message Date
Dan Saunders
1defb8a955 Merge branch 'main' into destroy-pg 2025-03-31 14:36:43 -04:00
Dan Saunders
70b466aa67 ray bugfix 2025-03-31 18:35:41 +00:00
Dan Saunders
ef6eb77cc8 destroy process group on Ctrl+C / training or eval run (#2457)
* fix nccl pg destroy warning

* update
2025-03-31 12:36:47 -04:00
Dan Saunders
32ce167404 update 2025-03-31 14:46:15 +00:00
Dan Saunders
1c4cc639f5 fix nccl pg destroy warning 2025-03-31 14:32:50 +00:00
Dan Saunders
5410195e0b Sequence parallelism quick follow-ups; remove ModelCallback (#2450)
* guard return if ring attn alrady registered

* add docs link, bits in multi-gpu docs, remove save model callback (subsumed by HF trainers)

* configurable heads_k_stride from ring-flash-attn hf adapter
2025-03-31 09:13:42 -04:00
13 changed files with 83 additions and 48 deletions

View File

@@ -243,6 +243,7 @@ website:
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -658,6 +658,9 @@ ddp_broadcast_buffers:
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -66,6 +67,28 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).

View File

@@ -25,6 +25,8 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
@@ -58,11 +60,16 @@ To use sequence parallelism, you need:
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
sequence_parallel_degree: 2 # Split each sequence into 4 parts
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
...
```

View File

@@ -69,7 +69,6 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
@@ -249,7 +248,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())
return callbacks
@@ -937,7 +935,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(SaveModelCallback())
return callbacks

View File

@@ -15,6 +15,7 @@ from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -159,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
del model
del tokenizer
cleanup_distributed()
return all_metrics

View File

@@ -38,13 +38,19 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int):
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
"""
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
return
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
@@ -84,6 +90,11 @@ def register_ring_attn(sequence_parallel_degree: int):
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if heads_k_stride is None:
heads_k_stride = 1
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)

View File

@@ -27,6 +27,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -157,6 +158,8 @@ def setup_signal_handler(
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
cleanup_distributed()
sys.exit(0)
_model_weakref = weakref.ref(model)
@@ -478,7 +481,7 @@ def train(
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer etc.
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,
model,
@@ -487,34 +490,26 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Save initial configs
# Additional setup
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg)
# Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
return model, tokenizer, trainer

View File

@@ -816,27 +816,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control
class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""

View File

@@ -71,8 +71,8 @@ def barrier():
def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
Check if the current process is the main process. If not in distributed mode,
always return `True`.
"""
if not is_distributed():
return True
@@ -87,6 +87,18 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
def cleanup_distributed():
"""
Destroy process group if torch distributed is initialized. Called in training early
termination or when training successfully completes.
"""
# Ensure that all operations are completed before destroying the process group
torch.cuda.synchronize()
# Destroy the process group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""

View File

@@ -609,7 +609,10 @@ class ModelLoader:
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(self.cfg.sequence_parallel_degree)
register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):

View File

@@ -248,6 +248,7 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -1108,7 +1109,7 @@ class AxolotlInputConfig(
@field_validator("sequence_parallel_degree", mode="before")
@classmethod
def check_sequence_parallel_config(cls, value, info):
def check_sequence_parallel_degree(cls, value, info):
if not value:
value = 1

View File

@@ -110,7 +110,7 @@ class TestRingAttention:
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4)
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2