update config.qmd and rename option

This commit is contained in:
Dan Saunders
2025-03-13 23:13:37 +00:00
parent 345a9dd831
commit 919b88f11b
11 changed files with 58 additions and 54 deletions

View File

@@ -620,6 +620,11 @@ ddp_timeout:
ddp_bucket_cap_mb: ddp_bucket_cap_mb:
ddp_broadcast_buffers: ddp_broadcast_buffers:
# Sequence parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
sequence_parallel_degree:
# Path to torch distx for optim 'adamw_anyprecision' # Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path: torchdistx_path:

View File

@@ -763,8 +763,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.kd_top_k_before_softmax ] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs[ training_arguments_kwargs[
"sequence_parallel_size" "sequence_parallel_degree"
] = self.cfg.sequence_parallel_size ] = self.cfg.sequence_parallel_degree
if self.cfg.reward_model: if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig training_args_cls = AxolotlRewardConfig
@@ -911,7 +911,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt" kwargs["return_tensors"] = "pt"
kwargs["sequence_parallel_size"] = training_args.sequence_parallel_size kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
return collator( return collator(
*collator_args, *collator_args,

View File

@@ -423,9 +423,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# Handle sequence parallelism # Handle sequence parallelism
if self.args.sequence_parallel_size > 1: if self.args.sequence_parallel_degree > 1:
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
# Create base sampler for SP groups # Create base sampler for SP groups
base_sampler = torch.utils.data.distributed.DistributedSampler( base_sampler = torch.utils.data.distributed.DistributedSampler(
@@ -472,10 +472,10 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
# Handle sequence parallelism # Handle sequence parallelism
if self.args.sequence_parallel_size > 1: if self.args.sequence_parallel_degree > 1:
# Create sampler for SP groups # Create sampler for SP groups
num_sp_groups = self.args.world_size // self.args.sequence_parallel_size num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_size sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
# Create distributed sampler for the SP group # Create distributed sampler for the SP group
base_sampler = torch.utils.data.distributed.DistributedSampler( base_sampler = torch.utils.data.distributed.DistributedSampler(
@@ -570,7 +570,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
# Don't prepare dataloader for sequence parallelism # Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case # We use a distributed sampler in this case
if self.args.sequence_parallel_size > 1: if self.args.sequence_parallel_degree > 1:
return dataloader return dataloader
return self.accelerator.prepare_data_loader(dataloader) return self.accelerator.prepare_data_loader(dataloader)
@@ -625,11 +625,11 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
# Don't prepare dataloader for sequence parallelism # Don't prepare dataloader for sequence parallelism
# We use a distributed sampler in this case # We use a distributed sampler in this case
if self.args.sequence_parallel_size > 1: if self.args.sequence_parallel_degree > 1:
return dataloader return dataloader
return self.accelerator.prepare_data_loader(dataloader) return self.accelerator.prepare_data_loader(dataloader)
if self.args.sequence_parallel_size > 1: if self.args.sequence_parallel_degree > 1:
eval_dataset = ( eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset if eval_dataset is not None else self.eval_dataset
) )
@@ -949,14 +949,14 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
""" """
Perform a training step on a batch of inputs. Perform a training step on a batch of inputs.
""" """
if self.args.sequence_parallel_size > 1: if self.args.sequence_parallel_degree > 1:
# At this point, inputs should already be partitioned by the sequence # At this point, inputs should already be partitioned by the sequence
# parallel data collator # parallel data collator
batch_size = inputs["input_ids"].shape[0] batch_size = inputs["input_ids"].shape[0]
seq_len = inputs["input_ids"].shape[1] seq_len = inputs["input_ids"].shape[1]
# Calculate the full sequence length across all GPUs in this SP group # Calculate the full sequence length across all GPUs in this SP group
total_seq_len = seq_len * self.args.sequence_parallel_size total_seq_len = seq_len * self.args.sequence_parallel_degree
# Pass the partitioned sequence information to ring flash attention # Pass the partitioned sequence information to ring flash attention
self._update_ring_flash_attn_params( self._update_ring_flash_attn_params(

View File

@@ -207,7 +207,7 @@ class AxolotlTrainingMixins:
}, },
) )
sequence_parallel_size: Optional[int] = field( sequence_parallel_degree: Optional[int] = field(
default=1, default=1,
metadata={"help": "The number of workers to use in sequence parallelism"}, metadata={"help": "The number of workers to use in sequence parallelism"},
) )

View File

@@ -23,21 +23,21 @@ def set_ring_attn_group(ring_attn_group: Any):
RING_ATTN_GROUP = ring_attn_group RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_size: int): def register_ring_attn(sequence_parallel_degree: int):
""" """
Create ring attention group and substitute flash attn with ring flash attn. Create ring attention group and substitute flash attn with ring flash attn.
Args: Args:
sequence_parallel_size: Sequence parallelism factor. sequence_parallel_degree: Sequence parallelism factor.
""" """
LOG.info( LOG.info(
"Enabling ring attention sequence parallelism: " "Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_size} GPUs" f"each sequence will be processed across {sequence_parallel_degree} GPUs"
) )
world_size = dist.get_world_size() world_size = dist.get_world_size()
assert world_size % sequence_parallel_size == 0, ( assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_size ({sequence_parallel_size}) " f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})" f"must evenly divide world_size ({world_size})"
) )
@@ -45,11 +45,11 @@ def register_ring_attn(sequence_parallel_size: int):
rank = dist.get_rank() rank = dist.get_rank()
group_assignments = {} group_assignments = {}
for i in range(world_size // sequence_parallel_size): for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list( ring_attn_ranks = list(
range( range(
i * sequence_parallel_size, i * sequence_parallel_degree,
(i + 1) * sequence_parallel_size, (i + 1) * sequence_parallel_degree,
) )
) )
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
@@ -65,4 +65,4 @@ def register_ring_attn(sequence_parallel_size: int):
if rank == 0: if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}") LOG.info(f"Sequence parallel group assignments: {group_assignments}")
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size) substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)

View File

@@ -95,7 +95,7 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`): return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf". The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_size (`int`): sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism. The degree of sequence parallelism. Default to 1 for no sequence parallelism.
""" """
@@ -107,10 +107,10 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100 label_pad_token_id: int = -100
position_pad_token_id: int = 0 position_pad_token_id: int = 0
return_tensors: str = "pt" return_tensors: str = "pt"
sequence_parallel_size: int = 1 sequence_parallel_degree: int = 1
def __post_init__(self): def __post_init__(self):
if self.sequence_parallel_size > 1: if self.sequence_parallel_degree > 1:
# Get information about our position in the SP group # Get information about our position in the SP group
sp_group = get_ring_attn_group() sp_group = get_ring_attn_group()
self.rank = dist.get_rank() self.rank = dist.get_rank()
@@ -183,7 +183,7 @@ class DataCollatorForSeq2Seq:
) )
features["decoder_input_ids"] = decoder_input_ids features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_size > 1: if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features) features = self.apply_sequence_parallelism(features)
return features return features

View File

@@ -552,14 +552,13 @@ class ModelLoader:
patch_self_attn_lora(self.cfg) patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_size > 1: if self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
# Initialize ring attention for sequence parallelism if enabled. # Initialize ring attn for sequence parallelism. This must be done after
# This must be done after model initialization but before the first forward pass, # model init but before the first forward pass, since it modifies flash
# as it modifies the flash attention implementation to use ring communication # attn to use ring comm for SP training across multiple GPUs.
# patterns for efficient sequence-parallel training across multiple GPUs. register_ring_attn(self.cfg.sequence_parallel_degree)
register_ring_attn(self.cfg.sequence_parallel_size)
def patch_attention(self) -> None: def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"): if hasattr(self.model_config, "model_type"):

View File

@@ -245,7 +245,7 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0) val_set_size: float | None = Field(default=0.0)
sequence_parallel_size: int | None = 1 sequence_parallel_degree: int | None = 1
special_tokens: SpecialTokensConfig | None = None special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None tokens: list[str] | None = None
@@ -1107,10 +1107,10 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sequence_parallel_config(cls, data): def check_sequence_parallel_config(cls, data):
if data.get("sequence_parallel_size") > 1: if data.get("sequence_parallel_degree") > 1:
if not data.get("flash_attention"): if not data.get("flash_attention"):
raise ValueError( raise ValueError(
"flash_attention: true must be set with sequence_parallel_size > 1" "flash_attention: true must be set with sequence_parallel_degree > 1"
) )
return data return data

View File

@@ -346,7 +346,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
elif cfg.sample_packing or cfg.sequence_parallel_size > 1: elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -356,7 +356,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs, **filter_map_kwargs,
**drop_long_kwargs, **drop_long_kwargs,
) )
if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1: if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
add_position_ids, add_position_ids,
@@ -443,7 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
- 1 - 1
) )
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_size * cfg.sequence_parallel_degree
) )
LOG.debug( LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
@@ -476,7 +476,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
# on the agreed on value for sample_packing_eff_est # on the agreed on value for sample_packing_eff_est
total_num_steps = int( total_num_steps = int(
math.floor( math.floor(
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_size data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
) )
) )
@@ -502,7 +502,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
math.ceil( math.ceil(
len(train_dataset) len(train_dataset)
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_size * cfg.sequence_parallel_degree
/ cfg.batch_size / cfg.batch_size
) )
) )

View File

@@ -25,7 +25,7 @@ def test_integration_with_config():
], ],
"load_in_8bit": False, "load_in_8bit": False,
"sequence_len": 1024, "sequence_len": 1024,
"sequence_parallel_size": 2, "sequence_parallel_degree": 2,
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
@@ -58,17 +58,17 @@ def test_integration_with_config():
normalize_config(cfg) normalize_config(cfg)
# Verify sequence parallelism settings were properly processed # Verify sequence parallelism settings were properly processed
assert cfg.sequence_parallel_size == 2 assert cfg.sequence_parallel_degree == 2
assert cfg.flash_attention is True assert cfg.flash_attention is True
# Check if the sequence_parallel_size was propagated to the training args # Check if the sequence_parallel_degree was propagated to the training args
from axolotl.core.training_args import AxolotlTrainingArguments from axolotl.core.training_args import AxolotlTrainingArguments
# pylint: disable=unexpected-keyword-arg # pylint: disable=unexpected-keyword-arg
training_args = AxolotlTrainingArguments( training_args = AxolotlTrainingArguments(
output_dir=temp_dir, sequence_parallel_size=cfg.sequence_parallel_size output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree
) )
assert training_args.sequence_parallel_size == 2 assert training_args.sequence_parallel_degree == 2
def test_ring_attn_group_creation(): def test_ring_attn_group_creation():
@@ -90,7 +90,7 @@ def test_ring_attn_group_creation():
pytest.skip(f"Need an even number of GPUs, but got {world_size}") pytest.skip(f"Need an even number of GPUs, but got {world_size}")
# Register with sequence parallel size of 2 # Register with sequence parallel size of 2
register_ring_attn(sequence_parallel_size=2) register_ring_attn(sequence_parallel_degree=2)
# Get the ring attention group # Get the ring attention group
group = get_ring_attn_group() group = get_ring_attn_group()

View File

@@ -94,7 +94,7 @@ class TestRingAttention:
mock_new_group.return_value = mock_group mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4 # Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_size=4) register_ring_attn(sequence_parallel_degree=4)
# Verify the number of calls without examining the arguments # Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2 assert mock_new_group.call_count == 2
@@ -175,15 +175,15 @@ def test_config_validation_with_valid_inputs(cfg):
# Import the actual model class with appropriate mocks # Import the actual model class with appropriate mocks
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
# Valid configuration: sequence_parallel_size > 1 and flash_attention is True # Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
cfg = cfg | { cfg = cfg | {
"sequence_parallel_size": 2, "sequence_parallel_degree": 2,
"flash_attention": True, "flash_attention": True,
} }
# Should validate without errors # Should validate without errors
config = AxolotlInputConfig(**cfg) config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_size == 2 assert config.sequence_parallel_degree == 2
assert config.flash_attention is True assert config.flash_attention is True
@@ -191,9 +191,9 @@ def test_config_validation_with_invalid_inputs(cfg):
"""Test that invalid sequence parallelism configurations fail validation.""" """Test that invalid sequence parallelism configurations fail validation."""
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
# Invalid configuration: sequence_parallel_size > 1 but flash_attention is False # Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
cfg = cfg | { cfg = cfg | {
"sequence_parallel_size": 2, "sequence_parallel_degree": 2,
"flash_attention": False, "flash_attention": False,
} }