diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index a51211263..44f8c5d2b 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -932,9 +932,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator = DataCollatorForSeq2Seq kwargs["return_tensors"] = "pt" - if issubclass(collator, DataCollatorForSeq2Seq): - kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree - kwargs["ring_attn_func"] = training_args.ring_attn_func return collator( *collator_args, diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 3395f3f44..45facf832 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,20 +1,12 @@ -""" -Data collators for axolotl to pad labels and position_ids for packed sequences. Also -includes logic for handling sequence parallelism collation. -""" +"""Data collators for axolotl to pad labels and position_ids for packed sequences""" from dataclasses import dataclass from typing import Any import numpy as np -import torch -import torch.distributed as dist from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params -from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc - @dataclass class DataCollatorForSeq2Seq: @@ -49,8 +41,6 @@ class DataCollatorForSeq2Seq: The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). return_tensors (`str`): The type of Tensor to return. Allowable values are "np", "pt" and "tf". - sequence_parallel_degree (`int`): - The degree of sequence parallelism. Default to 1 for no sequence parallelism. """ tokenizer: PreTrainedTokenizerBase @@ -61,17 +51,6 @@ class DataCollatorForSeq2Seq: label_pad_token_id: int = -100 position_pad_token_id: int = 0 return_tensors: str = "pt" - sequence_parallel_degree: int = 1 - ring_attn_func: RingAttnFunc | None = None - - def __post_init__(self): - if self.sequence_parallel_degree > 1: - from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group - - # Get information about our position in the SP group - sp_group = get_ring_attn_group() - self.local_rank = dist.get_rank(group=sp_group) - self.local_world_size = dist.get_world_size(group=sp_group) def __call__(self, features, return_tensors=None): has_attn_mask = "attention_mask" in features[0].keys() @@ -141,62 +120,8 @@ class DataCollatorForSeq2Seq: ) features["decoder_input_ids"] = decoder_input_ids - # if self.sequence_parallel_degree > 1: - # features = self.apply_sequence_parallelism(features) - return features - def apply_sequence_parallelism( - self, batch: dict[str, torch.Tensor] - ) -> torch.Tensor: - """ - Apply sequence parallelism slicing to a batch. - - Args: - batch: Batch dictionary from parent collator. - - Returns: - Sliced batch dictionary. - """ - # Get local (start, end) for sequence parallelism slicing - total_seq_len = batch["input_ids"].size(1) - - # Update params for varlen ring attention calculation - if batch.get("position_ids") is not None: - update_ring_attn_params(position_ids=batch["position_ids"]) - - # Slice batch for sequence parallel processing - for key in batch: - if batch[key].size(1) == total_seq_len: - if self.ring_attn_func in [ - RingAttnFunc.VARLEN_LLAMA3, - RingAttnFunc.BATCH_RING, - ]: - batch[key] = ( - batch[key] - .chunk(self.local_world_size, dim=1)[self.local_rank] - .contiguous() - ) - elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: - chunks = batch[key].chunk(2 * self.local_world_size, dim=1) - - # Take rank's chunk and opposing chunk for zigzag pattern - selected_chunks = [ - chunks[self.local_rank], - chunks[2 * self.local_world_size - self.local_rank - 1], - ] - batch[key] = torch.cat(selected_chunks, dim=1).contiguous() - elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE: - # TODO(djsaunde): This doesn't seem to work as expected - # Split into striped data and stack - tensor = torch.stack( - batch[key].split(self.local_world_size, dim=1), - dim=1, - ).transpose(1, 2) - batch[key] = tensor[:, self.local_rank].contiguous() - - return batch - @dataclass class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index b527dce08..e5ea44aa0 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -126,9 +126,6 @@ def normalize_config(cfg): with open(ds_config_path, encoding="utf-8") as f: cfg.deepspeed = json.load(f) - if cfg.sequence_parallel_degree is None: - cfg.sequence_parallel_degree = 1 - if cfg.saves_per_epoch: save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) if save_steps < 1.0: # prevent saves on every step diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 07bedbbd7..b749bebc2 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -719,9 +719,10 @@ class AxolotlInputConfig( and data.get("eval_sample_packing") is None and not data.get("eval_table_size") ): - LOG.info( - "explicitly setting `eval_sample_packing` to match `sample_packing`" - ) + if is_main_process(): + LOG.info( + "explicitly setting `eval_sample_packing` to match `sample_packing`" + ) data["eval_sample_packing"] = True if ( @@ -1192,10 +1193,9 @@ class AxolotlInputConfig( @model_validator(mode="after") def validate_ring_attn_func(self): - if self.sequence_parallel_degree == 1: + if getattr(self, "sequence_parallel_degree", 1) == 1: return self - # Your validation logic for ring_attn_func from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc if self.ring_attn_func is not None: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c1154be68..3dc9ae3f6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -348,7 +348,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) - elif cfg.sample_packing or cfg.sequence_parallel_degree > 1: + elif cfg.sample_packing: drop_long_kwargs = {} if filter_map_kwargs: drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" @@ -358,7 +358,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): **filter_map_kwargs, **drop_long_kwargs, ) - if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1: + if cfg.eval_sample_packing: if eval_dataset: eval_dataset = eval_dataset.map( add_position_ids,