progress; move validation to pydantic model config

This commit is contained in:
Dan Saunders
2025-06-07 06:58:59 +00:00
parent 10d1e44943
commit ae73123eae
3 changed files with 33 additions and 29 deletions

View File

@@ -13,9 +13,9 @@ import inspect
import accelerate import accelerate
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -80,19 +80,9 @@ def register_ring_attn(
rank = dist.get_rank() rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
if rank == 0: LOG.info(
LOG.info( "Enabling ring attention sequence parallelism: "
"Enabling ring attention sequence parallelism: " f"each sequence will be processed across {sequence_parallel_degree} GPUs"
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
) )
# Assign ranks to sequence parallel groups # Assign ranks to sequence parallel groups
@@ -113,9 +103,7 @@ def register_ring_attn(
if rank in ring_attn_ranks: if rank in ring_attn_ranks:
set_ring_attn_group(group) set_ring_attn_group(group)
# Log the GPU group assignments LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3: if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn from ring_flash_attn import substitute_hf_flash_attn

View File

@@ -17,7 +17,7 @@ from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel from peft import PeftConfig, PeftModel
from torch.distributed.tensor.experimental import _context_parallel from torch.distributed.tensor.experimental import context_parallel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer from transformers.trainer import Trainer
@@ -216,7 +216,7 @@ def execute_training(
torch.tensor(list(range(world_size))).reshape(mesh_shape), torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"), mesh_dim_names=("dp", "cp"),
) )
stack.enter_context(_context_parallel(seq_dim=2, mesh=mesh)) stack.enter_context(context_parallel(mesh=mesh))
else: # flash_attention else: # flash_attention
models = [trainer.model] models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model: if hasattr(trainer, "ref_model") and trainer.ref_model:

View File

@@ -1194,9 +1194,24 @@ class AxolotlInputConfig(
if not self.sequence_parallel_degree: if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1 self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1: elif self.sequence_parallel_degree > 1:
if not self.flash_attention: import torch
world_size = torch.cuda.device_count()
if not world_size >= self.sequence_parallel_degree:
raise ValueError( raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1" f"World size ({world_size}) must be greater "
f"than or equal to SP degree ({self.sequence_parallel_degree})"
)
if not world_size % self.sequence_parallel_degree == 0:
raise ValueError(
f"SP degree ({self.sequence_parallel_degree}) "
f"must evenly divide world size ({world_size})"
)
if not (self.flash_attention or self.sdp_attention):
raise ValueError(
"flash_attention: true or sdp_attention: true "
"must be set with sequence_parallel_degree > 1"
) )
if self.sample_packing and self.micro_batch_size > 1: if self.sample_packing and self.micro_batch_size > 1:
@@ -1205,14 +1220,15 @@ class AxolotlInputConfig(
"due to a `ring-flash-attn` requirement" "due to a `ring-flash-attn` requirement"
) )
try: if self.flash_attention:
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import try:
except ImportError as exception: import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
raise ImportError( except ImportError as exception:
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. " raise ImportError(
"Please install it with `pip install axolotl[ring-flash-attn] " "sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
"or `pip install ring-flash-attn>=0.1.4`." "Please install it with `pip install axolotl[ring-flash-attn] "
) from exception "or `pip install ring-flash-attn>=0.1.4`."
) from exception
# TODO: monkeypatch / callback to average losses correctly across SP ranks # TODO: monkeypatch / callback to average losses correctly across SP ranks
# / fix gradient scaling across SP ranks. Losses, grads should be scaled # / fix gradient scaling across SP ranks. Losses, grads should be scaled