diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 7d733cfc1..ade858b46 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -13,9 +13,9 @@ import inspect import accelerate import torch import torch.distributed as dist -from accelerate.logging import get_logger from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RingAttnFunc LOG = get_logger(__name__) @@ -80,19 +80,9 @@ def register_ring_attn( rank = dist.get_rank() world_size = dist.get_world_size() - if rank == 0: - LOG.info( - "Enabling ring attention sequence parallelism: " - f"each sequence will be processed across {sequence_parallel_degree} GPUs" - ) - - assert sequence_parallel_degree <= world_size, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must be less than or equal to world_size ({world_size})" - ) - assert world_size % sequence_parallel_degree == 0, ( - f"sequence_parallel_degree ({sequence_parallel_degree}) " - f"must evenly divide world_size ({world_size})" + LOG.info( + "Enabling ring attention sequence parallelism: " + f"each sequence will be processed across {sequence_parallel_degree} GPUs" ) # Assign ranks to sequence parallel groups @@ -113,9 +103,7 @@ def register_ring_attn( if rank in ring_attn_ranks: set_ring_attn_group(group) - # Log the GPU group assignments - if rank == 0: - LOG.info(f"Sequence parallel group assignments: {group_assignments}") + LOG.info(f"Sequence parallel group assignments: {group_assignments}") if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: from ring_flash_attn import substitute_hf_flash_attn diff --git a/src/axolotl/train.py b/src/axolotl/train.py index c155db42e..26638a975 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -17,7 +17,7 @@ from accelerate.utils import save_fsdp_model from datasets import Dataset from huggingface_hub.errors import OfflineModeIsEnabled from peft import PeftConfig, PeftModel -from torch.distributed.tensor.experimental import _context_parallel +from torch.distributed.tensor.experimental import context_parallel from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.trainer import Trainer @@ -216,7 +216,7 @@ def execute_training( torch.tensor(list(range(world_size))).reshape(mesh_shape), mesh_dim_names=("dp", "cp"), ) - stack.enter_context(_context_parallel(seq_dim=2, mesh=mesh)) + stack.enter_context(context_parallel(mesh=mesh)) else: # flash_attention models = [trainer.model] if hasattr(trainer, "ref_model") and trainer.ref_model: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e5f105053..ec1a4e23d 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1194,9 +1194,24 @@ class AxolotlInputConfig( if not self.sequence_parallel_degree: self.sequence_parallel_degree = 1 elif self.sequence_parallel_degree > 1: - if not self.flash_attention: + import torch + + world_size = torch.cuda.device_count() + if not world_size >= self.sequence_parallel_degree: raise ValueError( - "flash_attention: true must be set with sequence_parallel_degree > 1" + f"World size ({world_size}) must be greater " + f"than or equal to SP degree ({self.sequence_parallel_degree})" + ) + if not world_size % self.sequence_parallel_degree == 0: + raise ValueError( + f"SP degree ({self.sequence_parallel_degree}) " + f"must evenly divide world size ({world_size})" + ) + + if not (self.flash_attention or self.sdp_attention): + raise ValueError( + "flash_attention: true or sdp_attention: true " + "must be set with sequence_parallel_degree > 1" ) if self.sample_packing and self.micro_batch_size > 1: @@ -1205,14 +1220,15 @@ class AxolotlInputConfig( "due to a `ring-flash-attn` requirement" ) - try: - import ring_flash_attn # noqa: F401 # pylint:disable=unused-import - except ImportError as exception: - raise ImportError( - "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " - "Please install it with `pip install axolotl[ring-flash-attn] " - "or `pip install ring-flash-attn>=0.1.4`." - ) from exception + if self.flash_attention: + try: + import ring_flash_attn # noqa: F401 # pylint:disable=unused-import + except ImportError as exception: + raise ImportError( + "sequence_parallel_degree > 1 but ring_flash_attn is not installed. " + "Please install it with `pip install axolotl[ring-flash-attn] " + "or `pip install ring-flash-attn>=0.1.4`." + ) from exception # TODO: monkeypatch / callback to average losses correctly across SP ranks # / fix gradient scaling across SP ranks. Losses, grads should be scaled