diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 866a9c454..c155db42e 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -11,11 +11,13 @@ from pathlib import Path from typing import Any, Dict import torch +import torch.distributed as dist import transformers.modelcard from accelerate.utils import save_fsdp_model from datasets import Dataset from huggingface_hub.errors import OfflineModeIsEnabled from peft import PeftConfig, PeftModel +from torch.distributed.tensor.experimental import _context_parallel from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import Trainer @@ -203,19 +205,32 @@ def execute_training( ) if cfg.sequence_parallel_degree > 1: - models = [trainer.model] - if hasattr(trainer, "ref_model") and trainer.ref_model: - models.append(trainer.ref_model) - - stack.enter_context( - SequenceParallelContextManager( - models=models, - sequence_parallel_degree=cfg.sequence_parallel_degree, - gradient_accumulation_steps=cfg.gradient_accumulation_steps, - ring_attn_func=cfg.ring_attn_func, - heads_k_stride=cfg.heads_k_stride, + if cfg.sdp_attention: + world_size = dist.get_world_size() + mesh_shape = ( + world_size // cfg.sequence_parallel_degree, + cfg.sequence_parallel_degree, + ) + mesh = dist.DeviceMesh( + "cuda", + torch.tensor(list(range(world_size))).reshape(mesh_shape), + mesh_dim_names=("dp", "cp"), + ) + stack.enter_context(_context_parallel(seq_dim=2, mesh=mesh)) + else: # flash_attention + models = [trainer.model] + if hasattr(trainer, "ref_model") and trainer.ref_model: + models.append(trainer.ref_model) + + stack.enter_context( + SequenceParallelContextManager( + models=models, + sequence_parallel_degree=cfg.sequence_parallel_degree, + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + ring_attn_func=cfg.ring_attn_func, + heads_k_stride=cfg.heads_k_stride, + ) ) - ) LOG.info("Starting trainer...") trainer.train(resume_from_checkpoint=resume_from_checkpoint)