update config.qmd and rename option
This commit is contained in:
@@ -620,6 +620,11 @@ ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Sequence parallelism
|
||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||
sequence_parallel_degree:
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
|
||||
@@ -763,8 +763,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.kd_top_k_before_softmax
|
||||
|
||||
training_arguments_kwargs[
|
||||
"sequence_parallel_size"
|
||||
] = self.cfg.sequence_parallel_size
|
||||
"sequence_parallel_degree"
|
||||
] = self.cfg.sequence_parallel_degree
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
@@ -911,7 +911,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
kwargs["return_tensors"] = "pt"
|
||||
kwargs["sequence_parallel_size"] = training_args.sequence_parallel_size
|
||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||
|
||||
return collator(
|
||||
*collator_args,
|
||||
|
||||
@@ -423,9 +423,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
# Handle sequence parallelism
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
# Create base sampler for SP groups
|
||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
@@ -472,10 +472,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
|
||||
# Handle sequence parallelism
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Create sampler for SP groups
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
|
||||
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||
|
||||
# Create distributed sampler for the SP group
|
||||
base_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
@@ -570,7 +570,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
# Don't prepare dataloader for sequence parallelism
|
||||
# We use a distributed sampler in this case
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
@@ -625,11 +625,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
|
||||
# Don't prepare dataloader for sequence parallelism
|
||||
# We use a distributed sampler in this case
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
return self.accelerator.prepare_data_loader(dataloader)
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
@@ -949,14 +949,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
"""
|
||||
if self.args.sequence_parallel_size > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# At this point, inputs should already be partitioned by the sequence
|
||||
# parallel data collator
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
seq_len = inputs["input_ids"].shape[1]
|
||||
|
||||
# Calculate the full sequence length across all GPUs in this SP group
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_size
|
||||
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||
|
||||
# Pass the partitioned sequence information to ring flash attention
|
||||
self._update_ring_flash_attn_params(
|
||||
|
||||
@@ -207,7 +207,7 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
sequence_parallel_size: Optional[int] = field(
|
||||
sequence_parallel_degree: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||
)
|
||||
|
||||
@@ -23,21 +23,21 @@ def set_ring_attn_group(ring_attn_group: Any):
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_size: int):
|
||||
def register_ring_attn(sequence_parallel_degree: int):
|
||||
"""
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
sequence_parallel_size: Sequence parallelism factor.
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
"""
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_size} GPUs"
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
assert world_size % sequence_parallel_size == 0, (
|
||||
f"sequence_parallel_size ({sequence_parallel_size}) "
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
@@ -45,11 +45,11 @@ def register_ring_attn(sequence_parallel_size: int):
|
||||
rank = dist.get_rank()
|
||||
group_assignments = {}
|
||||
|
||||
for i in range(world_size // sequence_parallel_size):
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * sequence_parallel_size,
|
||||
(i + 1) * sequence_parallel_size,
|
||||
i * sequence_parallel_degree,
|
||||
(i + 1) * sequence_parallel_degree,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
@@ -65,4 +65,4 @@ def register_ring_attn(sequence_parallel_size: int):
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size)
|
||||
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
|
||||
|
||||
@@ -95,7 +95,7 @@ class DataCollatorForSeq2Seq:
|
||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||
return_tensors (`str`):
|
||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||
sequence_parallel_size (`int`):
|
||||
sequence_parallel_degree (`int`):
|
||||
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
||||
"""
|
||||
|
||||
@@ -107,10 +107,10 @@ class DataCollatorForSeq2Seq:
|
||||
label_pad_token_id: int = -100
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
sequence_parallel_size: int = 1
|
||||
sequence_parallel_degree: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sequence_parallel_size > 1:
|
||||
if self.sequence_parallel_degree > 1:
|
||||
# Get information about our position in the SP group
|
||||
sp_group = get_ring_attn_group()
|
||||
self.rank = dist.get_rank()
|
||||
@@ -183,7 +183,7 @@ class DataCollatorForSeq2Seq:
|
||||
)
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
if self.sequence_parallel_size > 1:
|
||||
if self.sequence_parallel_degree > 1:
|
||||
features = self.apply_sequence_parallelism(features)
|
||||
|
||||
return features
|
||||
|
||||
@@ -552,14 +552,13 @@ class ModelLoader:
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
if self.cfg.sequence_parallel_size > 1:
|
||||
if self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||
|
||||
# Initialize ring attention for sequence parallelism if enabled.
|
||||
# This must be done after model initialization but before the first forward pass,
|
||||
# as it modifies the flash attention implementation to use ring communication
|
||||
# patterns for efficient sequence-parallel training across multiple GPUs.
|
||||
register_ring_attn(self.cfg.sequence_parallel_size)
|
||||
# Initialize ring attn for sequence parallelism. This must be done after
|
||||
# model init but before the first forward pass, since it modifies flash
|
||||
# attn to use ring comm for SP training across multiple GPUs.
|
||||
register_ring_attn(self.cfg.sequence_parallel_degree)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
|
||||
@@ -245,7 +245,7 @@ class AxolotlInputConfig(
|
||||
|
||||
val_set_size: float | None = Field(default=0.0)
|
||||
|
||||
sequence_parallel_size: int | None = 1
|
||||
sequence_parallel_degree: int | None = 1
|
||||
|
||||
special_tokens: SpecialTokensConfig | None = None
|
||||
tokens: list[str] | None = None
|
||||
@@ -1107,10 +1107,10 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sequence_parallel_config(cls, data):
|
||||
if data.get("sequence_parallel_size") > 1:
|
||||
if data.get("sequence_parallel_degree") > 1:
|
||||
if not data.get("flash_attention"):
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_size > 1"
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Add position_id column (PoSE)",
|
||||
)
|
||||
elif cfg.sample_packing or cfg.sequence_parallel_size > 1:
|
||||
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1:
|
||||
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
add_position_ids,
|
||||
@@ -443,7 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_size
|
||||
* cfg.sequence_parallel_degree
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||
@@ -476,7 +476,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_size
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||
)
|
||||
)
|
||||
|
||||
@@ -502,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_size
|
||||
* cfg.sequence_parallel_degree
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_integration_with_config():
|
||||
],
|
||||
"load_in_8bit": False,
|
||||
"sequence_len": 1024,
|
||||
"sequence_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
@@ -58,17 +58,17 @@ def test_integration_with_config():
|
||||
normalize_config(cfg)
|
||||
|
||||
# Verify sequence parallelism settings were properly processed
|
||||
assert cfg.sequence_parallel_size == 2
|
||||
assert cfg.sequence_parallel_degree == 2
|
||||
assert cfg.flash_attention is True
|
||||
|
||||
# Check if the sequence_parallel_size was propagated to the training args
|
||||
# Check if the sequence_parallel_degree was propagated to the training args
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
training_args = AxolotlTrainingArguments(
|
||||
output_dir=temp_dir, sequence_parallel_size=cfg.sequence_parallel_size
|
||||
output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree
|
||||
)
|
||||
assert training_args.sequence_parallel_size == 2
|
||||
assert training_args.sequence_parallel_degree == 2
|
||||
|
||||
|
||||
def test_ring_attn_group_creation():
|
||||
@@ -90,7 +90,7 @@ def test_ring_attn_group_creation():
|
||||
pytest.skip(f"Need an even number of GPUs, but got {world_size}")
|
||||
|
||||
# Register with sequence parallel size of 2
|
||||
register_ring_attn(sequence_parallel_size=2)
|
||||
register_ring_attn(sequence_parallel_degree=2)
|
||||
|
||||
# Get the ring attention group
|
||||
group = get_ring_attn_group()
|
||||
|
||||
@@ -94,7 +94,7 @@ class TestRingAttention:
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(sequence_parallel_size=4)
|
||||
register_ring_attn(sequence_parallel_degree=4)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
@@ -175,15 +175,15 @@ def test_config_validation_with_valid_inputs(cfg):
|
||||
# Import the actual model class with appropriate mocks
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
# Valid configuration: sequence_parallel_size > 1 and flash_attention is True
|
||||
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
|
||||
cfg = cfg | {
|
||||
"sequence_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
}
|
||||
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
assert config.sequence_parallel_size == 2
|
||||
assert config.sequence_parallel_degree == 2
|
||||
assert config.flash_attention is True
|
||||
|
||||
|
||||
@@ -191,9 +191,9 @@ def test_config_validation_with_invalid_inputs(cfg):
|
||||
"""Test that invalid sequence parallelism configurations fail validation."""
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
# Invalid configuration: sequence_parallel_size > 1 but flash_attention is False
|
||||
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
|
||||
cfg = cfg | {
|
||||
"sequence_parallel_size": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": False,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user