update config.qmd and rename option

This commit is contained in:
Dan Saunders
2025-03-13 23:13:37 +00:00
parent 345a9dd831
commit 919b88f11b
11 changed files with 58 additions and 54 deletions

View File

@@ -620,6 +620,11 @@ ddp_timeout:
ddp_bucket_cap_mb:
ddp_broadcast_buffers:
# Sequence parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
sequence_parallel_degree:
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -763,8 +763,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs[
"sequence_parallel_size"
] = self.cfg.sequence_parallel_size
"sequence_parallel_degree"
] = self.cfg.sequence_parallel_degree
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -911,7 +911,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
kwargs["sequence_parallel_size"] = training_args.sequence_parallel_size
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
return collator(
*collator_args,

View File

@@ -423,9 +423,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# Handle sequence parallelism
if self.args.sequence_parallel_size > 1:
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
if self.args.sequence_parallel_degree > 1:
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
# Create base sampler for SP groups
base_sampler = torch.utils.data.distributed.DistributedSampler(
@@ -472,10 +472,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle sequence parallelism
if self.args.sequence_parallel_size > 1:
if self.args.sequence_parallel_degree > 1:
# Create sampler for SP groups
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
# Create distributed sampler for the SP group
base_sampler = torch.utils.data.distributed.DistributedSampler(
@@ -570,7 +570,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
# Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case
if self.args.sequence_parallel_size > 1:
if self.args.sequence_parallel_degree > 1:
return dataloader
return self.accelerator.prepare_data_loader(dataloader)
@@ -625,11 +625,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
# Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case
if self.args.sequence_parallel_size > 1:
if self.args.sequence_parallel_degree > 1:
return dataloader
return self.accelerator.prepare_data_loader(dataloader)
if self.args.sequence_parallel_size > 1:
if self.args.sequence_parallel_degree > 1:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
@@ -949,14 +949,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"""
Perform a training step on a batch of inputs.
"""
if self.args.sequence_parallel_size > 1:
if self.args.sequence_parallel_degree > 1:
# At this point, inputs should already be partitioned by the sequence
# parallel data collator
batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1]
# Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_size
total_seq_len = seq_len * self.args.sequence_parallel_degree
# Pass the partitioned sequence information to ring flash attention
self._update_ring_flash_attn_params(

View File

@@ -207,7 +207,7 @@ class AxolotlTrainingMixins:
},
)
sequence_parallel_size: Optional[int] = field(
sequence_parallel_degree: Optional[int] = field(
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)

View File

@@ -23,21 +23,21 @@ def set_ring_attn_group(ring_attn_group: Any):
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_size: int):
def register_ring_attn(sequence_parallel_degree: int):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_size: Sequence parallelism factor.
sequence_parallel_degree: Sequence parallelism factor.
"""
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_size} GPUs"
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
world_size = dist.get_world_size()
assert world_size % sequence_parallel_size == 0, (
f"sequence_parallel_size ({sequence_parallel_size}) "
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
@@ -45,11 +45,11 @@ def register_ring_attn(sequence_parallel_size: int):
rank = dist.get_rank()
group_assignments = {}
for i in range(world_size // sequence_parallel_size):
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_size,
(i + 1) * sequence_parallel_size,
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
@@ -65,4 +65,4 @@ def register_ring_attn(sequence_parallel_size: int):
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size)
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)

View File

@@ -95,7 +95,7 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_size (`int`):
sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
"""
@@ -107,10 +107,10 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sequence_parallel_size: int = 1
sequence_parallel_degree: int = 1
def __post_init__(self):
if self.sequence_parallel_size > 1:
if self.sequence_parallel_degree > 1:
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.rank = dist.get_rank()
@@ -183,7 +183,7 @@ class DataCollatorForSeq2Seq:
)
features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_size > 1:
if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features)
return features

View File

@@ -552,14 +552,13 @@ class ModelLoader:
patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_size > 1:
if self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
# Initialize ring attention for sequence parallelism if enabled.
# This must be done after model initialization but before the first forward pass,
# as it modifies the flash attention implementation to use ring communication
# patterns for efficient sequence-parallel training across multiple GPUs.
register_ring_attn(self.cfg.sequence_parallel_size)
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(self.cfg.sequence_parallel_degree)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):

View File

@@ -245,7 +245,7 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_size: int | None = 1
sequence_parallel_degree: int | None = 1
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -1107,10 +1107,10 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_sequence_parallel_config(cls, data):
if data.get("sequence_parallel_size") > 1:
if data.get("sequence_parallel_degree") > 1:
if not data.get("flash_attention"):
raise ValueError(
"flash_attention: true must be set with sequence_parallel_size > 1"
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
return data

View File

@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing or cfg.sequence_parallel_size > 1:
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1:
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
@@ -443,7 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1
)
* cfg.num_epochs
* cfg.sequence_parallel_size
* cfg.sequence_parallel_degree
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
@@ -476,7 +476,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
math.floor(
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_size
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
)
)
@@ -502,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.ceil(
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_size
* cfg.sequence_parallel_degree
/ cfg.batch_size
)
)

View File

@@ -25,7 +25,7 @@ def test_integration_with_config():
],
"load_in_8bit": False,
"sequence_len": 1024,
"sequence_parallel_size": 2,
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"pad_to_sequence_len": True,
@@ -58,17 +58,17 @@ def test_integration_with_config():
normalize_config(cfg)
# Verify sequence parallelism settings were properly processed
assert cfg.sequence_parallel_size == 2
assert cfg.sequence_parallel_degree == 2
assert cfg.flash_attention is True
# Check if the sequence_parallel_size was propagated to the training args
# Check if the sequence_parallel_degree was propagated to the training args
from axolotl.core.training_args import AxolotlTrainingArguments
# pylint: disable=unexpected-keyword-arg
training_args = AxolotlTrainingArguments(
output_dir=temp_dir, sequence_parallel_size=cfg.sequence_parallel_size
output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree
)
assert training_args.sequence_parallel_size == 2
assert training_args.sequence_parallel_degree == 2
def test_ring_attn_group_creation():
@@ -90,7 +90,7 @@ def test_ring_attn_group_creation():
pytest.skip(f"Need an even number of GPUs, but got {world_size}")
# Register with sequence parallel size of 2
register_ring_attn(sequence_parallel_size=2)
register_ring_attn(sequence_parallel_degree=2)
# Get the ring attention group
group = get_ring_attn_group()

View File

@@ -94,7 +94,7 @@ class TestRingAttention:
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_size=4)
register_ring_attn(sequence_parallel_degree=4)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2
@@ -175,15 +175,15 @@ def test_config_validation_with_valid_inputs(cfg):
# Import the actual model class with appropriate mocks
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
# Valid configuration: sequence_parallel_size > 1 and flash_attention is True
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
cfg = cfg | {
"sequence_parallel_size": 2,
"sequence_parallel_degree": 2,
"flash_attention": True,
}
# Should validate without errors
config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_size == 2
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
@@ -191,9 +191,9 @@ def test_config_validation_with_invalid_inputs(cfg):
"""Test that invalid sequence parallelism configurations fail validation."""
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
# Invalid configuration: sequence_parallel_size > 1 but flash_attention is False
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
cfg = cfg | {
"sequence_parallel_size": 2,
"sequence_parallel_degree": 2,
"flash_attention": False,
}