update
This commit is contained in:
@@ -932,9 +932,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
if issubclass(collator, DataCollatorForSeq2Seq):
|
|
||||||
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
|
||||||
kwargs["ring_attn_func"] = training_args.ring_attn_func
|
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
|
|||||||
@@ -1,20 +1,12 @@
|
|||||||
"""
|
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
|
||||||
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
|
|
||||||
includes logic for handling sequence parallelism collation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -49,8 +41,6 @@ class DataCollatorForSeq2Seq:
|
|||||||
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
return_tensors (`str`):
|
return_tensors (`str`):
|
||||||
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
sequence_parallel_degree (`int`):
|
|
||||||
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizerBase
|
tokenizer: PreTrainedTokenizerBase
|
||||||
@@ -61,17 +51,6 @@ class DataCollatorForSeq2Seq:
|
|||||||
label_pad_token_id: int = -100
|
label_pad_token_id: int = -100
|
||||||
position_pad_token_id: int = 0
|
position_pad_token_id: int = 0
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
sequence_parallel_degree: int = 1
|
|
||||||
ring_attn_func: RingAttnFunc | None = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.sequence_parallel_degree > 1:
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
|
||||||
|
|
||||||
# Get information about our position in the SP group
|
|
||||||
sp_group = get_ring_attn_group()
|
|
||||||
self.local_rank = dist.get_rank(group=sp_group)
|
|
||||||
self.local_world_size = dist.get_world_size(group=sp_group)
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
has_attn_mask = "attention_mask" in features[0].keys()
|
has_attn_mask = "attention_mask" in features[0].keys()
|
||||||
@@ -141,62 +120,8 @@ class DataCollatorForSeq2Seq:
|
|||||||
)
|
)
|
||||||
features["decoder_input_ids"] = decoder_input_ids
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
# if self.sequence_parallel_degree > 1:
|
|
||||||
# features = self.apply_sequence_parallelism(features)
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def apply_sequence_parallelism(
|
|
||||||
self, batch: dict[str, torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply sequence parallelism slicing to a batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: Batch dictionary from parent collator.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sliced batch dictionary.
|
|
||||||
"""
|
|
||||||
# Get local (start, end) for sequence parallelism slicing
|
|
||||||
total_seq_len = batch["input_ids"].size(1)
|
|
||||||
|
|
||||||
# Update params for varlen ring attention calculation
|
|
||||||
if batch.get("position_ids") is not None:
|
|
||||||
update_ring_attn_params(position_ids=batch["position_ids"])
|
|
||||||
|
|
||||||
# Slice batch for sequence parallel processing
|
|
||||||
for key in batch:
|
|
||||||
if batch[key].size(1) == total_seq_len:
|
|
||||||
if self.ring_attn_func in [
|
|
||||||
RingAttnFunc.VARLEN_LLAMA3,
|
|
||||||
RingAttnFunc.BATCH_RING,
|
|
||||||
]:
|
|
||||||
batch[key] = (
|
|
||||||
batch[key]
|
|
||||||
.chunk(self.local_world_size, dim=1)[self.local_rank]
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
|
|
||||||
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
|
|
||||||
|
|
||||||
# Take rank's chunk and opposing chunk for zigzag pattern
|
|
||||||
selected_chunks = [
|
|
||||||
chunks[self.local_rank],
|
|
||||||
chunks[2 * self.local_world_size - self.local_rank - 1],
|
|
||||||
]
|
|
||||||
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
|
|
||||||
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
|
|
||||||
# TODO(djsaunde): This doesn't seem to work as expected
|
|
||||||
# Split into striped data and stack
|
|
||||||
tensor = torch.stack(
|
|
||||||
batch[key].split(self.local_world_size, dim=1),
|
|
||||||
dim=1,
|
|
||||||
).transpose(1, 2)
|
|
||||||
batch[key] = tensor[:, self.local_rank].contiguous()
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
|
|||||||
@@ -126,9 +126,6 @@ def normalize_config(cfg):
|
|||||||
with open(ds_config_path, encoding="utf-8") as f:
|
with open(ds_config_path, encoding="utf-8") as f:
|
||||||
cfg.deepspeed = json.load(f)
|
cfg.deepspeed = json.load(f)
|
||||||
|
|
||||||
if cfg.sequence_parallel_degree is None:
|
|
||||||
cfg.sequence_parallel_degree = 1
|
|
||||||
|
|
||||||
if cfg.saves_per_epoch:
|
if cfg.saves_per_epoch:
|
||||||
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
if save_steps < 1.0: # prevent saves on every step
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
|
|||||||
@@ -719,9 +719,10 @@ class AxolotlInputConfig(
|
|||||||
and data.get("eval_sample_packing") is None
|
and data.get("eval_sample_packing") is None
|
||||||
and not data.get("eval_table_size")
|
and not data.get("eval_table_size")
|
||||||
):
|
):
|
||||||
LOG.info(
|
if is_main_process():
|
||||||
"explicitly setting `eval_sample_packing` to match `sample_packing`"
|
LOG.info(
|
||||||
)
|
"explicitly setting `eval_sample_packing` to match `sample_packing`"
|
||||||
|
)
|
||||||
data["eval_sample_packing"] = True
|
data["eval_sample_packing"] = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1192,10 +1193,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_ring_attn_func(self):
|
def validate_ring_attn_func(self):
|
||||||
if self.sequence_parallel_degree == 1:
|
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# Your validation logic for ring_attn_func
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
|
||||||
|
|
||||||
if self.ring_attn_func is not None:
|
if self.ring_attn_func is not None:
|
||||||
|
|||||||
@@ -348,7 +348,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
|
elif cfg.sample_packing:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
@@ -358,7 +358,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
|
if cfg.eval_sample_packing:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user