SDPA context parallel

This commit is contained in:
Dan Saunders
2025-06-06 00:34:12 +00:00
parent 7909bfb076
commit 10d1e44943

View File

@@ -11,11 +11,13 @@ from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
import torch import torch
import torch.distributed as dist
import transformers.modelcard import transformers.modelcard
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel
from torch.distributed.tensor.experimental import _context_parallel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
@@ -203,19 +205,32 @@ def execute_training(
) )
if cfg.sequence_parallel_degree > 1: if cfg.sequence_parallel_degree > 1:
models = [trainer.model] if cfg.sdp_attention:
if hasattr(trainer, "ref_model") and trainer.ref_model: world_size = dist.get_world_size()
models.append(trainer.ref_model) mesh_shape = (
world_size // cfg.sequence_parallel_degree,
stack.enter_context( cfg.sequence_parallel_degree,
SequenceParallelContextManager( )
models=models, mesh = dist.DeviceMesh(
sequence_parallel_degree=cfg.sequence_parallel_degree, "cuda",
gradient_accumulation_steps=cfg.gradient_accumulation_steps, torch.tensor(list(range(world_size))).reshape(mesh_shape),
ring_attn_func=cfg.ring_attn_func, mesh_dim_names=("dp", "cp"),
heads_k_stride=cfg.heads_k_stride, )
stack.enter_context(_context_parallel(seq_dim=2, mesh=mesh))
else: # flash_attention
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
stack.enter_context(
SequenceParallelContextManager(
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
)
) )
)
LOG.info("Starting trainer...") LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)