progress; move validation to pydantic model config

This commit is contained in:
Dan Saunders
2025-06-07 06:58:59 +00:00
parent 10d1e44943
commit ae73123eae
3 changed files with 33 additions and 29 deletions

View File

@@ -13,9 +13,9 @@ import inspect
import accelerate
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
@@ -80,19 +80,9 @@ def register_ring_attn(
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
# Assign ranks to sequence parallel groups
@@ -113,9 +103,7 @@ def register_ring_attn(
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn

View File

@@ -17,7 +17,7 @@ from accelerate.utils import save_fsdp_model
from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel
from torch.distributed.tensor.experimental import _context_parallel
from torch.distributed.tensor.experimental import context_parallel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
@@ -216,7 +216,7 @@ def execute_training(
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
stack.enter_context(_context_parallel(seq_dim=2, mesh=mesh))
stack.enter_context(context_parallel(mesh=mesh))
else: # flash_attention
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:

View File

@@ -1194,9 +1194,24 @@ class AxolotlInputConfig(
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
if not self.flash_attention:
import torch
world_size = torch.cuda.device_count()
if not world_size >= self.sequence_parallel_degree:
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
f"World size ({world_size}) must be greater "
f"than or equal to SP degree ({self.sequence_parallel_degree})"
)
if not world_size % self.sequence_parallel_degree == 0:
raise ValueError(
f"SP degree ({self.sequence_parallel_degree}) "
f"must evenly divide world size ({world_size})"
)
if not (self.flash_attention or self.sdp_attention):
raise ValueError(
"flash_attention: true or sdp_attention: true "
"must be set with sequence_parallel_degree > 1"
)
if self.sample_packing and self.micro_batch_size > 1:
@@ -1205,14 +1220,15 @@ class AxolotlInputConfig(
"due to a `ring-flash-attn` requirement"
)
try:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
if self.flash_attention:
try:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
raise ImportError(
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
# TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled