progress; move validation to pydantic model config
This commit is contained in:
@@ -13,9 +13,9 @@ import inspect
|
|||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate.logging import get_logger
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -80,19 +80,9 @@ def register_ring_attn(
|
|||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
if rank == 0:
|
LOG.info(
|
||||||
LOG.info(
|
"Enabling ring attention sequence parallelism: "
|
||||||
"Enabling ring attention sequence parallelism: "
|
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert sequence_parallel_degree <= world_size, (
|
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
|
||||||
f"must be less than or equal to world_size ({world_size})"
|
|
||||||
)
|
|
||||||
assert world_size % sequence_parallel_degree == 0, (
|
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
|
||||||
f"must evenly divide world_size ({world_size})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign ranks to sequence parallel groups
|
# Assign ranks to sequence parallel groups
|
||||||
@@ -113,9 +103,7 @@ def register_ring_attn(
|
|||||||
if rank in ring_attn_ranks:
|
if rank in ring_attn_ranks:
|
||||||
set_ring_attn_group(group)
|
set_ring_attn_group(group)
|
||||||
|
|
||||||
# Log the GPU group assignments
|
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||||
if rank == 0:
|
|
||||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
|
||||||
|
|
||||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||||
from ring_flash_attn import substitute_hf_flash_attn
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from accelerate.utils import save_fsdp_model
|
|||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||||
from peft import PeftConfig, PeftModel
|
from peft import PeftConfig, PeftModel
|
||||||
from torch.distributed.tensor.experimental import _context_parallel
|
from torch.distributed.tensor.experimental import context_parallel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
@@ -216,7 +216,7 @@ def execute_training(
|
|||||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||||
mesh_dim_names=("dp", "cp"),
|
mesh_dim_names=("dp", "cp"),
|
||||||
)
|
)
|
||||||
stack.enter_context(_context_parallel(seq_dim=2, mesh=mesh))
|
stack.enter_context(context_parallel(mesh=mesh))
|
||||||
else: # flash_attention
|
else: # flash_attention
|
||||||
models = [trainer.model]
|
models = [trainer.model]
|
||||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||||
|
|||||||
@@ -1194,9 +1194,24 @@ class AxolotlInputConfig(
|
|||||||
if not self.sequence_parallel_degree:
|
if not self.sequence_parallel_degree:
|
||||||
self.sequence_parallel_degree = 1
|
self.sequence_parallel_degree = 1
|
||||||
elif self.sequence_parallel_degree > 1:
|
elif self.sequence_parallel_degree > 1:
|
||||||
if not self.flash_attention:
|
import torch
|
||||||
|
|
||||||
|
world_size = torch.cuda.device_count()
|
||||||
|
if not world_size >= self.sequence_parallel_degree:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
f"World size ({world_size}) must be greater "
|
||||||
|
f"than or equal to SP degree ({self.sequence_parallel_degree})"
|
||||||
|
)
|
||||||
|
if not world_size % self.sequence_parallel_degree == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"SP degree ({self.sequence_parallel_degree}) "
|
||||||
|
f"must evenly divide world size ({world_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (self.flash_attention or self.sdp_attention):
|
||||||
|
raise ValueError(
|
||||||
|
"flash_attention: true or sdp_attention: true "
|
||||||
|
"must be set with sequence_parallel_degree > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sample_packing and self.micro_batch_size > 1:
|
if self.sample_packing and self.micro_batch_size > 1:
|
||||||
@@ -1205,14 +1220,15 @@ class AxolotlInputConfig(
|
|||||||
"due to a `ring-flash-attn` requirement"
|
"due to a `ring-flash-attn` requirement"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
if self.flash_attention:
|
||||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
try:
|
||||||
except ImportError as exception:
|
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||||
raise ImportError(
|
except ImportError as exception:
|
||||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
raise ImportError(
|
||||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||||
"or `pip install ring-flash-attn>=0.1.4`."
|
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||||
) from exception
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
|
) from exception
|
||||||
|
|
||||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||||
|
|||||||
Reference in New Issue
Block a user