Compare commits

..

2 Commits

Author SHA1 Message Date
NanoCode012
7888a35118 chore: remove unused log 2025-03-31 16:20:15 +07:00
NanoCode012
873385b7d5 feat: update xformers for new attention interface 2025-03-31 16:15:55 +07:00
14 changed files with 144 additions and 219 deletions

View File

@@ -243,7 +243,6 @@ website:
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -658,9 +658,6 @@ ddp_broadcast_buffers:
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -18,7 +18,6 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -67,28 +66,6 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).

View File

@@ -25,8 +25,6 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
@@ -60,16 +58,11 @@ To use sequence parallelism, you need:
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
sequence_parallel_degree: 2 # Split each sequence into 4 parts
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
...
```

View File

@@ -69,6 +69,7 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
@@ -248,6 +249,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())
return callbacks
@@ -935,6 +937,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(SaveModelCallback())
return callbacks

View File

@@ -15,7 +15,6 @@ from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -160,6 +159,4 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
del model
del tokenizer
cleanup_distributed()
return all_metrics

View File

@@ -38,19 +38,13 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
def register_ring_attn(sequence_parallel_degree: int):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
"""
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
return
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
@@ -90,11 +84,6 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if heads_k_stride is None:
heads_k_stride = 1
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)

View File

@@ -1,153 +1,113 @@
"""
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
Hijack the LlamaAttention forward method to use xformers if available.
Updated for transformers v4.50.0.
"""
import logging
import warnings
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from torch import nn
from transformers.models.llama.modeling_llama import repeat_kv
try:
import xformers.ops
XFORMERS_AVAILABLE = True
except ImportError:
logging.error("xformers not found! Please install it before trying to use it.")
XFORMERS_AVAILABLE = False
def xformers_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs, # pylint: disable=unused-argument
):
"""
Implements xformers memory-efficient attention for LlamaAttention with support for GQA.
Args:
module: The LlamaAttention module
query: Query states of shape [batch, num_heads, seq_len, head_dim]
key: Key states of shape [batch, num_kv_heads, seq_len, head_dim]
value: Value states of shape [batch, num_kv_heads, seq_len, head_dim]
attention_mask: Attention mask
scaling: Scaling factor for attention scores
dropout: Dropout probability
Returns:
attn_output: Output of xformers memory-efficient attention
attn_weights: None
"""
# First, handle grouped-query attention (GQA)
# We need to repeat key and value states to match the number of query heads
num_key_value_groups = getattr(module, "num_key_value_groups", 1)
key = repeat_kv(key, num_key_value_groups)
value = repeat_kv(value, num_key_value_groups)
# xformers expects inputs in shape [batch, seq_len, num_heads, head_dim]
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Determine if we need a causal mask
is_causal = getattr(module, "is_causal", True)
# Set up the attention bias for xformers
if is_causal:
# Use xformers built-in causal mask
attn_bias = xformers.ops.LowerTriangularMask()
elif attention_mask is not None:
# For non-causal attention with a mask, we'd need to convert the mask
# This is a simplification - you might need to adapt based on your mask format
attn_bias = attention_mask
else:
# No mask needed
attn_bias = None
# Apply xformers memory-efficient attention
attn_output = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attn_bias,
p=dropout if module.training else 0.0,
scale=scaling,
)
# Reshape back to [batch, seq_len, hidden_size]
attn_output = attn_output.transpose(1, 2)
return attn_output, None # Return None for attn_weights to match interface
def hijack_llama_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
cos, sin = self.rotary_emb(value_states)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# xformers-attn start
#
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
"""
Patch the LlamaAttention forward method to use xformers if available.
"""
if not XFORMERS_AVAILABLE:
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
"xformers not available. Please install it following axolotl's requirements."
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# xformers-attn end
#
import transformers.models.llama.modeling_llama as llama_modeling
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
# Add xformers to the available attention implementations
llama_modeling.ALL_ATTENTION_FUNCTIONS["xformers"] = xformers_attention_forward
return attn_output, None, past_key_value
# Create a wrapper for the original LlamaAttention forward method
original_forward = llama_modeling.LlamaAttention.forward
def patched_forward(self, *args, **kwargs):
# Set the attention implementation to xformers
# pylint: disable=protected-access
self.config._attn_implementation = "xformers"
return original_forward(self, *args, **kwargs)
# Apply the patch
llama_modeling.LlamaAttention.forward = patched_forward

View File

@@ -27,7 +27,6 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -158,8 +157,6 @@ def setup_signal_handler(
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
cleanup_distributed()
sys.exit(0)
_model_weakref = weakref.ref(model)
@@ -481,7 +478,7 @@ def train(
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
# Setup model, tokenizer, (causal or RLHF) trainer etc.
(
trainer,
model,
@@ -490,26 +487,34 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Handle untrained tokens if configured
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Additional setup
# Save initial configs
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
# Set up signal handler for graceful termination
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg)
# Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model and cleanup
# Save the trained model
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
return model, tokenizer, trainer

View File

@@ -816,6 +816,27 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control
class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""

View File

@@ -71,8 +71,8 @@ def barrier():
def is_main_process():
"""
Check if the current process is the main process. If not in distributed mode,
always return `True`.
Check if the current process is the main process.
If not in distributed mode, always return True.
"""
if not is_distributed():
return True
@@ -87,18 +87,6 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
def cleanup_distributed():
"""
Destroy process group if torch distributed is initialized. Called in training early
termination or when training successfully completes.
"""
# Ensure that all operations are completed before destroying the process group
torch.cuda.synchronize()
# Destroy the process group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""

View File

@@ -609,10 +609,7 @@ class ModelLoader:
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
)
register_ring_attn(self.cfg.sequence_parallel_degree)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):

View File

@@ -248,7 +248,6 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -1109,7 +1108,7 @@ class AxolotlInputConfig(
@field_validator("sequence_parallel_degree", mode="before")
@classmethod
def check_sequence_parallel_degree(cls, value, info):
def check_sequence_parallel_config(cls, value, info):
if not value:
value = 1

View File

@@ -110,7 +110,7 @@ class TestRingAttention:
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
register_ring_attn(sequence_parallel_degree=4)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2