progress on ring attn impl
This commit is contained in:
5
=0.1.4
Normal file
5
=0.1.4
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
Collecting ring-flash-attn
|
||||||
|
Downloading ring_flash_attn-0.1.4-py3-none-any.whl.metadata (7.3 kB)
|
||||||
|
Downloading ring_flash_attn-0.1.4-py3-none-any.whl (24 kB)
|
||||||
|
Installing collected packages: ring-flash-attn
|
||||||
|
Successfully installed ring-flash-attn-0.1.4
|
||||||
@@ -67,4 +67,5 @@ axolotl-contribs-lgpl==0.0.6
|
|||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
# for sequence parallelism
|
# for sequence parallelism
|
||||||
|
yunchang.=0.6.0
|
||||||
ring-flash-attn>=0.1.4
|
ring-flash-attn>=0.1.4
|
||||||
|
|||||||
@@ -758,9 +758,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.kd_zscore_base_temp
|
self.cfg.kd_zscore_base_temp
|
||||||
)
|
)
|
||||||
if self.cfg.kd_top_k_before_softmax is not None:
|
if self.cfg.kd_top_k_before_softmax is not None:
|
||||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
training_arguments_kwargs[
|
||||||
self.cfg.kd_top_k_before_softmax
|
"kd_top_k_before_softmax"
|
||||||
)
|
] = self.cfg.kd_top_k_before_softmax
|
||||||
|
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sequence_parallel_size"
|
||||||
|
] = self.cfg.sequence_parallel_size
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
@@ -845,9 +849,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if self.cfg.pretraining_sample_concatenation is False:
|
if self.cfg.pretraining_sample_concatenation is False or self.cfg.micro_batch_size > 1:
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
|
||||||
if self.cfg.micro_batch_size > 1:
|
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,14 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Dict, Literal, Optional
|
from typing import Any, Dict, Literal, Optional
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
@@ -25,6 +28,7 @@ from trl.trainer.utils import pad_to_length
|
|||||||
|
|
||||||
from axolotl.integrations.base import BaseOptimizerFactory
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
|
from axolotl.utils.ring_attn import get_ring_attn_group
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
RexLR,
|
RexLR,
|
||||||
@@ -797,6 +801,58 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def training_step(
|
||||||
|
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Perform a training step on a batch of inputs.
|
||||||
|
|
||||||
|
Note: we are subclassing `transformers.trainer.Trainer` in order to compute
|
||||||
|
parameters needed for the ring flash attention implementation we're using.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`nn.Module`):
|
||||||
|
The model to train.
|
||||||
|
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
||||||
|
The inputs and targets of the model.
|
||||||
|
|
||||||
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||||
|
argument `labels`. Check your model's documentation for all accepted arguments.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
`torch.Tensor`: The tensor with training loss on this batch.
|
||||||
|
"""
|
||||||
|
if self.args.sequence_parallel_size > 1:
|
||||||
|
if "attention_mask" in inputs:
|
||||||
|
# Calculate sequence lengths from attention mask
|
||||||
|
seq_lens = inputs["attention_mask"].sum(dim=1).tolist()
|
||||||
|
total_seq_len = inputs["attention_mask"].shape[0] * inputs["attention_mask"].shape[1]
|
||||||
|
else:
|
||||||
|
# Assume all sequences are the same length if no mask is provided
|
||||||
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
seq_lens = [seq_len] * batch_size
|
||||||
|
total_seq_len = batch_size * seq_len
|
||||||
|
|
||||||
|
self._update_ring_flash_attn_params(seq_lens, total_seq_len)
|
||||||
|
|
||||||
|
return super().training_step(model, inputs, num_items_in_batch)
|
||||||
|
|
||||||
|
def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len):
|
||||||
|
"""
|
||||||
|
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||||
|
the substituted ring_flash_attn.
|
||||||
|
"""
|
||||||
|
cu_seqlens = torch.cumsum(
|
||||||
|
torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
|
||||||
|
|
||||||
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -207,14 +207,21 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sequence_parallel_size: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={
|
||||||
|
"help": "The number of workers to use in sequence parallelism"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
||||||
so it can't be used as a mixin.
|
default value so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ Materialization-aware gradient checkpointing monkey patch.
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import nn
|
|
||||||
from torch.utils.checkpoint import (
|
from torch.utils.checkpoint import (
|
||||||
_get_autocast_kwargs,
|
_get_autocast_kwargs,
|
||||||
check_backward_validity,
|
check_backward_validity,
|
||||||
@@ -16,10 +14,12 @@ from torch.utils.checkpoint import (
|
|||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaModel,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .async_communication import initialize_distributed, reset_global_memory_buffer
|
from .async_communication import initialize_distributed
|
||||||
from .lightseq_async_attn import _lightseq_backward, _lightseq_forward
|
from .lightseq_async_attn import _lightseq_backward, _lightseq_forward
|
||||||
|
|
||||||
# define a global buffer to save flash attention outputs
|
# define a global buffer to save flash attention outputs
|
||||||
@@ -749,7 +749,6 @@ def forward(
|
|||||||
|
|
||||||
def apply_dist_flash_attn_monkey_patch_llama():
|
def apply_dist_flash_attn_monkey_patch_llama():
|
||||||
initialize_distributed()
|
initialize_distributed()
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.forward = forward
|
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
LlamaModel.forward = forward
|
||||||
llama_layer_forward
|
LlamaDecoderLayer.forward = llama_layer_forward
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
import warnings
|
from typing import Optional, Tuple
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import transformers
|
|
||||||
from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func
|
from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
MistralAttention,
|
||||||
|
MistralDecoderLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def new_flash_attn_forward(
|
def new_flash_attn_forward(
|
||||||
@@ -18,6 +22,10 @@ def new_flash_attn_forward(
|
|||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
use_sliding_windows=False,
|
use_sliding_windows=False,
|
||||||
):
|
):
|
||||||
|
assert (
|
||||||
|
self.config._attn_implementation == "flash_attention_2"
|
||||||
|
), "Only Flash Attention is supported."
|
||||||
|
|
||||||
if not self._flash_attn_uses_top_left_mask:
|
if not self._flash_attn_uses_top_left_mask:
|
||||||
causal = self.is_causal
|
causal = self.is_causal
|
||||||
else:
|
else:
|
||||||
@@ -48,26 +56,19 @@ def new_decoder_forward(
|
|||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
assert isinstance(
|
assert isinstance(self.self_attn, LlamaAttention) or isinstance(
|
||||||
self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2
|
|
||||||
) or isinstance(
|
|
||||||
self.self_attn,
|
self.self_attn,
|
||||||
transformers.models.mistral.modeling_mistral.MistralFlashAttention2,
|
MistralAttention,
|
||||||
), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch."
|
), "Llama and Mistral attention only are supported."
|
||||||
|
|
||||||
if "padding_mask" in kwargs:
|
|
||||||
warnings.warn(
|
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
|
||||||
)
|
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
hidden_states, self_attn_weights = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@@ -75,6 +76,7 @@ def new_decoder_forward(
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
@@ -86,29 +88,19 @@ def new_decoder_forward(
|
|||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (self_attn_weights,)
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs += (present_key_value,)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
def apply_zigzag_ring_attn_monkey_patch_llama():
|
def apply_zigzag_ring_attn_monkey_patch_llama():
|
||||||
transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = (
|
# LlamaAttention._flash_attention_forward = new_flash_attn_forward
|
||||||
new_flash_attn_forward
|
ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
|
||||||
)
|
LlamaDecoderLayer.forward = new_decoder_forward
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = (
|
|
||||||
new_decoder_forward
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_zigzag_ring_attn_monkey_patch_mistral():
|
def apply_zigzag_ring_attn_monkey_patch_mistral():
|
||||||
transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = (
|
# MistralAttention._flash_attention_forward = new_flash_attn_forward
|
||||||
new_flash_attn_forward
|
ALL_ATTENTION_FUNCTIONS.update({"flash_attention_2": new_flash_attn_forward})
|
||||||
)
|
MistralDecoderLayer.forward = new_decoder_forward
|
||||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = (
|
|
||||||
new_decoder_forward
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from functools import cached_property
|
|||||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
|
from axolotl.utils.ring_attn import register_ring_attn
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -548,18 +549,17 @@ class ModelLoader:
|
|||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
if self.cfg.sequence_parallel_size > 1:
|
if self.cfg.sequence_parallel_size > 1:
|
||||||
from axolotl.integrations.easy_context import (
|
# from axolotl.integrations.easy_context import (
|
||||||
apply_seq_parallel_monkey_patch,
|
# apply_seq_parallel_monkey_patch,
|
||||||
)
|
# )
|
||||||
|
|
||||||
method = self.cfg.sequence_parallel_mode
|
# method = self.cfg.sequence_parallel_mode
|
||||||
model_type = self.cfg.model_type
|
# model_type = self.cfg.model_config_type
|
||||||
|
|
||||||
# Apply the monkey patch
|
# # Apply the monkey patch
|
||||||
apply_seq_parallel_monkey_patch(method, model_type)
|
# apply_seq_parallel_monkey_patch(method, model_type)
|
||||||
|
|
||||||
# Ensure flash attention is enabled when loading the model
|
register_ring_attn(self.cfg.sequence_parallel_size)
|
||||||
self.cfg.attn_implementation = "flash_attention_2"
|
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
if hasattr(self.model_config, "model_type"):
|
if hasattr(self.model_config, "model_type"):
|
||||||
|
|||||||
40
src/axolotl/utils/ring_attn.py
Normal file
40
src/axolotl/utils/ring_attn.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import torch.distributed as dist
|
||||||
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|
||||||
|
RING_ATTN_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ring_attn_group():
|
||||||
|
return RING_ATTN_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def set_ring_attn_group(ring_attn_group):
|
||||||
|
global RING_ATTN_GROUP
|
||||||
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
|
def register_ring_attn(sequence_parallel_size):
|
||||||
|
"""
|
||||||
|
Create ring attention group and substitute flash attention with ring flash
|
||||||
|
attention.
|
||||||
|
"""
|
||||||
|
if sequence_parallel_size == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
assert world_size % sequence_parallel_size == 0, \
|
||||||
|
f"sequence_parallel_size ({sequence_parallel_size}) " \
|
||||||
|
f"must evenly divide world_size ({world_size})"
|
||||||
|
|
||||||
|
for i in range(world_size // sequence_parallel_size):
|
||||||
|
ring_attn_ranks = list(
|
||||||
|
range(
|
||||||
|
i * sequence_parallel_size,
|
||||||
|
(i + 1) * sequence_parallel_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||||
|
if dist.get_rank() in ring_attn_ranks:
|
||||||
|
set_ring_attn_group(group)
|
||||||
|
|
||||||
|
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size)
|
||||||
@@ -8,6 +8,7 @@ from contextlib import contextmanager
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from axolotl.integrations.easy_context import prepare_seq_parallel_inputs
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
@@ -346,7 +347,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing:
|
elif cfg.sample_packing or cfg.sequence_parallel_size > 1:
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
@@ -356,7 +357,18 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
**filter_map_kwargs,
|
**filter_map_kwargs,
|
||||||
**drop_long_kwargs,
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.sequence_parallel_size > 1:
|
||||||
|
train_dataset.map(
|
||||||
|
prepare_seq_parallel_inputs,
|
||||||
|
"dist_flash_attn",
|
||||||
|
lambda batch: batch["input_ids"],
|
||||||
|
lambda batch: batch["position_ids"],
|
||||||
|
lambda batch: batch["target_ids"],
|
||||||
|
accelerator.process_index,
|
||||||
|
accelerator.num_processes,
|
||||||
|
accelerator.device,
|
||||||
|
)
|
||||||
|
if cfg.eval_sample_packing or cfg.sequence_parallel_size > 1:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user