Compare commits
6 Commits
feat/soap-
...
destroy-pg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1defb8a955 | ||
|
|
70b466aa67 | ||
|
|
ef6eb77cc8 | ||
|
|
32ce167404 | ||
|
|
1c4cc639f5 | ||
|
|
5410195e0b |
@@ -243,6 +243,7 @@ website:
|
|||||||
- docs/unsloth.qmd
|
- docs/unsloth.qmd
|
||||||
- docs/torchao.qmd
|
- docs/torchao.qmd
|
||||||
- docs/custom_integrations.qmd
|
- docs/custom_integrations.qmd
|
||||||
|
- docs/sequence_parallelism.qmd
|
||||||
|
|
||||||
- section: "Troubleshooting"
|
- section: "Troubleshooting"
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -658,6 +658,9 @@ ddp_broadcast_buffers:
|
|||||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||||
sequence_parallel_degree:
|
sequence_parallel_degree:
|
||||||
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
|
# Must evenly divide the number of KV heads in your model.
|
||||||
|
heads_k_stride: 1
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training:
|
|||||||
|
|
||||||
- DeepSpeed (recommended)
|
- DeepSpeed (recommended)
|
||||||
- FSDP (Fully Sharded Data Parallel)
|
- FSDP (Fully Sharded Data Parallel)
|
||||||
|
- Sequence parallelism
|
||||||
- FSDP + QLoRA
|
- FSDP + QLoRA
|
||||||
|
|
||||||
## DeepSpeed {#sec-deepspeed}
|
## DeepSpeed {#sec-deepspeed}
|
||||||
@@ -66,6 +67,28 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Sequence parallelism {#sec-sequence-parallelism}
|
||||||
|
|
||||||
|
We support sequence parallelism (SP) via the
|
||||||
|
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
|
||||||
|
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||||
|
single sequence causes OOM errors during model training.
|
||||||
|
|
||||||
|
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
|
||||||
|
or from source with `pip install .[ring-flash-attn]`.
|
||||||
|
|
||||||
|
Your Axolotl YAML config should contain the following lines:
|
||||||
|
|
||||||
|
```{.yaml}
|
||||||
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
|
flash_attention: true # Required with sequence parallelism
|
||||||
|
|
||||||
|
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
|
||||||
|
heads_k_stride: 1
|
||||||
|
```
|
||||||
|
|
||||||
|
See our [dedicated guide](sequence_parallelism.qmd) for more details.
|
||||||
|
|
||||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||||
|
|
||||||
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ To enable sequence parallelism, add the following to your configuration file:
|
|||||||
```yaml
|
```yaml
|
||||||
# Set to a divisor (> 1) of the number of GPUs available
|
# Set to a divisor (> 1) of the number of GPUs available
|
||||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||||
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
|
heads_k_stride: 1
|
||||||
```
|
```
|
||||||
|
|
||||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||||
@@ -58,11 +60,16 @@ To use sequence parallelism, you need:
|
|||||||
## Example
|
## Example
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Example config with sequence parallelism
|
|
||||||
base_model: meta-llama/Llama-3-8B-Instruct
|
base_model: meta-llama/Llama-3-8B-Instruct
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sequence_parallel_degree: 2 # Split each sequence into 4 parts
|
|
||||||
|
...
|
||||||
|
|
||||||
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
flash_attention: true # Required with sequence parallelism
|
flash_attention: true # Required with sequence parallelism
|
||||||
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
|
heads_k_stride: 1
|
||||||
|
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ from axolotl.utils.callbacks import (
|
|||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SaveModelCallback,
|
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -249,7 +248,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.gc_steps:
|
if self.cfg.gc_steps:
|
||||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||||
callbacks.append(SaveModelCallback())
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -937,7 +935,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
callbacks.append(SaveModelCallback())
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from axolotl.logging_config import configure_logging
|
|||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -159,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
|
|||||||
del model
|
del model
|
||||||
del tokenizer
|
del tokenizer
|
||||||
|
|
||||||
|
cleanup_distributed()
|
||||||
|
|
||||||
return all_metrics
|
return all_metrics
|
||||||
|
|||||||
@@ -38,13 +38,19 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
|||||||
RING_ATTN_GROUP = ring_attn_group
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(sequence_parallel_degree: int):
|
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||||
"""
|
"""
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequence_parallel_degree: Sequence parallelism factor.
|
sequence_parallel_degree: Sequence parallelism factor.
|
||||||
|
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||||
|
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||||
"""
|
"""
|
||||||
|
if get_ring_attn_group() is not None:
|
||||||
|
LOG.info("Ring attention already registered, exiting early...")
|
||||||
|
return
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Enabling ring attention sequence parallelism: "
|
"Enabling ring attention sequence parallelism: "
|
||||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
@@ -84,6 +90,11 @@ def register_ring_attn(sequence_parallel_degree: int):
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
|
|
||||||
|
if heads_k_stride is None:
|
||||||
|
heads_k_stride = 1
|
||||||
|
|
||||||
from ring_flash_attn import substitute_hf_flash_attn
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|
||||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
substitute_hf_flash_attn(
|
||||||
|
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||||
|
)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
|||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
@@ -157,6 +158,8 @@ def setup_signal_handler(
|
|||||||
_model.save_pretrained(
|
_model.save_pretrained(
|
||||||
cfg.output_dir, safe_serialization=safe_serialization
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cleanup_distributed()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
_model_weakref = weakref.ref(model)
|
_model_weakref = weakref.ref(model)
|
||||||
@@ -478,7 +481,7 @@ def train(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (model, tokenizer) after training
|
Tuple of (model, tokenizer) after training
|
||||||
"""
|
"""
|
||||||
# Setup model, tokenizer, (causal or RLHF) trainer etc.
|
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||||
(
|
(
|
||||||
trainer,
|
trainer,
|
||||||
model,
|
model,
|
||||||
@@ -487,34 +490,26 @@ def train(
|
|||||||
processor,
|
processor,
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
# Determine if we need to resume from a checkpoint
|
|
||||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
|
||||||
|
|
||||||
# Configuration for saving
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
|
||||||
|
|
||||||
# Handle untrained tokens if configured
|
# Handle untrained tokens if configured
|
||||||
|
safe_serialization = cfg.save_safetensors is True
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
handle_untrained_tokens_fix(
|
handle_untrained_tokens_fix(
|
||||||
cfg, model, tokenizer, train_dataset, safe_serialization
|
cfg, model, tokenizer, train_dataset, safe_serialization
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save initial configs
|
# Additional setup
|
||||||
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
|
||||||
|
|
||||||
# Set up signal handler for graceful termination
|
|
||||||
setup_signal_handler(cfg, model, safe_serialization)
|
setup_signal_handler(cfg, model, safe_serialization)
|
||||||
|
|
||||||
# Set up badges and config info for model card
|
|
||||||
setup_model_card(cfg)
|
setup_model_card(cfg)
|
||||||
|
|
||||||
# Execute the training
|
# Execute the training
|
||||||
|
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||||
|
|
||||||
# Save the trained model
|
# Save the trained model and cleanup
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
|
|
||||||
# Create model card
|
|
||||||
create_model_card(cfg, trainer)
|
create_model_card(cfg, trainer)
|
||||||
|
if not cfg.use_ray:
|
||||||
|
cleanup_distributed()
|
||||||
|
|
||||||
return model, tokenizer, trainer
|
return model, tokenizer, trainer
|
||||||
|
|||||||
@@ -816,27 +816,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class SaveModelCallback(TrainerCallback):
|
|
||||||
"""Callback to save model on train end"""
|
|
||||||
|
|
||||||
def on_step_end( # pylint: disable=unused-argument
|
|
||||||
self,
|
|
||||||
args: TrainingArguments,
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# Save
|
|
||||||
if state.global_step >= state.max_steps:
|
|
||||||
control.should_save = True
|
|
||||||
|
|
||||||
def on_train_end( # pylint: disable=unused-argument
|
|
||||||
self, args, state, control, **kwargs
|
|
||||||
):
|
|
||||||
control.should_save = True
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
class GCCallback(TrainerCallback):
|
class GCCallback(TrainerCallback):
|
||||||
"""Callback to garbage collect torch cache"""
|
"""Callback to garbage collect torch cache"""
|
||||||
|
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ def barrier():
|
|||||||
|
|
||||||
def is_main_process():
|
def is_main_process():
|
||||||
"""
|
"""
|
||||||
Check if the current process is the main process.
|
Check if the current process is the main process. If not in distributed mode,
|
||||||
If not in distributed mode, always return True.
|
always return `True`.
|
||||||
"""
|
"""
|
||||||
if not is_distributed():
|
if not is_distributed():
|
||||||
return True
|
return True
|
||||||
@@ -87,6 +87,18 @@ def get_world_size():
|
|||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_distributed():
|
||||||
|
"""
|
||||||
|
Destroy process group if torch distributed is initialized. Called in training early
|
||||||
|
termination or when training successfully completes.
|
||||||
|
"""
|
||||||
|
# Ensure that all operations are completed before destroying the process group
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# Destroy the process group
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_only():
|
def zero_only():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -609,7 +609,10 @@ class ModelLoader:
|
|||||||
# Initialize ring attn for sequence parallelism. This must be done after
|
# Initialize ring attn for sequence parallelism. This must be done after
|
||||||
# model init but before the first forward pass, since it modifies flash
|
# model init but before the first forward pass, since it modifies flash
|
||||||
# attn to use ring comm for SP training across multiple GPUs.
|
# attn to use ring comm for SP training across multiple GPUs.
|
||||||
register_ring_attn(self.cfg.sequence_parallel_degree)
|
register_ring_attn(
|
||||||
|
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||||
|
heads_k_stride=self.cfg.heads_k_stride,
|
||||||
|
)
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
|
|||||||
@@ -248,6 +248,7 @@ class AxolotlInputConfig(
|
|||||||
val_set_size: float | None = Field(default=0.0)
|
val_set_size: float | None = Field(default=0.0)
|
||||||
|
|
||||||
sequence_parallel_degree: int | None = None
|
sequence_parallel_degree: int | None = None
|
||||||
|
heads_k_stride: int | None = None
|
||||||
|
|
||||||
special_tokens: SpecialTokensConfig | None = None
|
special_tokens: SpecialTokensConfig | None = None
|
||||||
tokens: list[str] | None = None
|
tokens: list[str] | None = None
|
||||||
@@ -1108,7 +1109,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
@field_validator("sequence_parallel_degree", mode="before")
|
@field_validator("sequence_parallel_degree", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_sequence_parallel_config(cls, value, info):
|
def check_sequence_parallel_degree(cls, value, info):
|
||||||
if not value:
|
if not value:
|
||||||
value = 1
|
value = 1
|
||||||
|
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ class TestRingAttention:
|
|||||||
mock_new_group.return_value = mock_group
|
mock_new_group.return_value = mock_group
|
||||||
|
|
||||||
# Call register_ring_attn with size 4
|
# Call register_ring_attn with size 4
|
||||||
register_ring_attn(sequence_parallel_degree=4)
|
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
|
||||||
|
|
||||||
# Verify the number of calls without examining the arguments
|
# Verify the number of calls without examining the arguments
|
||||||
assert mock_new_group.call_count == 2
|
assert mock_new_group.call_count == 2
|
||||||
|
|||||||
Reference in New Issue
Block a user