progress; move validation to pydantic model config
This commit is contained in:
@@ -13,9 +13,9 @@ import inspect
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -80,19 +80,9 @@ def register_ring_attn(
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if rank == 0:
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
# Assign ranks to sequence parallel groups
|
||||
@@ -113,9 +103,7 @@ def register_ring_attn(
|
||||
if rank in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
|
||||
# Log the GPU group assignments
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
@@ -17,7 +17,7 @@ from accelerate.utils import save_fsdp_model
|
||||
from datasets import Dataset
|
||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||
from peft import PeftConfig, PeftModel
|
||||
from torch.distributed.tensor.experimental import _context_parallel
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer import Trainer
|
||||
@@ -216,7 +216,7 @@ def execute_training(
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
stack.enter_context(_context_parallel(seq_dim=2, mesh=mesh))
|
||||
stack.enter_context(context_parallel(mesh=mesh))
|
||||
else: # flash_attention
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||
|
||||
@@ -1194,9 +1194,24 @@ class AxolotlInputConfig(
|
||||
if not self.sequence_parallel_degree:
|
||||
self.sequence_parallel_degree = 1
|
||||
elif self.sequence_parallel_degree > 1:
|
||||
if not self.flash_attention:
|
||||
import torch
|
||||
|
||||
world_size = torch.cuda.device_count()
|
||||
if not world_size >= self.sequence_parallel_degree:
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
f"World size ({world_size}) must be greater "
|
||||
f"than or equal to SP degree ({self.sequence_parallel_degree})"
|
||||
)
|
||||
if not world_size % self.sequence_parallel_degree == 0:
|
||||
raise ValueError(
|
||||
f"SP degree ({self.sequence_parallel_degree}) "
|
||||
f"must evenly divide world size ({world_size})"
|
||||
)
|
||||
|
||||
if not (self.flash_attention or self.sdp_attention):
|
||||
raise ValueError(
|
||||
"flash_attention: true or sdp_attention: true "
|
||||
"must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
@@ -1205,14 +1220,15 @@ class AxolotlInputConfig(
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
if self.flash_attention:
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||
|
||||
Reference in New Issue
Block a user