SDPA context parallel

This commit is contained in:
Dan Saunders
2025-06-06 00:34:12 +00:00
parent 7909bfb076
commit 10d1e44943

View File

@@ -11,11 +11,13 @@ from pathlib import Path
from typing import Any, Dict
import torch
import torch.distributed as dist
import transformers.modelcard
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel
from torch.distributed.tensor.experimental import _context_parallel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
@@ -203,19 +205,32 @@ def execute_training(
)
if cfg.sequence_parallel_degree > 1:
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
stack.enter_context(
SequenceParallelContextManager(
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
if cfg.sdp_attention:
world_size = dist.get_world_size()
mesh_shape = (
world_size // cfg.sequence_parallel_degree,
cfg.sequence_parallel_degree,
)
mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
stack.enter_context(_context_parallel(seq_dim=2, mesh=mesh))
else: # flash_attention
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
stack.enter_context(
SequenceParallelContextManager(
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
)
)
)
LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)