SDPA context parallel
This commit is contained in:
@@ -11,11 +11,13 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||||
from peft import PeftConfig, PeftModel
|
from peft import PeftConfig, PeftModel
|
||||||
|
from torch.distributed.tensor.experimental import _context_parallel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
@@ -203,19 +205,32 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.sequence_parallel_degree > 1:
|
if cfg.sequence_parallel_degree > 1:
|
||||||
models = [trainer.model]
|
if cfg.sdp_attention:
|
||||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
world_size = dist.get_world_size()
|
||||||
models.append(trainer.ref_model)
|
mesh_shape = (
|
||||||
|
world_size // cfg.sequence_parallel_degree,
|
||||||
stack.enter_context(
|
cfg.sequence_parallel_degree,
|
||||||
SequenceParallelContextManager(
|
)
|
||||||
models=models,
|
mesh = dist.DeviceMesh(
|
||||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
"cuda",
|
||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||||
ring_attn_func=cfg.ring_attn_func,
|
mesh_dim_names=("dp", "cp"),
|
||||||
heads_k_stride=cfg.heads_k_stride,
|
)
|
||||||
|
stack.enter_context(_context_parallel(seq_dim=2, mesh=mesh))
|
||||||
|
else: # flash_attention
|
||||||
|
models = [trainer.model]
|
||||||
|
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||||
|
models.append(trainer.ref_model)
|
||||||
|
|
||||||
|
stack.enter_context(
|
||||||
|
SequenceParallelContextManager(
|
||||||
|
models=models,
|
||||||
|
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||||
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
|
heads_k_stride=cfg.heads_k_stride,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|||||||
Reference in New Issue
Block a user