Compare commits
3 Commits
revert-mul
...
sp-rl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f30d3d33a | ||
|
|
ce07081d6c | ||
|
|
3ce43b6db9 |
@@ -686,9 +686,10 @@ ddp_broadcast_buffers:
|
|||||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||||
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||||
sequence_parallel_degree:
|
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
flash_attention: true # SP requires flash attention
|
||||||
# Must evenly divide the number of KV heads in your model.
|
micro_batch_size: 1 # SP requires this is set to 1
|
||||||
|
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
|
|||||||
@@ -23,9 +23,10 @@ Use sequence parallelism when:
|
|||||||
To enable sequence parallelism, add the following to your configuration file:
|
To enable sequence parallelism, add the following to your configuration file:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Set to a divisor (> 1) of the number of GPUs available
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
flash_attention: true # SP requires flash attention
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
micro_batch_size: 1 # SP requires this is set to 1
|
||||||
|
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -66,15 +67,16 @@ sequence_len: 8192
|
|||||||
...
|
...
|
||||||
|
|
||||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
flash_attention: true # Required with sequence parallelism
|
flash_attention: true # SP requires flash attention
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
micro_batch_size: 1 # SP requires this is set to 1
|
||||||
|
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
|
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
This will train the Llama 3 8B model with 8192 context length, with each sequence split
|
||||||
into 2 subsequences of length 4096 across 2 GPUs.
|
into 4 subsequences of length 2048 across 4 GPUs.
|
||||||
|
|
||||||
## Sample Packing with Sequence Parallelism
|
## Sample Packing with Sequence Parallelism
|
||||||
|
|
||||||
@@ -86,12 +88,14 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
|||||||
|
|
||||||
## Effect on Batch Size
|
## Effect on Batch Size
|
||||||
|
|
||||||
|
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
|
||||||
|
|
||||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||||
|
|
||||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||||
- The number of batches processed per step decreases
|
- The number of batches processed per step decreases
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
|
||||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2
|
||||||
|
|||||||
@@ -82,3 +82,6 @@ deepspeed:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
|
|||||||
@@ -1043,6 +1043,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
|
training_args_kwargs["sequence_parallel_degree"] = (
|
||||||
|
self.cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
blocklist_args_kwargs = []
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
@@ -1161,6 +1165,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["dataset_tags"] = [
|
dpo_trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
*trainer_cls_args,
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
@@ -1178,21 +1183,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer.add_callback(callback)
|
dpo_trainer.add_callback(callback)
|
||||||
|
|
||||||
return dpo_trainer
|
return dpo_trainer
|
||||||
|
|
||||||
|
|
||||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
|
||||||
"""
|
|
||||||
HF Factory class for PPO Trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_callbacks(self):
|
|
||||||
callbacks = super().get_callbacks()
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
|
||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
|
||||||
# build PPOConfig
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -3,16 +3,16 @@
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
from .base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
from .dpo.trainer import AxolotlDPOTrainer
|
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
|
||||||
from .grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
|
||||||
from .mamba import AxolotlMambaTrainer
|
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
|
||||||
from .relora import ReLoRATrainer
|
from axolotl.core.trainers.relora import ReLoRATrainer
|
||||||
from .trl import (
|
from axolotl.core.trainers.trl import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
|
AxolotlPPOTrainer,
|
||||||
AxolotlPRMTrainer,
|
AxolotlPRMTrainer,
|
||||||
AxolotlRewardTrainer,
|
AxolotlRewardTrainer,
|
||||||
TRLPPOTrainer,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,10 +8,11 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch.utils.data import (
|
from torch.utils.data import (
|
||||||
BatchSampler,
|
BatchSampler,
|
||||||
@@ -25,12 +26,8 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
|||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||||
OptimizerMixin,
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
RngLoaderMixin,
|
|
||||||
SchedulerMixin,
|
|
||||||
SequenceParallelMixin,
|
|
||||||
)
|
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -40,9 +37,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(
|
class AxolotlTrainer(TrainerMixins, Trainer):
|
||||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
|
||||||
):
|
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
@@ -68,9 +63,7 @@ class AxolotlTrainer(
|
|||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
# Initialize sequence parallelism if enabled
|
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._setup_sequence_parallel()
|
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
@@ -131,7 +124,7 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
# Determine the base sampler first
|
# Determine the base sampler first
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
base_sampler = self._sp_get_train_sampler(self.train_dataset)
|
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
|
||||||
elif self.args.curriculum_sampling:
|
elif self.args.curriculum_sampling:
|
||||||
base_sampler = SequentialSampler(self.train_dataset)
|
base_sampler = SequentialSampler(self.train_dataset)
|
||||||
elif use_sample_packing:
|
elif use_sample_packing:
|
||||||
@@ -167,7 +160,7 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
# Determine the base sampler
|
# Determine the base sampler
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
base_sampler = self._sp_get_eval_sampler(eval_dataset)
|
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
|
||||||
elif use_multipack:
|
elif use_multipack:
|
||||||
base_sampler = SequentialSampler(eval_dataset)
|
base_sampler = SequentialSampler(eval_dataset)
|
||||||
else:
|
else:
|
||||||
@@ -239,7 +232,10 @@ class AxolotlTrainer(
|
|||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# Otherwise prepare with accelerator
|
# Otherwise prepare with accelerator
|
||||||
return self.accelerator.prepare_data_loader(dataloader)
|
dataloader = self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
"""Get dataloader for training"""
|
"""Get dataloader for training"""
|
||||||
@@ -348,7 +344,57 @@ class AxolotlTrainer(
|
|||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
|
||||||
|
def training_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
|
num_items_in_batch: int | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Perform a training step on a batch of inputs. Overrides the
|
||||||
|
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||||
|
enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to perform training step for.
|
||||||
|
inputs: Dictionary mapping of inputs.
|
||||||
|
num_items_in_batch: The number of items in the batch.
|
||||||
|
"""
|
||||||
|
# Set up sequence parallelism for this step if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
|
# Proceed with normal training step
|
||||||
|
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||||
|
|
||||||
|
def prediction_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
|
prediction_loss_only: bool,
|
||||||
|
ignore_keys: list[str] | None = None,
|
||||||
|
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||||
|
"""
|
||||||
|
Perform a prediction step on a batch of inputs. Overrides the
|
||||||
|
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||||
|
enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to perform prediction step for.
|
||||||
|
inputs: Dictionary mapping of inputs.
|
||||||
|
prediction_loss_only: Whether to return only the loss.
|
||||||
|
ignore_keys: Keys to ignore in the inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (loss, logits, labels).
|
||||||
|
"""
|
||||||
|
# Set up sequence parallelism for this prediction step if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
|
# Proceed with normal prediction step
|
||||||
|
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
|
|||||||
@@ -1,14 +1,10 @@
|
|||||||
"""
|
"""DPO Specific Strategy for training"""
|
||||||
DPO Specific Strategy for training
|
|
||||||
"""
|
|
||||||
|
|
||||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
class DPOStrategy:
|
class DPOStrategy:
|
||||||
"""
|
"""Strategy for DPO training"""
|
||||||
Strategy for DPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_trainer_class(cls):
|
def get_trainer_class(cls):
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""Axolotl specific DPO args"""
|
||||||
Axolotl specific DPO args
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@@ -11,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||||
"""
|
"""DPO config for DPO training"""
|
||||||
DPO config for DPO training
|
|
||||||
"""
|
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
"""
|
"""DPO trainer for axolotl"""
|
||||||
DPO trainer for axolotl
|
|
||||||
"""
|
|
||||||
|
|
||||||
import gc
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Union
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
@@ -13,7 +10,8 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||||
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -23,18 +21,18 @@ if is_sagemaker_mp_enabled():
|
|||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||||
"""
|
"""Extend the base DPOTrainer for axolotl helpers"""
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.model_accepts_loss_kwargs = False
|
self.model_accepts_loss_kwargs = False
|
||||||
|
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -88,7 +86,7 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
max_prompt_length,
|
max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
) -> Dict:
|
) -> dict:
|
||||||
res = DPOTrainer.tokenize_row(
|
res = DPOTrainer.tokenize_row(
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
@@ -117,10 +115,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
|||||||
def training_step(
|
def training_step(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
inputs: dict[str, torch.Tensor | Any | None],
|
||||||
num_items_in_batch=None,
|
num_items_in_batch=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
return super().training_step(model, inputs, num_items_in_batch)
|
||||||
return loss
|
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""Axolotl GRPO trainer"""
|
||||||
Axolotl GRPO trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
@@ -8,16 +6,14 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
|
|||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.extras.profiling import profiling_decorator
|
from trl.extras.profiling import profiling_decorator
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
|
|
||||||
if is_deepspeed_available():
|
if is_deepspeed_available():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
|
||||||
"""
|
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||||
Extend the base GRPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
|
|||||||
3
src/axolotl/core/trainers/handlers/__init__.py
Normal file
3
src/axolotl/core/trainers/handlers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Init for trainer handlers"""
|
||||||
|
|
||||||
|
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler
|
||||||
123
src/axolotl/core/trainers/handlers/sequence_parallel.py
Normal file
123
src/axolotl/core/trainers/handlers/sequence_parallel.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""Handler class for sequence parallel trainer logic"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DistributedSampler
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelHandler:
|
||||||
|
"""
|
||||||
|
Handler class that encapsulates sequence parallelism functionality.
|
||||||
|
This replaces the SequenceParallelMixin with a composition-based approach.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args=None):
|
||||||
|
"""
|
||||||
|
Initialize the sequence parallel handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: The arguments object containing sequence parallelism settings.
|
||||||
|
"""
|
||||||
|
self.args = args
|
||||||
|
self.ring_attn_group = None
|
||||||
|
|
||||||
|
# Set up sequence parallelism if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._setup_sequence_parallel()
|
||||||
|
|
||||||
|
def _setup_sequence_parallel(self):
|
||||||
|
"""Set up sequence parallelism environment."""
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
self.update_ring_flash_attn_params = update_ring_flash_attn_params
|
||||||
|
self.ring_attn_group = get_ring_attn_group()
|
||||||
|
|
||||||
|
def create_sequence_parallel_sampler(
|
||||||
|
self,
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
is_eval=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper method to create sampler for sequence parallelism (SP).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Dataset to sample from.
|
||||||
|
shuffle: Whether to shuffle the dataset.
|
||||||
|
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Distributed sampler.
|
||||||
|
"""
|
||||||
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||||
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
return DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=num_sp_groups,
|
||||||
|
rank=sp_group_id,
|
||||||
|
seed=self.args.seed if shuffle else None,
|
||||||
|
shuffle=shuffle,
|
||||||
|
drop_last=not is_eval,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_train_sampler(self, dataset):
|
||||||
|
"""
|
||||||
|
Get a training sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The training dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self.create_sequence_parallel_sampler(
|
||||||
|
dataset,
|
||||||
|
shuffle=not self.args.curriculum_sampling,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_eval_sampler(self, eval_dataset):
|
||||||
|
"""
|
||||||
|
Get an evaluation sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_dataset: The evaluation dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self.create_sequence_parallel_sampler(
|
||||||
|
eval_dataset, shuffle=False, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_ring_flash_attn_params(self, inputs):
|
||||||
|
"""
|
||||||
|
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||||
|
the substituted ring_flash_attn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Current batch of inputs.
|
||||||
|
"""
|
||||||
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
|
# parallel data collator
|
||||||
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
packed_seq_lens = [seq_len] * batch_size
|
||||||
|
|
||||||
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
|
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
cu_seqlens = torch.cumsum(
|
||||||
|
torch.tensor(
|
||||||
|
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(
|
||||||
|
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||||
@@ -3,7 +3,12 @@
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
from .optimizer import OptimizerMixin
|
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
|
||||||
from .rng_state_loader import RngLoaderMixin
|
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
|
||||||
from .scheduler import SchedulerMixin
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
from .sequence_parallel import SequenceParallelMixin
|
|
||||||
|
|
||||||
|
class TrainerMixins(
|
||||||
|
OptimizerMixin, RngLoaderMixin, SchedulerMixin
|
||||||
|
):
|
||||||
|
"""Stub class combining all mixins for Axolotl trainers."""
|
||||||
|
|||||||
@@ -21,9 +21,7 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class RngLoaderMixin(Trainer):
|
class RngLoaderMixin(Trainer):
|
||||||
"""
|
"""Mixin for method override to load RNG states from a checkpoint"""
|
||||||
mixin for method override to load RNG states from a checkpoint
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _load_rng_state(self, checkpoint):
|
def _load_rng_state(self, checkpoint):
|
||||||
# Load RNG states from `checkpoint`
|
# Load RNG states from `checkpoint`
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module for Axolotl trainer sequence parallelism mixin"""
|
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||||
|
# TODO(Dan): remove
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -7,7 +8,6 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch import nn
|
|
||||||
from torch.utils.data import DistributedSampler, Sampler
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
@@ -71,12 +71,12 @@ class SequenceParallelMixin:
|
|||||||
drop_last=not is_eval,
|
drop_last=not is_eval,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
|
def _get_train_sampler(self, dataset) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Get a training sampler configured for sequence parallelism.
|
Get a training sampler configured for sequence parallelism.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: The training dataset
|
dataset: The training dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured sequence parallel sampler.
|
Configured sequence parallel sampler.
|
||||||
@@ -86,7 +86,7 @@ class SequenceParallelMixin:
|
|||||||
shuffle=not self.args.curriculum_sampling,
|
shuffle=not self.args.curriculum_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Get an evaluation sampler configured for sequence parallelism.
|
Get an evaluation sampler configured for sequence parallelism.
|
||||||
|
|
||||||
@@ -130,53 +130,3 @@ class SequenceParallelMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||||
|
|
||||||
def training_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: dict[str, torch.Tensor | Any],
|
|
||||||
num_items_in_batch: int | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Perform a training step on a batch of inputs. Overrides the
|
|
||||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
|
||||||
enabled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model to perform training step for.
|
|
||||||
inputs: Dictionary mapping.
|
|
||||||
"""
|
|
||||||
# Set up sequence parallelism for this step if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._update_ring_flash_attn_params(inputs)
|
|
||||||
|
|
||||||
# Proceed with normal training step
|
|
||||||
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
|
||||||
|
|
||||||
def prediction_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: dict[str, torch.Tensor | Any],
|
|
||||||
prediction_loss_only: bool,
|
|
||||||
ignore_keys: list[str] | None = None,
|
|
||||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
|
||||||
"""
|
|
||||||
Perform a prediction step on a batch of inputs. Overrides the
|
|
||||||
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
|
||||||
enabled.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model to perform prediction step for.
|
|
||||||
inputs: Dictionary mapping of inputs.
|
|
||||||
prediction_loss_only: Whether to return only the loss.
|
|
||||||
ignore_keys: Keys to ignore in the inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (loss, logits, labels).
|
|
||||||
"""
|
|
||||||
# Set up sequence parallelism for this prediction step if enabled
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
|
||||||
self._update_ring_flash_attn_params(inputs)
|
|
||||||
|
|
||||||
# Proceed with normal prediction step
|
|
||||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
|
||||||
|
|||||||
@@ -13,11 +13,10 @@ from trl import (
|
|||||||
RewardTrainer,
|
RewardTrainer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
||||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||||
|
|
||||||
tag_names = ["axolotl", "ppo"]
|
tag_names = ["axolotl", "ppo"]
|
||||||
@@ -75,10 +74,8 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
||||||
"""
|
"""Extend the base ORPOTrainer for axolotl helpers"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
@@ -155,18 +152,14 @@ class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
|||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
|
||||||
"""
|
"""Extend the base KTOTrainer for axolotl helpers"""
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
||||||
"""
|
"""Extend the base CPOTrainer for axolotl helpers"""
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
@@ -245,17 +238,13 @@ class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
|||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
|
||||||
"""
|
"""Extend the base RewardTrainer for axolotl helpers"""
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
|
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
|
||||||
"""
|
"""Extend the base trl.PRMTrainer for axolotl helpers"""
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "prm"]
|
tag_names = ["axolotl", "prm"]
|
||||||
|
|||||||
@@ -12,9 +12,7 @@ from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingMixins:
|
class AxolotlTrainingMixins:
|
||||||
"""
|
"""Mixin class for the Axolotl training args."""
|
||||||
Mixin class for the Axolotl training args.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
|
|||||||
@@ -6,11 +6,22 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc
|
|||||||
their sequence parallel version of Flash Attention 2.
|
their sequence parallel version of Flash Attention 2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
except ImportError:
|
||||||
|
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||||
|
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -32,12 +43,120 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
|||||||
Setter for ring attention group on this rank.
|
Setter for ring attention group on this rank.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
Process group for ring attention.
|
ring_attn_group: Process group for ring attention.
|
||||||
"""
|
"""
|
||||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||||
RING_ATTN_GROUP = ring_attn_group
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
|
def patch_flash_attention_for_sequential_batch(sequence_parallel_degree: int):
|
||||||
|
"""
|
||||||
|
Patch flash attention a second time to handle batched data. This is a hack to
|
||||||
|
accommodate certain RL trainers which batch data even when `micro_batch_size: 1` is
|
||||||
|
specified in the Axolotl config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_parallel_degree: Sequence parallelism factor.
|
||||||
|
"""
|
||||||
|
# Store the original flash attention function
|
||||||
|
original_flash_attention = ALL_ATTENTION_FUNCTIONS["flash_attention_2"]
|
||||||
|
|
||||||
|
def sequential_batch_flash_attention(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scaling: float | None = None,
|
||||||
|
sliding_window: int | None = None,
|
||||||
|
softcap: float | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, None]:
|
||||||
|
# Check if we have a batch dimension > 1
|
||||||
|
batch_size = query.shape[0]
|
||||||
|
|
||||||
|
if batch_size <= 1:
|
||||||
|
return original_flash_attention(
|
||||||
|
module,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask,
|
||||||
|
dropout,
|
||||||
|
scaling,
|
||||||
|
sliding_window,
|
||||||
|
softcap,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process each item in the batch separately
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Extract single batch item
|
||||||
|
q_item = query[i:i+1]
|
||||||
|
k_item = key[i:i+1]
|
||||||
|
v_item = value[i:i+1]
|
||||||
|
|
||||||
|
# Handle attention mask - it might be None or have different shapes
|
||||||
|
mask_item = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
# The mask could have different formats depending on implementation
|
||||||
|
if attention_mask.dim() >= 3 and attention_mask.shape[0] == batch_size:
|
||||||
|
mask_item = attention_mask[i:i+1]
|
||||||
|
else:
|
||||||
|
# For broadcast masks that don't have a batch dimension
|
||||||
|
mask_item = attention_mask
|
||||||
|
|
||||||
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
|
# parallel data collator
|
||||||
|
batch_size = q_item.shape[0]
|
||||||
|
seq_len = q_item.shape[2]
|
||||||
|
packed_seq_lens = [seq_len] * batch_size
|
||||||
|
|
||||||
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
|
total_seq_len = seq_len * sequence_parallel_degree
|
||||||
|
|
||||||
|
cu_seqlens = torch.cumsum(
|
||||||
|
torch.tensor(
|
||||||
|
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(
|
||||||
|
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
|
|
||||||
|
# Call the original function for a single batch item
|
||||||
|
output, _ = original_flash_attention(
|
||||||
|
module,
|
||||||
|
q_item,
|
||||||
|
k_item,
|
||||||
|
v_item,
|
||||||
|
mask_item,
|
||||||
|
dropout,
|
||||||
|
scaling,
|
||||||
|
sliding_window,
|
||||||
|
softcap,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# Concatenate results along batch dimension
|
||||||
|
concatenated_output = torch.cat(outputs, dim=0)
|
||||||
|
return concatenated_output, None
|
||||||
|
|
||||||
|
# Replace the original function with our sequential version
|
||||||
|
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = sequential_batch_flash_attention
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
|
||||||
"""
|
"""
|
||||||
Create ring attention group and substitute flash attn with ring flash attn.
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
@@ -98,3 +217,4 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
|
|||||||
substitute_hf_flash_attn(
|
substitute_hf_flash_attn(
|
||||||
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
|
||||||
)
|
)
|
||||||
|
patch_flash_attention_for_sequential_batch(sequence_parallel_degree)
|
||||||
|
|||||||
@@ -1351,9 +1351,7 @@ def load_model(
|
|||||||
reference_model: bool = False,
|
reference_model: bool = False,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||||
"""
|
"""Load a model for a given configuration and tokenizer."""
|
||||||
Load a model for a given configuration and tokenizer.
|
|
||||||
"""
|
|
||||||
model_loader = ModelLoader(
|
model_loader = ModelLoader(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -1362,12 +1360,16 @@ def load_model(
|
|||||||
reference_model=reference_model,
|
reference_model=reference_model,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_loader.load_model()
|
return model_loader.load_model()
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(
|
||||||
# type: (PreTrainedModel, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
model: PreTrainedModel,
|
||||||
|
cfg: DictDefault,
|
||||||
|
adapter: str | None,
|
||||||
|
inference: bool = False,
|
||||||
|
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
||||||
if adapter is None:
|
if adapter is None:
|
||||||
return model, None
|
return model, None
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
@@ -1380,8 +1382,9 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|||||||
raise NotImplementedError(f"{adapter} peft adapter not available")
|
raise NotImplementedError(f"{adapter} peft adapter not available")
|
||||||
|
|
||||||
|
|
||||||
def load_llama_adapter(model, cfg):
|
def load_llama_adapter(
|
||||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
model: PreTrainedModel, cfg: DictDefault
|
||||||
|
) -> tuple[PreTrainedModel, PeftConfig | None]:
|
||||||
from peft import AdaptionPromptConfig, get_peft_model
|
from peft import AdaptionPromptConfig, get_peft_model
|
||||||
|
|
||||||
peft_config = AdaptionPromptConfig(
|
peft_config = AdaptionPromptConfig(
|
||||||
@@ -1405,7 +1408,7 @@ def load_llama_adapter(model, cfg):
|
|||||||
return model, peft_config
|
return model, peft_config
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model: PreTrainedModel):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||||
lora_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
|||||||
Reference in New Issue
Block a user