Compare commits
36 Commits
sdpa-cp
...
kd-fix-202
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2491303c46 | ||
|
|
2c66483a47 | ||
|
|
01382b9a79 | ||
|
|
cfcd69df0d | ||
|
|
2302b14a84 | ||
|
|
a8e2bddd19 | ||
|
|
d55a51623f | ||
|
|
73a84ad0dd | ||
|
|
3cffe881bb | ||
|
|
e77d62933d | ||
|
|
3a0faa97ca | ||
|
|
20602fd93f | ||
|
|
770bb0605a | ||
|
|
24b96b1c4f | ||
|
|
90c7228ff9 | ||
|
|
9eb53f5c9e | ||
|
|
225b420dc5 | ||
|
|
b75db13615 | ||
|
|
c7b1db329e | ||
|
|
a40e484803 | ||
|
|
9899c924f9 | ||
|
|
505009b454 | ||
|
|
b4e96ef12c | ||
|
|
a8d9fab635 | ||
|
|
49e2fa825d | ||
|
|
7263845207 | ||
|
|
5ccfd225cb | ||
|
|
28eb8632a1 | ||
|
|
5cfaac3767 | ||
|
|
ca70fb7cb0 | ||
|
|
22b50d6619 | ||
|
|
a2248673d8 | ||
|
|
0399aefcb3 | ||
|
|
83ad248e5b | ||
|
|
6fafe46562 | ||
|
|
0e46367e01 |
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'src/axolotl/core/trainers/mixins/context_parallel.py'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
|
||||
@@ -75,7 +75,7 @@ quartodoc:
|
||||
- title: Context Managers
|
||||
desc: Context managers for altering trainer behaviors
|
||||
contents:
|
||||
- utils.ctx_managers.context_parallel
|
||||
- utils.ctx_managers.sequence_parallel
|
||||
- title: Prompt Strategies
|
||||
desc: Prompt formatting strategies
|
||||
contents:
|
||||
@@ -274,7 +274,7 @@ website:
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/context_parallelism.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
31
deepspeed_configs/zero2_torch_compile.json
Normal file
31
deepspeed_configs/zero2_torch_compile.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"compile": {
|
||||
"disable": false,
|
||||
"backend": "inductor"
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"contiguous_gradients": true,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -764,13 +764,13 @@ ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Context parallelism
|
||||
# Sequence parallelism
|
||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://docs.axolotl.ai/docs/context_parallelism.html for more details.
|
||||
context_parallel_degree:
|
||||
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
|
||||
@@ -18,7 +18,7 @@ Axolotl supports several methods for multi-GPU training:
|
||||
|
||||
- DeepSpeed (recommended)
|
||||
- FSDP (Fully Sharded Data Parallel)
|
||||
- Context parallelism
|
||||
- Sequence parallelism
|
||||
- FSDP + QLoRA
|
||||
|
||||
## DeepSpeed {#sec-deepspeed}
|
||||
@@ -80,14 +80,14 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Context parallelism {#sec-sequence-parallelism}
|
||||
## Sequence parallelism {#sec-sequence-parallelism}
|
||||
|
||||
We support context parallelism (SP) via the
|
||||
We support sequence parallelism (SP) via the
|
||||
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
|
||||
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||
single sequence causes OOM errors during model training.
|
||||
|
||||
See our [dedicated guide](context_parallelism.qmd) for more information.
|
||||
See our [dedicated guide](sequence_parallelism.qmd) for more information.
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
---
|
||||
title: Context Parallelism
|
||||
title: Sequence Parallelism
|
||||
description: Train with long sequences split across multiple GPUs.
|
||||
---
|
||||
|
||||
Context parallelism is a technique that splits sequences across multiple GPUs,
|
||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||
GPU processes a different portion of the sequence, and the results are aggregated
|
||||
through a ring communication pattern.
|
||||
|
||||
## When to Use Context Parallelism
|
||||
## When to Use Sequence Parallelism
|
||||
|
||||
Use context parallelism when:
|
||||
Use sequence parallelism when:
|
||||
|
||||
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||
- You have multiple GPUs available
|
||||
@@ -18,11 +18,11 @@ Use context parallelism when:
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable context parallelism, add the following to your configuration file:
|
||||
To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
context_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -30,23 +30,23 @@ heads_k_stride: 1
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `context_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
|
||||
## Implementation Details
|
||||
|
||||
When context parallelism is enabled:
|
||||
When sequence parallelism is enabled:
|
||||
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a context parallel group
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||
3. Position IDs are adjusted to maintain proper relative positions
|
||||
4. The trainer uses special ring communication patterns for attention operations
|
||||
|
||||
## Requirements
|
||||
|
||||
To use context parallelism, you need:
|
||||
To use sequence parallelism, you need:
|
||||
|
||||
- Multiple GPUs (at least 2)
|
||||
- The `ring-flash-attn` package. Install with:
|
||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
||||
|
||||
...
|
||||
|
||||
context_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -79,22 +79,22 @@ ring_attn_func:
|
||||
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
||||
into 2 subsequences of length 4096 across 2 GPUs.
|
||||
|
||||
## Sample Packing with Context Parallelism
|
||||
## Sample Packing with Sequence Parallelism
|
||||
|
||||
Context parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||
|
||||
1. Samples are first packed together
|
||||
2. The packed sequences are then divided across GPUs in the context parallel group
|
||||
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using context parallelism, your effective global batch size is **divided** by the `context_parallel_degree`. This happens because:
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
|
||||
- Each group of `context_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no context parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `context_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
@@ -73,7 +73,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
context_parallel_degree=None,
|
||||
sequence_parallel_degree=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
|
||||
@@ -21,11 +21,6 @@ from axolotl.core.trainers import (
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlPRMConfig,
|
||||
AxolotlRewardConfig,
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
@@ -130,6 +125,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
"""
|
||||
Gets the trainer class for the given configuration.
|
||||
"""
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
@@ -146,6 +144,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlPRMConfig,
|
||||
AxolotlRewardConfig,
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
|
||||
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps
|
||||
)
|
||||
@@ -314,20 +318,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||
self.cfg.image_resize_algorithm
|
||||
)
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
if self.cfg.kd_alpha is not None:
|
||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||
if self.cfg.kd_temperature is not None:
|
||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||
if self.cfg.kd_zscore_base_temp is not None:
|
||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
||||
self.cfg.kd_zscore_base_temp
|
||||
)
|
||||
if self.cfg.kd_top_k_before_softmax is not None:
|
||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||
if plugin_training_args:
|
||||
training_arguments_kwargs.update(plugin_training_args)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
@@ -408,7 +404,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return trainer
|
||||
|
||||
def build_collator(
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
self,
|
||||
training_args, # type: "AxolotlTrainingArguments" # type: ignore
|
||||
is_eval=False,
|
||||
**kwargs,
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if (
|
||||
@@ -437,7 +436,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
if self.cfg.reward_model:
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
|
||||
self.cfg, is_eval=is_eval
|
||||
)
|
||||
|
||||
if collator_cls_and_kwargs:
|
||||
collator = collator_cls_and_kwargs[0]
|
||||
if kwargs and isinstance(kwargs, dict):
|
||||
kwargs.update(collator_cls_and_kwargs[1])
|
||||
elif self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif use_batch_sampler_collator:
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
@@ -468,16 +478,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
collator_args.pop(0)
|
||||
kwargs.pop("pad_to_multiple_of", None)
|
||||
kwargs.pop("padding", None)
|
||||
elif self.cfg.kd_trainer:
|
||||
from axolotl.integrations.kd.collator import (
|
||||
DataCollatorForKD,
|
||||
KDBatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
|
||||
if self.cfg.sample_packing:
|
||||
collator = KDBatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = DataCollatorForKD
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@@ -12,11 +12,6 @@ from axolotl.core.trainers import (
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -54,7 +49,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
context_parallel=self.cfg.context_parallel_degree > 1
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
@@ -79,6 +74,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Returns training_args and trainer_kwargs
|
||||
"""
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
)
|
||||
|
||||
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps=total_num_steps
|
||||
)
|
||||
@@ -165,6 +166,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||
if plugin_training_args:
|
||||
training_args_kwargs.update(plugin_training_args)
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
logging_first_step=True,
|
||||
**training_args_kwargs,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOContextParallelTrainer, AxolotlGRPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
|
||||
@@ -7,13 +7,11 @@ from __future__ import annotations
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
from typing import Callable, Literal, Optional
|
||||
|
||||
from axolotl.utils.ctx_managers.context_parallel.distributed import get_context_parallel_manager
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -35,6 +33,7 @@ from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils import get_not_null
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -67,32 +66,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# SPDA device mesh init
|
||||
import torch.distributed as dist
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // 2,
|
||||
2,
|
||||
)
|
||||
self.world_mesh = dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch=None
|
||||
) -> torch.Tensor:
|
||||
ctx_manager = get_context_parallel_manager(
|
||||
world_mesh=self.world_mesh,
|
||||
model=model,
|
||||
)
|
||||
to_shard = {k: v for k, v in inputs.items() if v.ndim > 1}
|
||||
with ctx_manager(list(to_shard.values())):
|
||||
super().training_step(model, inputs, num_items_in_batch)
|
||||
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
@@ -129,7 +102,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
return MultipackBatchSampler(
|
||||
sampler = MultipackBatchSampler(
|
||||
base_sampler,
|
||||
lengths=get_dataset_lengths(dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
@@ -141,6 +114,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
len(sampler)
|
||||
return sampler
|
||||
|
||||
def _get_train_sampler(
|
||||
self, train_dataset: Optional[Dataset] = None
|
||||
) -> Optional[Sampler]:
|
||||
@@ -248,7 +224,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
}
|
||||
|
||||
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["drop_last"] = get_not_null(
|
||||
self.args.dataloader_drop_last, True
|
||||
)
|
||||
if sampler_fn is not None:
|
||||
sampler = sampler_fn(dataset)
|
||||
if isinstance(sampler, BatchSampler):
|
||||
|
||||
@@ -8,7 +8,7 @@ from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlGRPOContextParallelTrainer,
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -23,10 +23,10 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(
|
||||
cls, context_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOContextParallelTrainer]:
|
||||
if context_parallel:
|
||||
return AxolotlGRPOContextParallelTrainer
|
||||
cls, sequence_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
|
||||
if sequence_parallel:
|
||||
return AxolotlGRPOSequenceParallelTrainer
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
@@ -69,8 +69,8 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
|
||||
if cfg.context_parallel_degree > 1:
|
||||
grpo_args_kwargs["context_parallel_degree"] = cfg.context_parallel_degree
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
context_parallel_degree: int | None = None
|
||||
sequence_parallel_degree: int | None = None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Repeat random sampler (similar to the one implemented in
|
||||
https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds
|
||||
context parallelism functionality; i.e., duplicating data across ranks in the same
|
||||
context parallel group.
|
||||
sequence parallelism functionality; i.e., duplicating data across ranks in the same
|
||||
sequence parallel group.
|
||||
"""
|
||||
|
||||
from typing import Iterator, Sized
|
||||
@@ -10,26 +10,26 @@ import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
class ContextParallelRepeatRandomSampler(Sampler):
|
||||
"""Sampler for GRPO training with context parallelism.
|
||||
class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
"""Sampler for GRPO training with sequence parallelism.
|
||||
|
||||
This sampler ensures:
|
||||
- Ranks in the same context parallel (SP) group receive identical data.
|
||||
- Ranks in the same sequence parallel (SP) group receive identical data.
|
||||
- Each index is repeated multiple times for sampling different completions.
|
||||
- Entire batches are repeated for reuse in multiple updates.
|
||||
- Data is properly distributed across CP groups.
|
||||
- Data is properly distributed across SP groups.
|
||||
|
||||
In the table below, the values represent dataset indices. Each CP group has
|
||||
`context_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
CP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
In the table below, the values represent dataset indices. Each SP group has
|
||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
|
||||
Context Parallel Groups
|
||||
Sequence Parallel Groups
|
||||
| SP0 | SP1 |
|
||||
| GPU 0 | GPU 1 | GPU 2 | GPU 3 |
|
||||
global_step step <---> mini_repeat_count=3
|
||||
<----------> batch_size=2 per CP group
|
||||
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- CP groups get different data
|
||||
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each CP group GPU
|
||||
<----------> batch_size=2 per SP group
|
||||
grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data
|
||||
▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU
|
||||
|
|
||||
| 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations
|
||||
num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation
|
||||
@@ -45,7 +45,7 @@ class ContextParallelRepeatRandomSampler(Sampler):
|
||||
rank: Rank of current process.
|
||||
batch_size: Number of samples per batch.
|
||||
repeat_count: How many times to repeat the full sampling process.
|
||||
context_parallel_degree: Number of ranks in a context parallel group.
|
||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
seed: Random seed for shuffling.
|
||||
drop_last: Whether to drop the last incomplete batch.
|
||||
@@ -59,7 +59,7 @@ class ContextParallelRepeatRandomSampler(Sampler):
|
||||
rank: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
context_parallel_degree: int = 1,
|
||||
sequence_parallel_degree: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
@@ -76,16 +76,16 @@ class ContextParallelRepeatRandomSampler(Sampler):
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
|
||||
# Context parallelism parameters
|
||||
self.context_parallel_degree = context_parallel_degree
|
||||
self.num_sp_groups = world_size // context_parallel_degree
|
||||
self.sp_group_id = rank // context_parallel_degree
|
||||
# Sequence parallelism parameters
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||
self.sp_group_id = rank // sequence_parallel_degree
|
||||
|
||||
# Adjust dataset size for distributed sampling
|
||||
self.num_samples = len(self.dataset)
|
||||
self.total_size = self.num_samples
|
||||
|
||||
# Calculate effective number of samples per CP group
|
||||
# Calculate effective number of samples per SP group
|
||||
if (
|
||||
self.drop_last
|
||||
and self.total_size % (self.num_sp_groups * self.batch_size) != 0
|
||||
@@ -125,8 +125,8 @@ class ContextParallelRepeatRandomSampler(Sampler):
|
||||
padding = indices[: self.batch_size - len(indices) % self.batch_size]
|
||||
indices += padding
|
||||
|
||||
# Subsample based on CP group ID
|
||||
# Each CP group gets distinct batches of data
|
||||
# Subsample based on SP group ID
|
||||
# Each SP group gets distinct batches of data
|
||||
batch_indices = []
|
||||
for i in range(0, len(indices), self.batch_size * self.num_sp_groups):
|
||||
start_idx = i + self.sp_group_id * self.batch_size
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Axolotl GRPO trainers (with and without context parallelism handling)"""
|
||||
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
|
||||
|
||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||
|
||||
@@ -41,7 +41,7 @@ from trl.trainer.grpo_config import GRPOConfig
|
||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.sampler import ContextParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
|
||||
@@ -59,8 +59,8 @@ class AxolotlGRPOTrainer(
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
|
||||
class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for context parallelism handling"""
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -97,11 +97,11 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
|
||||
)
|
||||
|
||||
# Get number of CP groups (number of processes divided by CP degree)
|
||||
# Get number of SP groups (number of processes divided by SP degree)
|
||||
num_processes = self.accelerator.num_processes
|
||||
num_sp_groups = num_processes // self.args.context_parallel_degree
|
||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||
|
||||
# Calculate batch size per CP group (not per process)
|
||||
# Calculate batch size per SP group (not per process)
|
||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||
possible_values = [
|
||||
n_gen
|
||||
@@ -111,7 +111,7 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The batch size per CP group ({num_sp_groups} x "
|
||||
f"The batch size per SP group ({num_sp_groups} x "
|
||||
f"{self.args.per_device_train_batch_size}) must be evenly divisible by "
|
||||
f"the number of generations per prompt ({self.num_generations}). Given "
|
||||
"the current configuration, the valid values for the number of "
|
||||
@@ -119,7 +119,7 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
|
||||
if self.args.eval_strategy != "no":
|
||||
# If context parallelism is enabled, calculate batch size per CP group
|
||||
# If sequence parallelism is enabled, calculate batch size per SP group
|
||||
sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr]
|
||||
possible_values = [
|
||||
n_gen
|
||||
@@ -129,8 +129,8 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"With context parallelism (degree {self.args.context_parallel_degree}), "
|
||||
f"the eval batch size per CP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"must be evenly divisible by the number of generations per prompt "
|
||||
f"({self.num_generations}). Given the current eval batch size, "
|
||||
f"the valid values for the number of generations are: {possible_values}."
|
||||
@@ -143,7 +143,7 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
self.local_world_size = 1
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
# Initialize the CP group
|
||||
# Initialize the SP group
|
||||
self.sp_group = get_ring_attn_group()
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
@@ -159,16 +159,16 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
* self.args.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
return ContextParallelRepeatRandomSampler(
|
||||
return SequenceParallelRepeatRandomSampler(
|
||||
dataset=self.train_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
batch_size=effective_batch_size
|
||||
// self.num_generations
|
||||
// self.args.context_parallel_degree,
|
||||
// self.args.sequence_parallel_degree,
|
||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||
context_parallel_degree=self.args.context_parallel_degree,
|
||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||
shuffle=True,
|
||||
seed=self.args.seed,
|
||||
drop_last=True,
|
||||
@@ -226,11 +226,11 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
):
|
||||
self.accelerator.even_batches = False
|
||||
|
||||
# Return unprepared dataloader if using context parallelism
|
||||
# Return unprepared dataloader if using sequence parallelism
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.context_parallel_degree > 1:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
@@ -303,21 +303,21 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
if self.args.context_parallel_degree > 1:
|
||||
# Calculate context parallel group information
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate sequence parallel group information
|
||||
world_size = self.accelerator.num_processes
|
||||
context_parallel_degree = self.args.context_parallel_degree
|
||||
num_sp_groups = world_size // context_parallel_degree
|
||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||
num_sp_groups = world_size // sequence_parallel_degree
|
||||
|
||||
# Since processes in the same CP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each CP group
|
||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each SP group
|
||||
ordered_set_of_prompts = []
|
||||
for sp_group_id in range(num_sp_groups):
|
||||
# Get the first process from each CP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * context_parallel_degree
|
||||
# Get the first process from each SP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||
|
||||
# Extract prompts from this CP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each CP group
|
||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each SP group
|
||||
group_prompts = all_prompts_text[
|
||||
group_leader_rank
|
||||
* len(prompts_text) : (group_leader_rank + 1)
|
||||
@@ -330,7 +330,7 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[
|
||||
:: self.num_generations * self.args.context_parallel_degree
|
||||
:: self.num_generations * self.args.sequence_parallel_degree
|
||||
]
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
@@ -347,28 +347,28 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.context_parallel_degree
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
|
||||
# Determine the appropriate slice based on context parallelism
|
||||
if self.args.context_parallel_degree > 1:
|
||||
# Calculate CP group ID (which group of ranks this rank belongs to)
|
||||
# Determine the appropriate slice based on sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
# Calculate the start index for this CP group
|
||||
# Calculate the start index for this SP group
|
||||
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
||||
|
||||
# All ranks in the same CP group get the same data slice
|
||||
# All ranks in the same SP group get the same data slice
|
||||
process_slice = slice(
|
||||
sp_group_start,
|
||||
sp_group_start + len(prompts),
|
||||
)
|
||||
completion_ids = completion_ids[process_slice]
|
||||
else:
|
||||
# Original behavior for non-context parallel case
|
||||
# Original behavior for non-sequence parallel case
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
@@ -578,20 +578,20 @@ class AxolotlGRPOContextParallelTrainer(AxolotlGRPOTrainer):
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
if self.args.context_parallel_degree > 1:
|
||||
# Calculate CP group ID (which group of ranks this rank belongs to)
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
# Calculate the start index for this CP group
|
||||
# Calculate the start index for this SP group
|
||||
sp_group_start = sp_group_id * len(prompts) * self.local_world_size
|
||||
|
||||
# All ranks in the same CP group get the same data slice
|
||||
# All ranks in the same SP group get the same data slice
|
||||
process_slice = slice(
|
||||
sp_group_start,
|
||||
sp_group_start + len(prompts),
|
||||
)
|
||||
else:
|
||||
# Original behavior for non-context parallel case
|
||||
# Original behavior for non-sequence parallel case
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
|
||||
@@ -2,238 +2,17 @@
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Type
|
||||
|
||||
from PIL.Image import Resampling
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
from axolotl.integrations.config import merge_training_args
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
lr_groups: Optional[list[dict]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
kd_ce_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_alpha: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||
)
|
||||
|
||||
kd_temperature: Optional[float] = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "the temperature parameter for KL divergence loss when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_zscore_base_temp: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "the base temperature parameter for KL divergence with z-score when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
kd_top_k_before_softmax: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
|
||||
},
|
||||
)
|
||||
|
||||
adam_beta3: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
adam_epsilon2: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The size of the image to resize to"},
|
||||
)
|
||||
|
||||
image_resize_algorithm: Resampling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The algorithm to use for image resizing"},
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
AxolotlTrainingMixins: Type = merge_training_args()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
220
src/axolotl/core/training_args_base.py
Normal file
220
src/axolotl/core/training_args_base.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Base Axolotl Training Mixins shared across various trainer configs
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Resampling
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingMixins:
|
||||
"""
|
||||
Mixin class for the Axolotl training args.
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
model_type: Optional[str] = field(
|
||||
default=None, metadata={"help": "HF model configuration model_type."}
|
||||
)
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
pretraining: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||
},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
sample_packing_bin_size: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
sample_packing_group_size: int = field(
|
||||
default=100000,
|
||||
metadata={
|
||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||
},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
do_causal_lm_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
cosine_min_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||
)
|
||||
cosine_constant_lr_ratio: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||
},
|
||||
)
|
||||
loraplus_lr_ratio: Optional[float] = field(
|
||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
embedding_lr_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||
)
|
||||
lr_groups: Optional[list[dict]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||
)
|
||||
embedding_lr: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
lisa_n_layers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
)
|
||||
curriculum_sampling: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||
)
|
||||
alternate_lr_scheduler_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||
},
|
||||
)
|
||||
chat_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat template converting chat messages to text"},
|
||||
)
|
||||
|
||||
# kd_ce_alpha: Optional[float] = field(
|
||||
# default=None,
|
||||
# metadata={
|
||||
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||
# },
|
||||
# )
|
||||
#
|
||||
# kd_alpha: Optional[float] = field(
|
||||
# default=1.0,
|
||||
# metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||
# )
|
||||
#
|
||||
# kd_temperature: Optional[float] = field(
|
||||
# default=1.0,
|
||||
# metadata={
|
||||
# "help": "the temperature parameter for KL divergence loss when using KD"
|
||||
# },
|
||||
# )
|
||||
|
||||
adam_beta3: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
adam_epsilon2: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||
},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The size of the image to resize to"},
|
||||
)
|
||||
|
||||
image_resize_algorithm: Resampling | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The algorithm to use for image resizing"},
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
@@ -22,6 +22,7 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import importlib
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||
|
||||
from peft import PeftModel
|
||||
@@ -83,6 +84,11 @@ class BasePlugin:
|
||||
def get_input_args(self) -> str | None:
|
||||
"""Returns a pydantic model for the plugin's input arguments."""
|
||||
|
||||
def get_training_args_mixin(self) -> str | None:
|
||||
"""
|
||||
Returns a dataclass model for the plugin's training arguments.
|
||||
"""
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
@@ -158,6 +164,31 @@ class BasePlugin:
|
||||
trainer: The trainer object for training.
|
||||
"""
|
||||
|
||||
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns custom training arguments to set on TrainingArgs.
|
||||
|
||||
Args:
|
||||
cfg: The global axolotl configuration.
|
||||
|
||||
Returns:
|
||||
object: dict containing the training arguments.
|
||||
"""
|
||||
|
||||
def get_collator_cls_and_kwargs(
|
||||
self, cfg: DictDefault, is_eval: bool = False
|
||||
): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns a custom class for the collator.
|
||||
|
||||
Args:
|
||||
cfg: The global axolotl configuration.
|
||||
is_eval: Whether this is an eval split.
|
||||
|
||||
Returns:
|
||||
class: The class for the collator.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
|
||||
"""Creates and returns an optimizer for training.
|
||||
@@ -278,7 +309,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
|
||||
return plugin
|
||||
|
||||
|
||||
class PluginManager:
|
||||
class PluginManager: # pylint: disable=too-many-public-methods
|
||||
"""The `PluginManager` class is responsible for loading and managing plugins. It
|
||||
should be a singleton so it can be accessed from anywhere in the codebase.
|
||||
|
||||
@@ -337,8 +368,11 @@ class PluginManager:
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
LOG.info(f"Plugin loaded successfully: {plugin_name}")
|
||||
except ImportError:
|
||||
except ImportError as exc:
|
||||
LOG.error(f"Failed to load plugin: {plugin_name}")
|
||||
# print stacktrace
|
||||
traceback.print_exc()
|
||||
print(f"Error: {exc}")
|
||||
|
||||
def get_input_args(self) -> list[str]:
|
||||
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
||||
@@ -353,6 +387,20 @@ class PluginManager:
|
||||
input_args.append(input_args_from_plugin)
|
||||
return input_args
|
||||
|
||||
def get_training_args_mixin(self):
|
||||
"""
|
||||
Returns a list of dataclasses for all registered plugins' training args mixins'
|
||||
|
||||
Returns:
|
||||
list[str]: A list of dataclsses
|
||||
"""
|
||||
training_args = []
|
||||
for plugin in self.plugins.values():
|
||||
training_args_from_plugin = plugin.get_training_args_mixin()
|
||||
if training_args_from_plugin is not None:
|
||||
training_args.append(training_args_from_plugin)
|
||||
return training_args
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
@@ -442,6 +490,42 @@ class PluginManager:
|
||||
return trainer_cls
|
||||
return None
|
||||
|
||||
def get_training_args(self, cfg):
|
||||
"""
|
||||
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
|
||||
Returns:
|
||||
object: The training arguments
|
||||
"""
|
||||
training_args_kwargs = {}
|
||||
for plugin in self.plugins.values():
|
||||
training_args = plugin.get_training_args(cfg)
|
||||
if training_args is not None:
|
||||
training_args_kwargs.update(training_args)
|
||||
|
||||
return training_args_kwargs
|
||||
|
||||
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||
"""
|
||||
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
is_eval (bool): Whether this is an eval split.
|
||||
|
||||
Returns:
|
||||
object: The collator class, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
|
||||
if collator is not None:
|
||||
collator_cls, collator_kwargs = collator
|
||||
return collator_cls, collator_kwargs
|
||||
return None
|
||||
|
||||
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||
"""Calls the `post_trainer_create` method of all registered plugins.
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
|
||||
This was moved here to prevent circular imports.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
@@ -61,3 +61,43 @@ def merge_input_args():
|
||||
]
|
||||
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
||||
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
||||
|
||||
|
||||
def merge_training_args() -> Type:
|
||||
"""
|
||||
Merges training arguments from registered plugins with the base TrainingArguments.
|
||||
|
||||
This function retrieves the training arguments from registered plugins using the PluginManager.
|
||||
It then dynamically creates new classes, AxolotlTrainingMixins,
|
||||
that inherit from the base configurations and include the training arguments from the plugins.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
from axolotl.core.training_args_base import (
|
||||
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
|
||||
mixin_classes = []
|
||||
dynamic_input = ""
|
||||
for plugin_args in training_args_mixins:
|
||||
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
||||
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
||||
mixin_classes.append(plugin_cls)
|
||||
if dynamic_input:
|
||||
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
|
||||
|
||||
namespace: Dict[Any, Any] = {}
|
||||
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
dynamic_input, {**globals(), **local_vars}, namespace
|
||||
)
|
||||
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
|
||||
"AxolotlTrainingMixins"
|
||||
]
|
||||
return AxolotlTrainingMixins
|
||||
return AxolotlTrainingMixinsBase
|
||||
|
||||
@@ -21,3 +21,32 @@ datasets:
|
||||
```
|
||||
|
||||
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)
|
||||
|
||||
## Online KD (sglang)
|
||||
|
||||
```bash
|
||||
export UV_TORCH_BACKEND=cu124
|
||||
uv venv sglang --python 3.11
|
||||
source sglang/bin/activate
|
||||
uv pip install --upgrade pip
|
||||
uv pip install setuptools
|
||||
uv pip install torch~=2.5.1 --index-url https://download.pytorch.org/whl/cu124
|
||||
uv pip install sgl-kernel --force-reinstall --no-deps
|
||||
uv pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
|
||||
```
|
||||
|
||||
## Online KD (vllm)
|
||||
|
||||
```bash
|
||||
VLLM_USE_V1=0 vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --max-num-seqs
|
||||
256 --gpu_memory_utilization 0.2 --enable-chunked-prefill
|
||||
```
|
||||
|
||||
```bash
|
||||
vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --no-enable-prefix-caching --gpu-memory-utilization 0.3 --max-num-batched-tokens 131072 --host 0.0.0.0
|
||||
```
|
||||
|
||||
|
||||
```bash
|
||||
python -m sglang.launch_server --model-path open-r1/OlympicCoder-32B --tensor-parallel-size 8 --port 8080 --host 0.0.0.0 --max-running-requests 256 --context-length 16400 --mem-fraction-static 0.2 --schedule-conservativeness 0.3 --chunked-prefill-size 131072 --schedule-policy fcfs --skip-tokenizer-init
|
||||
```
|
||||
|
||||
@@ -15,7 +15,12 @@
|
||||
"""
|
||||
Plugin init to add KD support to Axolotl.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
|
||||
|
||||
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
@@ -28,9 +33,75 @@ class KDPlugin(BasePlugin):
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.kd.KDArgs"
|
||||
|
||||
def get_training_args_mixin(self):
|
||||
return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
if cfg.kd_trainer:
|
||||
from .trainer import AxolotlKDTrainer
|
||||
|
||||
return AxolotlKDTrainer
|
||||
return None
|
||||
|
||||
def get_training_args(self, cfg):
|
||||
return {
|
||||
"kd_ce_alpha": cfg.kd_ce_alpha,
|
||||
"kd_alpha": cfg.kd_alpha,
|
||||
"kd_temperature": cfg.kd_temperature,
|
||||
"kd_beta": cfg.kd_beta,
|
||||
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||
}
|
||||
|
||||
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||
if not cfg.kd_trainer:
|
||||
return None, None
|
||||
|
||||
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
|
||||
|
||||
use_batch_sampler_collator = False
|
||||
if is_eval is False and cfg.sample_packing:
|
||||
use_batch_sampler_collator = True
|
||||
if cfg.eval_sample_packing and is_eval:
|
||||
use_batch_sampler_collator = True
|
||||
|
||||
if cfg.kd_online_server_base_url:
|
||||
from .collator_online_teacher import OnlineTeacherCollator
|
||||
|
||||
return OnlineTeacherCollator, {
|
||||
"kd_online_server_base_url": cfg.kd_online_server_base_url,
|
||||
"kd_online_topk": cfg.kd_online_topk,
|
||||
"kd_temperature": cfg.kd_temperature,
|
||||
"kd_online_server": cfg.kd_online_server,
|
||||
"kd_online_timeout": cfg.kd_online_timeout,
|
||||
"kd_normalize_topk": cfg.kd_normalize_topk,
|
||||
}
|
||||
|
||||
if use_batch_sampler_collator:
|
||||
return KDBatchSamplerDataCollatorForSeq2Seq, {}
|
||||
return DataCollatorForKD, {}
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from .kernels.models import apply_kernel
|
||||
|
||||
apply_kernel(cfg.model_config_type)
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
|
||||
"""
|
||||
Adds temp scheduler callback to the Trainer instance.
|
||||
|
||||
Args:
|
||||
cfg (Any): Configuration object containing the sparse recipe.
|
||||
trainer (Trainer): Huggingface Trainer instance.
|
||||
|
||||
Returns:
|
||||
list: List containing the configured callback instances.
|
||||
"""
|
||||
if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
|
||||
callback = KDTemperatureSchedulerCallback(
|
||||
cfg.kd_temperature,
|
||||
cfg.kd_temperature_min,
|
||||
trainer,
|
||||
)
|
||||
return [callback]
|
||||
|
||||
return []
|
||||
|
||||
@@ -15,9 +15,19 @@
|
||||
"""
|
||||
Plugin args for KD support.
|
||||
"""
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InferenceServerType(str, Enum):
|
||||
"""
|
||||
Online inferences server types to handle different request args
|
||||
"""
|
||||
|
||||
vllm = "vllm" # pylint: disable=invalid-name
|
||||
sglang = "sglang" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class KDArgs(BaseModel):
|
||||
@@ -25,13 +35,41 @@ class KDArgs(BaseModel):
|
||||
Input args for knowledge distillation.
|
||||
"""
|
||||
|
||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
||||
kd_ce_alpha: Optional[float] = (
|
||||
kd_trainer: float | None = None # whether to use KD trainer
|
||||
kd_ce_alpha: float | None = (
|
||||
None # loss coefficient for cross-entropy loss during KD
|
||||
)
|
||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||
kd_top_k_before_softmax: Optional[bool] = (
|
||||
None # whether to sample top k before softmax during KD
|
||||
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||
kd_temperature: float | None = None # temperature for sampling during KD
|
||||
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||
kd_normalize_topk: bool | None = (
|
||||
None # whether to normalize student logits during KD
|
||||
)
|
||||
|
||||
# TODO online kd
|
||||
kd_online_server_base_url: str | None = None
|
||||
kd_online_topk: int | None = None
|
||||
kd_online_server: InferenceServerType | None = Field(
|
||||
default_factory=lambda: InferenceServerType.vllm
|
||||
)
|
||||
kd_online_timeout: int | None = 120
|
||||
kd_temperature_min: float | None = (
|
||||
None # kd temperature scheduling during online kd
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KDTrainingArgsMixin:
|
||||
"""
|
||||
Additional args for KD training.
|
||||
"""
|
||||
|
||||
kd_ce_alpha: float | None = (
|
||||
None # loss coefficient for cross-entropy loss during KD
|
||||
)
|
||||
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||
kd_temperature: float | None = None # temperature for sampling during KD
|
||||
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||
kd_normalize_topk: float | None = (
|
||||
None # whether to normalize student logits during KD
|
||||
)
|
||||
|
||||
36
src/axolotl/integrations/kd/callbacks.py
Normal file
36
src/axolotl/integrations/kd/callbacks.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Transformers trainer callbacks to schedule the KD temperature during training
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
|
||||
class KDTemperatureSchedulerCallback(TrainerCallback):
|
||||
"""
|
||||
KD temperature scheduler callback for the trainer.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature_start, temperature_min, trainer):
|
||||
self.temperature_start = temperature_start
|
||||
self.temperature_min = temperature_min
|
||||
self.temperature = temperature_start
|
||||
|
||||
self.trainer = trainer
|
||||
|
||||
def on_step_end(
|
||||
self, args, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
# cosine decay temperature over the max steps
|
||||
|
||||
progress = state.global_step / state.max_steps
|
||||
# Cosine decay factor: 0.5 * (1 + cos(pi * progress))
|
||||
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
|
||||
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||
self.temperature = self.temperature_start - (
|
||||
(self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
|
||||
)
|
||||
|
||||
if hasattr(self.trainer.data_collator, "kd_temperature"):
|
||||
self.trainer.data_collator.kd_temperature = self.temperature
|
||||
@@ -15,12 +15,15 @@
|
||||
"""
|
||||
Chat template prompt strategy loader with KD support
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
"""
|
||||
@@ -101,10 +104,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||
# otherwise, we need to shift in the trainer
|
||||
shift = 0
|
||||
for _ in range(shift, input_padding_len):
|
||||
# we shift for causal models in the trainer, so start the range from 0
|
||||
for _ in range(0, input_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
@@ -143,6 +144,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# normalize probabilities to sum to 1 in case they aren't already
|
||||
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||
if teacher_probs_t1_sum > 1e-9:
|
||||
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
@@ -162,12 +167,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(position_token_ids)
|
||||
|
||||
if shift == 1:
|
||||
# since we started at index 1 for causal, we need one more padding token
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
@@ -184,13 +183,124 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
||||
"""
|
||||
Strat for datasets with complete structured KD logprob data
|
||||
"""
|
||||
|
||||
def transform_logprobs(self, sample):
|
||||
"""
|
||||
Transform logprobs to target format for KD training
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
logprobs = sample.pop(self.logprobs_field)
|
||||
target_seq_len = len(logprobs)
|
||||
input_seq_len = len(sample["input_ids"])
|
||||
input_padding_len = input_seq_len - target_seq_len
|
||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||
top_k_vals = [
|
||||
len(logprobs[i])
|
||||
for i in range(len(logprobs))
|
||||
if logprobs[i] is not None and len(logprobs[i])
|
||||
]
|
||||
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
|
||||
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
|
||||
top_k = min(max_top_k, min_top_k)
|
||||
if top_k == 0:
|
||||
raise ValueError("No non-zero top-k logprobs found.")
|
||||
|
||||
target_logprobs = []
|
||||
target_token_ids = []
|
||||
target_mask = []
|
||||
|
||||
if input_padding_len < 0:
|
||||
# logprobs is longer than target_seq_len,
|
||||
# so we need to slice from the left/beginning of logprobs
|
||||
logprobs = logprobs[:-input_seq_len]
|
||||
input_padding_len = 0
|
||||
# target_seq_len = input_seq_len
|
||||
|
||||
# truncate the second dimension of the logprobs to top_k
|
||||
logprobs = [row[:top_k] for row in logprobs]
|
||||
|
||||
# fill with -inf for padding_len tokens for top_k tokens
|
||||
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
|
||||
|
||||
# we shift for causal models in the trainer, so start the range from 0
|
||||
for _ in range(0, input_padding_len):
|
||||
target_logprobs.append([-float("inf")] * top_k)
|
||||
target_token_ids.append(list(range(top_k)))
|
||||
target_mask.append([0] * top_k)
|
||||
|
||||
for position in range(input_padding_len, input_seq_len):
|
||||
if sample["labels"][position] == -100:
|
||||
target_mask.append([0] * top_k)
|
||||
else:
|
||||
target_mask.append([1] * top_k)
|
||||
|
||||
for token_pos_logprobs, pos_target_token_ids in zip(
|
||||
logprobs, sample["target_token_ids"]
|
||||
):
|
||||
# Convert to a tensor for easier manipulation
|
||||
position_logprobs_tensor = torch.tensor(
|
||||
token_pos_logprobs, dtype=torch.float
|
||||
)
|
||||
|
||||
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
|
||||
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
|
||||
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
|
||||
#
|
||||
# Convert from log to probability
|
||||
teacher_probs_t1 = position_logprobs_tensor.exp()
|
||||
# normalize probabilities to sum to 1 in case they aren't already
|
||||
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
|
||||
if teacher_probs_t1_sum > 1e-9:
|
||||
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
|
||||
if self.kd_temperature != self.gen_temperature:
|
||||
# Exponentiate by factor (T1 / T2)
|
||||
exponent = self.gen_temperature / self.kd_temperature
|
||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||
else:
|
||||
teacher_probs_t2 = teacher_probs_t1
|
||||
# Re-normalize
|
||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
# Convert back to log
|
||||
position_logprobs_tensor = torch.log(teacher_probs_t2)
|
||||
|
||||
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
|
||||
position_logprobs_scaled = position_logprobs_tensor.tolist()
|
||||
|
||||
target_logprobs.append(position_logprobs_scaled)
|
||||
target_token_ids.append(pos_target_token_ids)
|
||||
|
||||
# Update sample with transformed logprobs
|
||||
sample["target_logprobs"] = target_logprobs
|
||||
sample["target_token_ids"] = target_token_ids
|
||||
sample["target_mask"] = target_mask
|
||||
|
||||
return sample
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
target_token_ids = prompt.pop("target_token_ids")
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
class KDStrategyLoader(StrategyLoader):
|
||||
"""
|
||||
Load ChatTemplateStrategy with KD support using StrategyLoader.
|
||||
"""
|
||||
|
||||
def _get_strategy_cls(self):
|
||||
return ChatTemplateStrategyWithKD
|
||||
return ChatTemplateStrategyWithKDv2
|
||||
|
||||
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
||||
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
|
||||
|
||||
@@ -47,11 +47,16 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
position_pad_token_id: int = 0
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
max_len = 0
|
||||
|
||||
# Pad labels and position_ids first
|
||||
for feature_name, pad_token_id in [
|
||||
@@ -102,7 +107,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
|
||||
target_mask_list.append(f.pop("target_mask"))
|
||||
|
||||
# Determine max lengths
|
||||
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
|
||||
max_teacher_seq_len = max_len or max(
|
||||
len(seq) for seq in target_logprobs_list
|
||||
)
|
||||
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
|
||||
|
||||
padded_target_logprobs = []
|
||||
@@ -209,7 +216,9 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
||||
# We want to produce a single "merged" feature dict for each sub-batch.
|
||||
out_features = [{} for _ in features]
|
||||
|
||||
for i, sub_features in enumerate(features):
|
||||
for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks
|
||||
features
|
||||
):
|
||||
# sub_features is a list of dicts, each dict = one sequence’s features
|
||||
# We'll merge them into out_features[i].
|
||||
#
|
||||
@@ -243,10 +252,17 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
|
||||
# For example, input_ids or labels are often arrays.
|
||||
arrays = []
|
||||
for feat in sub_features:
|
||||
if field_name in feat:
|
||||
if field_name in feat and isinstance(
|
||||
feat[field_name], (list, torch.Tensor)
|
||||
):
|
||||
if isinstance(
|
||||
feat[field_name][0], (dict, str)
|
||||
): # pylint: disable=too-many-nested-blocks
|
||||
continue
|
||||
arr = np.array(feat[field_name])
|
||||
arrays.append(arr)
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
if arrays:
|
||||
out_features[i][field_name] = np.concatenate(arrays)
|
||||
|
||||
# 3) Now call the parent collator, which will do:
|
||||
# - padding of labels/position_ids
|
||||
|
||||
561
src/axolotl/integrations/kd/collator_online_teacher.py
Normal file
561
src/axolotl/integrations/kd/collator_online_teacher.py
Normal file
@@ -0,0 +1,561 @@
|
||||
"""
|
||||
Packed data loader for online teacher training supporting vllm and sglang.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from orjson import orjson
|
||||
|
||||
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||
from axolotl.utils.data.utils import retry_on_request_exceptions
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):
|
||||
"""
|
||||
Create HMAC-SHA hash from a list of integers
|
||||
|
||||
Args:
|
||||
int_list: List of integers
|
||||
key: Secret key (string or bytes)
|
||||
hash_func: Hash function (default: sha256)
|
||||
|
||||
Returns:
|
||||
HMAC digest as hex string
|
||||
"""
|
||||
# Convert key to bytes if it's a string
|
||||
if isinstance(key, str):
|
||||
key = key.encode("utf-8")
|
||||
|
||||
# Convert list of ints to bytes
|
||||
# Method 1: Convert each int to bytes and concatenate
|
||||
data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list)
|
||||
|
||||
# Create HMAC
|
||||
h = hmac.new(key, data, hash_func)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||
"""
|
||||
Collator for online teacher training.
|
||||
"""
|
||||
|
||||
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
kd_online_server_base_url: Optional[str] = None,
|
||||
kd_online_topk: Optional[int] = None,
|
||||
kd_temperature: Optional[float] = 1.0,
|
||||
kd_online_server: Optional[str] = "vllm",
|
||||
kd_online_timeout: Optional[int] = 120,
|
||||
kd_cache_dir: Optional[str] = None,
|
||||
kd_normalize_topk: Optional[bool] = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if kd_online_server_base_url is None:
|
||||
raise ValueError(
|
||||
"kd_online_server_base_url must be provided for OnlineTeacherDataloader"
|
||||
)
|
||||
if kd_online_topk is None or kd_online_topk <= 0:
|
||||
raise ValueError(
|
||||
"kd_online_topk must be a positive integer for OnlineTeacherDataloader"
|
||||
)
|
||||
|
||||
self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
|
||||
self.kd_online_topk = kd_online_topk
|
||||
self.kd_temperature = kd_temperature
|
||||
self.kd_online_server = kd_online_server
|
||||
self.http_session = requests.Session()
|
||||
self.kd_online_timeout = kd_online_timeout
|
||||
self.kd_cache_dir = kd_cache_dir
|
||||
self.kd_normalize_topk = kd_normalize_topk
|
||||
|
||||
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
|
||||
"""
|
||||
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
||||
"""
|
||||
if not raw_logprobs or self.kd_online_topk == 0:
|
||||
return (
|
||||
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
|
||||
)
|
||||
|
||||
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
|
||||
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
|
||||
|
||||
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||
def fetch_online_logprobs_sglang(
|
||||
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||
):
|
||||
"""
|
||||
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
|
||||
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||
"""
|
||||
api_endpoint = f"{self.kd_online_server_base_url}/generate"
|
||||
|
||||
payload = {
|
||||
"input_ids": batch_input_ids,
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": self.kd_online_topk,
|
||||
"logprob_start_len": 0,
|
||||
"return_text_in_logprobs": True,
|
||||
"echo": True,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 0,
|
||||
"temperature": self.kd_temperature,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
}
|
||||
|
||||
# Initialize with empty lists, so if API call fails, these are returned.
|
||||
ret_data_target_token_ids: List[List[List[int]]] = []
|
||||
ret_data_target_logprobs: List[List[List[float]]] = []
|
||||
ret_data_target_mask: List[List[List[int]]] = []
|
||||
|
||||
try:
|
||||
response = self.http_session.post(
|
||||
api_endpoint, json=payload, timeout=self.kd_online_timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
api_data: list[dict] = response.json()
|
||||
|
||||
# Ensure api_data is a list, and its length matches batch_input_ids
|
||||
if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids):
|
||||
LOG.error(
|
||||
f"API response format error. Expected a list of {len(batch_input_ids)} "
|
||||
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
||||
)
|
||||
# Return empty data; items processed later will get default empty KD fields
|
||||
return {
|
||||
"target_token_ids": ret_data_target_token_ids,
|
||||
"target_logprobs": ret_data_target_logprobs,
|
||||
"target_mask": ret_data_target_mask,
|
||||
}
|
||||
|
||||
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||
api_data, batch_input_ids, labels
|
||||
):
|
||||
current_target_logprobs = []
|
||||
current_target_token_ids = []
|
||||
current_target_mask = []
|
||||
|
||||
meta_info = sequence_data.pop("meta_info", {})
|
||||
# Ensure input_top_logprobs is a list
|
||||
input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
|
||||
"input_top_logprobs", []
|
||||
)
|
||||
if not isinstance(input_top_logprobs, list):
|
||||
LOG.warning(
|
||||
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
||||
)
|
||||
input_top_logprobs = [] # Treat as empty
|
||||
|
||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||
|
||||
for i, _, label in zip(
|
||||
range(len(seq_input_ids)), seq_input_ids, seq_labels
|
||||
):
|
||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||
# this is always the case for the first token.
|
||||
# there is never logprob data for the first token since that's a true input
|
||||
# so we replace the None value with padding data
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
elif (
|
||||
i < len(input_top_logprobs)
|
||||
and input_top_logprobs[i] is not None
|
||||
):
|
||||
pos_top_logprobs_data = input_top_logprobs[i]
|
||||
# Ensure pos_top_logprobs_data is a list of lists as expected
|
||||
if not (
|
||||
isinstance(pos_top_logprobs_data, list)
|
||||
and all(
|
||||
isinstance(item, list) for item in pos_top_logprobs_data
|
||||
)
|
||||
and len(pos_top_logprobs_data) > 0
|
||||
and len(pos_top_logprobs_data[0]) == 3
|
||||
): # [logprob, token_id, token_str]
|
||||
LOG.warning(
|
||||
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
continue
|
||||
|
||||
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
||||
pos_logprobs_raw, pos_token_ids, _ = [
|
||||
list(row) for row in zip(*pos_top_logprobs_data)
|
||||
]
|
||||
|
||||
# Ensure correct length (top_k)
|
||||
if len(pos_logprobs_raw) < self.kd_online_topk:
|
||||
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
||||
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
||||
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
||||
|
||||
# truncate to top_k in case the response was longer
|
||||
current_target_token_ids.append(
|
||||
pos_token_ids[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
if self.kd_normalize_topk:
|
||||
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
normalized_logprobs_for_position
|
||||
)
|
||||
else:
|
||||
current_target_logprobs.append(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
# Mask depends on the corresponding label for the student
|
||||
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
else:
|
||||
current_target_mask.append([1] * self.kd_online_topk)
|
||||
else:
|
||||
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
|
||||
ret_data_target_token_ids.append(current_target_token_ids)
|
||||
ret_data_target_logprobs.append(current_target_logprobs)
|
||||
ret_data_target_mask.append(current_target_mask)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||
raise e
|
||||
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
||||
except Exception as e: # Catch other potential errors during processing
|
||||
LOG.error(
|
||||
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
|
||||
return {
|
||||
"target_token_ids": ret_data_target_token_ids,
|
||||
"target_logprobs": ret_data_target_logprobs,
|
||||
"target_mask": ret_data_target_mask,
|
||||
}
|
||||
|
||||
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||
def fetch_online_logprobs_vllm(
|
||||
self, batch_input_ids: List[List[int]], labels: List[List[int]]
|
||||
):
|
||||
"""
|
||||
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
|
||||
Assumes API returns token IDs as strings in logprob dictionary keys.
|
||||
"""
|
||||
api_endpoint = f"{self.kd_online_server_base_url}/v1/completions"
|
||||
|
||||
payload = {
|
||||
"prompt": batch_input_ids,
|
||||
"echo": True,
|
||||
"logprobs": True,
|
||||
"prompt_logprobs": self.kd_online_topk,
|
||||
"top_logprobs": self.kd_online_topk,
|
||||
"max_new_tokens": 0,
|
||||
"skip_special_tokens": False,
|
||||
"temperature": self.kd_temperature,
|
||||
"sampling_params": {
|
||||
"max_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# Initialize with empty lists, so if API call fails, these are returned.
|
||||
ret_data_target_token_ids: List[List[List[int]]] = []
|
||||
ret_data_target_logprobs: List[List[List[float]]] = []
|
||||
ret_data_target_mask: List[List[List[int]]] = []
|
||||
|
||||
try:
|
||||
headers = {"Accept-Encoding": "deflate, gzip, br, zstd"}
|
||||
response = self.http_session.post(
|
||||
api_endpoint,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=self.kd_online_timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
api_data: dict = orjson.loads(response.content)
|
||||
choices: list[dict] = api_data["choices"]
|
||||
|
||||
# Ensure api_data is a list, and its length matches batch_input_ids
|
||||
if not isinstance(choices, list) or len(choices) != len(batch_input_ids):
|
||||
LOG.error(
|
||||
f"API response format error. Expected a list of {len(batch_input_ids)} "
|
||||
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
||||
)
|
||||
# Return empty data; items processed later will get default empty KD fields
|
||||
return {
|
||||
"target_token_ids": ret_data_target_token_ids,
|
||||
"target_logprobs": ret_data_target_logprobs,
|
||||
"target_mask": ret_data_target_mask,
|
||||
}
|
||||
|
||||
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||
choices, batch_input_ids, labels
|
||||
):
|
||||
# seq_input_ids: List[int]
|
||||
# seq_labels: List[int]
|
||||
|
||||
current_target_logprobs = []
|
||||
current_target_token_ids = []
|
||||
current_target_mask = []
|
||||
|
||||
# Ensure input_top_logprobs is a list
|
||||
input_top_logprobs: Optional[list[None | dict[str, dict]]] = (
|
||||
sequence_data.pop("prompt_logprobs", [])
|
||||
)
|
||||
|
||||
if not isinstance(input_top_logprobs, list):
|
||||
LOG.warning(
|
||||
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
||||
)
|
||||
input_top_logprobs = [] # Treat as empty
|
||||
|
||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||
|
||||
seq_len = len(seq_input_ids)
|
||||
|
||||
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
|
||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||
# this is always the case for the first token.
|
||||
# there is never logprob data for the first token since that's a true input
|
||||
continue
|
||||
if (
|
||||
i < len(input_top_logprobs)
|
||||
and input_top_logprobs[i] is not None
|
||||
):
|
||||
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment]
|
||||
# Ensure pos_top_logprobs_data is a list of lists as expected
|
||||
if not (
|
||||
isinstance(pos_top_logprobs_data, dict)
|
||||
and all(
|
||||
isinstance(item, dict)
|
||||
for item in pos_top_logprobs_data.values()
|
||||
)
|
||||
and len(pos_top_logprobs_data.keys()) > 0
|
||||
): # [logprob, token_id, token_str]
|
||||
LOG.warning(
|
||||
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append(
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
continue
|
||||
|
||||
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
||||
pos_token_ids_str = list(pos_top_logprobs_data.keys())
|
||||
pos_logprobs_dict = pos_top_logprobs_data.values()
|
||||
pos_token_ids = [
|
||||
int(token_id) for token_id in pos_token_ids_str
|
||||
]
|
||||
pos_logprobs_raw = [
|
||||
float(logprob.get("logprob", -float("inf")))
|
||||
for logprob in pos_logprobs_dict
|
||||
]
|
||||
|
||||
# Ensure correct length (top_k)
|
||||
if len(pos_logprobs_raw) < self.kd_online_topk:
|
||||
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
|
||||
LOG.warning(
|
||||
f"Padding position {i} with {pad_len} top-k tokens and logprobs."
|
||||
)
|
||||
pos_logprobs_raw.extend([-float("inf")] * pad_len)
|
||||
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
|
||||
|
||||
# truncate to top_k in case the response was longer
|
||||
current_target_token_ids.append(
|
||||
pos_token_ids[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
if self.kd_normalize_topk:
|
||||
normalized_logprobs_for_position = self._normalize_logprobs(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
current_target_logprobs.append(
|
||||
normalized_logprobs_for_position
|
||||
)
|
||||
else:
|
||||
current_target_logprobs.append(
|
||||
pos_logprobs_raw[: self.kd_online_topk]
|
||||
)
|
||||
|
||||
# Mask depends on the corresponding label for the student
|
||||
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
else:
|
||||
current_target_mask.append([1] * self.kd_online_topk)
|
||||
else:
|
||||
# Pad if no logprobs for this position (either due to length mismatch or None entry)
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append(
|
||||
list(range(self.kd_online_topk))
|
||||
)
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
for i in range(max(0, seq_len - len(current_target_logprobs))):
|
||||
current_target_logprobs.append(
|
||||
[-float("inf")] * self.kd_online_topk
|
||||
)
|
||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
||||
current_target_mask.append([0] * self.kd_online_topk)
|
||||
|
||||
ret_data_target_token_ids.append(current_target_token_ids)
|
||||
ret_data_target_logprobs.append(current_target_logprobs)
|
||||
ret_data_target_mask.append(current_target_mask)
|
||||
|
||||
# TODO save and load targets to disk for caching for next epoch
|
||||
# generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int
|
||||
# if self.kd_cache_dir:
|
||||
# hash_input_ids = hmac_sha_from_int_list(
|
||||
# seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}"
|
||||
# )
|
||||
# with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f:
|
||||
# pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||
raise e
|
||||
# ret_logprobs_data will be returned with empty lists, handled by the caller.
|
||||
except Exception as e: # Catch other potential errors during processing
|
||||
LOG.error(
|
||||
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
|
||||
return {
|
||||
"target_token_ids": ret_data_target_token_ids,
|
||||
"target_logprobs": ret_data_target_logprobs,
|
||||
"target_mask": ret_data_target_mask,
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
if not features:
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
for (
|
||||
sub_batch_features
|
||||
) in features: # sub_batch_features is List[Dict[str, Any]]
|
||||
if not sub_batch_features:
|
||||
continue
|
||||
|
||||
input_ids_for_api_call: List[List[int]] = []
|
||||
labels_for_api_call: List[List[int]] = []
|
||||
# Store references to the original item dictionaries to update them in-place
|
||||
items_for_api_call: List[Dict[str, Any]] = []
|
||||
|
||||
for item_dict in sub_batch_features:
|
||||
if not isinstance(item_dict, dict):
|
||||
LOG.warning(
|
||||
f"Skipping non-dict item in sub_batch_features: {item_dict}"
|
||||
)
|
||||
continue
|
||||
|
||||
current_input_ids = item_dict.get("input_ids")
|
||||
current_labels = item_dict.get("labels")
|
||||
|
||||
if current_input_ids is not None and current_labels is not None:
|
||||
# Ensure input_ids and labels are lists of ints for JSON serialization
|
||||
input_ids_list = (
|
||||
current_input_ids.tolist()
|
||||
if hasattr(current_input_ids, "tolist")
|
||||
else list(current_input_ids)
|
||||
)
|
||||
labels_list = (
|
||||
current_labels.tolist()
|
||||
if hasattr(current_labels, "tolist")
|
||||
else list(current_labels)
|
||||
)
|
||||
|
||||
input_ids_for_api_call.append(input_ids_list)
|
||||
labels_for_api_call.append(labels_list)
|
||||
items_for_api_call.append(item_dict)
|
||||
else:
|
||||
# This item will not get teacher logprobs from the API.
|
||||
# Initialize KD fields to empty lists so downstream collators handle them uniformly.
|
||||
item_dict.setdefault("target_token_ids", [])
|
||||
item_dict.setdefault("target_logprobs", [])
|
||||
item_dict.setdefault("target_mask", [])
|
||||
|
||||
# print(items_for_api_call)
|
||||
if items_for_api_call: # Only call API if there's something to process
|
||||
if self.kd_online_server == "sglang":
|
||||
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
|
||||
input_ids_for_api_call, labels_for_api_call
|
||||
)
|
||||
else:
|
||||
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
|
||||
input_ids_for_api_call, labels_for_api_call
|
||||
)
|
||||
|
||||
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
|
||||
# Each value is a list, corresponding to items_for_api_call
|
||||
for i, item_to_update in enumerate(items_for_api_call):
|
||||
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
|
||||
if api_responses_for_sub_batch and i < len(
|
||||
api_responses_for_sub_batch["target_token_ids"]
|
||||
): # Check bounds
|
||||
assert len(
|
||||
api_responses_for_sub_batch["target_token_ids"][i]
|
||||
) == len(item_to_update["input_ids"])
|
||||
assert len(
|
||||
api_responses_for_sub_batch["target_logprobs"][i]
|
||||
) == len(item_to_update["input_ids"])
|
||||
assert len(
|
||||
api_responses_for_sub_batch["target_mask"][i]
|
||||
) == len(item_to_update["labels"])
|
||||
item_to_update["target_token_ids"] = (
|
||||
api_responses_for_sub_batch["target_token_ids"][i]
|
||||
)
|
||||
item_to_update["target_logprobs"] = api_responses_for_sub_batch[
|
||||
"target_logprobs"
|
||||
][i]
|
||||
item_to_update["target_mask"] = api_responses_for_sub_batch[
|
||||
"target_mask"
|
||||
][i]
|
||||
else:
|
||||
# API call failed for this item, or response was shorter than expected.
|
||||
# Ensure KD fields are initialized as empty lists.
|
||||
LOG.warning(
|
||||
f" (index {i}), or API response was too short. "
|
||||
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
|
||||
)
|
||||
item_to_update.setdefault("target_token_ids", [])
|
||||
item_to_update.setdefault("target_logprobs", [])
|
||||
item_to_update.setdefault("target_mask", [])
|
||||
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
485
src/axolotl/integrations/kd/kernels/liger.py
Normal file
485
src/axolotl/integrations/kd/kernels/liger.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
Liger Kernels for Chunked Top-K Log-Prob Distillation
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from liger_kernel.chunked_loss.fused_linear_distillation import (
|
||||
LigerFusedLinearDistillationBase,
|
||||
)
|
||||
|
||||
from axolotl.integrations.kd.utils import normalize_logprobs
|
||||
|
||||
|
||||
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
"""
|
||||
Chunked kl-div loss for top-k logprobs
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def distillation_loss_fn(
|
||||
student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled
|
||||
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
||||
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||
beta: float = 0.0,
|
||||
normalize_topk: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Top-K KL divergence loss for a chunk.
|
||||
Args:
|
||||
student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).
|
||||
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
|
||||
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
|
||||
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
|
||||
beta: Controls the type of KL divergence.
|
||||
0.0 for Forward KL (P_teacher || P_student).
|
||||
1.0 for Reverse KL (P_student || P_teacher).
|
||||
0.5 for Symmetric KL (average of Forward and Reverse).
|
||||
normalize_topk: Whether to normalize the log probabilities
|
||||
Returns:
|
||||
Sum of KL divergence losses for the chunk.
|
||||
"""
|
||||
topk = target_token_ids_chunk.shape[-1]
|
||||
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
|
||||
student_logits_temp_scaled.float()
|
||||
)
|
||||
target_logprobs_chunk = target_logprobs_chunk.float()
|
||||
|
||||
# Gather student logits for the top-k teacher token IDs
|
||||
# target_token_ids_chunk: [chunk_size, top_k]
|
||||
# student_logits_topk_temp_scaled: [chunk_size, top_k]
|
||||
student_logits_topk_temp_scaled = torch.gather(
|
||||
student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk
|
||||
)
|
||||
|
||||
# Student log-probabilities for the gathered top-k tokens
|
||||
student_lse = torch.logsumexp(
|
||||
student_logits_temp_scaled, dim=-1, keepdim=True
|
||||
) # [chunk_size, 1]
|
||||
student_logprobs_topk_temp_scaled = (
|
||||
student_logits_topk_temp_scaled - student_lse
|
||||
)
|
||||
|
||||
# we have the top-k student logprobs, normalize them
|
||||
if normalize_topk:
|
||||
student_logprobs_topk_temp_scaled = normalize_logprobs(
|
||||
student_logprobs_topk_temp_scaled, topk
|
||||
)
|
||||
|
||||
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
|
||||
|
||||
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
|
||||
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
|
||||
|
||||
# Teacher probabilities P(y|x_teacher) from logprobs
|
||||
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
|
||||
teacher_probs_valid = teacher_logprobs_valid.exp()
|
||||
# Student probabilities P_student from log P_student
|
||||
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
||||
|
||||
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
|
||||
|
||||
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
|
||||
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
|
||||
# The distillation loss is often formulated as -sum(P_teacher * log P_student)
|
||||
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
|
||||
# Here, target_logprobs_valid are log_softmax_teacher.
|
||||
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
||||
if beta == 0.0: # Contribution from Forward KL
|
||||
fwd_kl_per_token = teacher_probs_valid * (
|
||||
teacher_logprobs_valid - student_logprobs_topk_valid
|
||||
)
|
||||
kd_loss = fwd_kl_per_token.sum()
|
||||
elif beta == 1.0: # Contribution from Reverse KL
|
||||
rev_kl_per_token = student_probs_topk_valid * (
|
||||
student_logprobs_topk_valid - teacher_logprobs_valid
|
||||
)
|
||||
kd_loss = rev_kl_per_token.sum()
|
||||
else:
|
||||
# JSD - Jensen-Shannon Divergence / Symmetric
|
||||
mean_probs = (
|
||||
1 - beta
|
||||
) * student_probs_topk_valid + beta * teacher_probs_valid
|
||||
log_mean_probs = mean_probs.log()
|
||||
student_kl = F.kl_div(
|
||||
log_mean_probs,
|
||||
student_logprobs_topk_valid,
|
||||
reduction="sum",
|
||||
log_target=True,
|
||||
)
|
||||
teacher_kl = F.kl_div(
|
||||
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
|
||||
)
|
||||
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
|
||||
kd_loss = jsd_loss
|
||||
|
||||
return kd_loss
|
||||
|
||||
@staticmethod
|
||||
def _compute_loss_kl_topk(
|
||||
student_input_chunk: torch.Tensor,
|
||||
student_weight: torch.Tensor,
|
||||
# Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value
|
||||
# or through `partial`. Let's make them explicit here for clarity.
|
||||
target_token_ids_chunk: torch.Tensor,
|
||||
target_logprobs_chunk: torch.Tensor,
|
||||
target_mask_chunk: torch.Tensor,
|
||||
target_chunk: torch.Tensor, # For hard loss (true labels)
|
||||
student_bias: torch.Tensor = None, # This will be one of the grad targets
|
||||
# Other params passed via `partial` from `forward`
|
||||
distillation_loss_fn=None,
|
||||
ignore_index: int = -100,
|
||||
weight_hard_loss: float = 0.5,
|
||||
weight_soft_loss: float = 0.5,
|
||||
compute_ce_loss: bool = True,
|
||||
temperature: float = 1.0,
|
||||
beta: float = 0.0,
|
||||
normalize_topk: bool = True,
|
||||
):
|
||||
# Compute student logits for the chunk from hidden states and LM head
|
||||
# student_input_chunk: [chunk_size, hidden_dim]
|
||||
# student_lm_head_weight: [vocab_size, hidden_dim]
|
||||
# student_logits_chunk: [chunk_size, vocab_size]
|
||||
student_logits_chunk = F.linear(
|
||||
student_input_chunk, student_weight, student_bias
|
||||
)
|
||||
|
||||
ce_loss = torch.tensor(
|
||||
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
|
||||
)
|
||||
if compute_ce_loss and weight_hard_loss > 0.0:
|
||||
ce_loss = F.cross_entropy(
|
||||
student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),
|
||||
target_chunk.view(-1),
|
||||
reduction="sum",
|
||||
ignore_index=ignore_index,
|
||||
)
|
||||
|
||||
soft_loss = torch.tensor(
|
||||
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
|
||||
)
|
||||
if weight_soft_loss > 0.0:
|
||||
student_logits_chunk_temp_scaled = student_logits_chunk / temperature
|
||||
|
||||
# Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()
|
||||
# No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.
|
||||
|
||||
soft_loss = distillation_loss_fn(
|
||||
student_logits_chunk_temp_scaled,
|
||||
target_token_ids_chunk,
|
||||
target_logprobs_chunk,
|
||||
target_mask_chunk,
|
||||
beta=beta,
|
||||
normalize_topk=normalize_topk,
|
||||
)
|
||||
|
||||
return soft_loss, ce_loss
|
||||
|
||||
@classmethod
|
||||
def forward(
|
||||
cls,
|
||||
ctx,
|
||||
student_input: torch.Tensor, # [batch_size, seq_len, dim]
|
||||
student_lm_head_weight: torch.Tensor, # [dim, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||
target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||
target_mask: torch.Tensor, # [batch_size, seq_len, top_k]
|
||||
true_labels: torch.Tensor, # [batch_size, seq_len]
|
||||
student_lm_head_bias: torch.Tensor = None,
|
||||
weight_hard_loss: float = 0.5,
|
||||
weight_soft_loss: float = 0.5,
|
||||
ignore_index: int = -100,
|
||||
temperature: float = 1.0,
|
||||
beta: float = 0.0,
|
||||
compiled: bool = False,
|
||||
chunk_size: int = 1024,
|
||||
compute_ce_loss: bool = True,
|
||||
normalize_topk: bool = True,
|
||||
):
|
||||
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
|
||||
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
|
||||
grad_inputs_list = []
|
||||
grad_bias_acc = (
|
||||
torch.zeros_like(student_lm_head_bias)
|
||||
if student_lm_head_bias is not None
|
||||
else None
|
||||
)
|
||||
kd_loss_acc = torch.zeros(
|
||||
(), device=student_input.device, dtype=student_input.dtype
|
||||
)
|
||||
ce_loss_acc = torch.zeros(
|
||||
(), device=student_input.device, dtype=student_input.dtype
|
||||
)
|
||||
|
||||
# This function will be what torch.func.grad_and_value differentiates.
|
||||
# It takes student_input_chunk, student_weight (full), student_bias (full) as primals.
|
||||
# Other necessary data (target_*, etc.) are passed as non-differentiable arguments.
|
||||
def loss_fn_for_grad(
|
||||
_student_input_chunk,
|
||||
_student_lm_head_weight, # full weight
|
||||
_student_lm_head_bias, # full bias
|
||||
# Fixed arguments for a given chunk, not differentiated:
|
||||
_target_token_ids_chunk,
|
||||
_target_logprobs_chunk,
|
||||
_target_mask_chunk,
|
||||
_true_labels_chunk,
|
||||
):
|
||||
return cls._compute_loss_kl_topk(
|
||||
student_input_chunk=_student_input_chunk,
|
||||
student_weight=_student_lm_head_weight,
|
||||
target_token_ids_chunk=_target_token_ids_chunk,
|
||||
target_logprobs_chunk=_target_logprobs_chunk,
|
||||
target_mask_chunk=_target_mask_chunk,
|
||||
target_chunk=_true_labels_chunk,
|
||||
student_bias=_student_lm_head_bias,
|
||||
distillation_loss_fn=cls.distillation_loss_fn,
|
||||
ignore_index=ignore_index,
|
||||
weight_hard_loss=weight_hard_loss,
|
||||
weight_soft_loss=weight_soft_loss,
|
||||
compute_ce_loss=compute_ce_loss,
|
||||
temperature=temperature,
|
||||
beta=beta,
|
||||
normalize_topk=normalize_topk,
|
||||
)
|
||||
|
||||
def accumulate_chunk_grads(
|
||||
student_input_chunk_ac,
|
||||
target_token_ids_chunk_ac,
|
||||
target_logprobs_chunk_ac,
|
||||
target_mask_chunk_ac,
|
||||
true_labels_chunk_ac,
|
||||
):
|
||||
# student_weight and student_bias are closed over from the outer scope (full tensors)
|
||||
if student_lm_head_bias is not None:
|
||||
(
|
||||
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
||||
(chunk_kd_loss, chunk_ce_loss),
|
||||
) = torch.func.grad_and_value(
|
||||
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
|
||||
)(
|
||||
student_input_chunk_ac,
|
||||
student_lm_head_weight,
|
||||
student_lm_head_bias, # primals
|
||||
target_token_ids_chunk_ac,
|
||||
target_logprobs_chunk_ac,
|
||||
target_mask_chunk_ac,
|
||||
true_labels_chunk_ac,
|
||||
) # non-primals
|
||||
grad_bias_acc.add_(chunk_grad_bias)
|
||||
else:
|
||||
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
|
||||
(
|
||||
(chunk_grad_input, chunk_grad_weight), # No grad for bias
|
||||
(chunk_kd_loss, chunk_ce_loss),
|
||||
) = torch.func.grad_and_value(
|
||||
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
|
||||
)(
|
||||
student_input_chunk_ac,
|
||||
student_lm_head_weight,
|
||||
None, # Pass None for student_bias primal
|
||||
target_token_ids_chunk_ac,
|
||||
target_logprobs_chunk_ac,
|
||||
target_mask_chunk_ac,
|
||||
true_labels_chunk_ac,
|
||||
)
|
||||
|
||||
grad_weight_acc.add_(chunk_grad_weight)
|
||||
kd_loss_acc.add_(chunk_kd_loss)
|
||||
ce_loss_acc.add_(chunk_ce_loss)
|
||||
|
||||
return chunk_grad_input
|
||||
|
||||
if compiled:
|
||||
accumulate_chunk_grads_compiled = torch.compile(
|
||||
accumulate_chunk_grads, dynamic=True, backend="inductor"
|
||||
) # dynamic=True often helpful
|
||||
else:
|
||||
accumulate_chunk_grads_compiled = accumulate_chunk_grads
|
||||
|
||||
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
|
||||
B, N, D = student_input.shape # pylint: disable=invalid-name
|
||||
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
|
||||
|
||||
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
|
||||
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
|
||||
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
|
||||
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
|
||||
# pad and shift for cross entropy loss
|
||||
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
|
||||
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
|
||||
|
||||
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
|
||||
|
||||
_student_input_chunks = torch.chunk(
|
||||
student_input_flat, chunks=num_chunks, dim=0
|
||||
)
|
||||
_target_token_ids_chunks = torch.chunk(
|
||||
target_token_ids_flat, chunks=num_chunks, dim=0
|
||||
)
|
||||
_target_logprobs_chunks = torch.chunk(
|
||||
target_logprobs_flat, chunks=num_chunks, dim=0
|
||||
)
|
||||
_target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)
|
||||
_true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)
|
||||
|
||||
for i in range(num_chunks):
|
||||
grad_input_chunk = accumulate_chunk_grads_compiled(
|
||||
_student_input_chunks[i],
|
||||
_target_token_ids_chunks[i],
|
||||
_target_logprobs_chunks[i],
|
||||
_target_mask_chunks[i],
|
||||
_true_labels_chunks[i],
|
||||
)
|
||||
grad_inputs_list.append(grad_input_chunk)
|
||||
|
||||
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
|
||||
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||
|
||||
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
|
||||
ctx.bias_was_none = student_lm_head_bias is None
|
||||
ctx.orig_dims = (B, N, D, K)
|
||||
|
||||
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
|
||||
# we still need to scale the kd_loss by the temp^2
|
||||
kd_loss_acc = kd_loss_acc * (temperature**2)
|
||||
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
||||
|
||||
return final_loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_input_flat, grad_weight, grad_bias_maybe = (
|
||||
ctx.saved_tensors
|
||||
) # grad_input_flat is (B*N, D)
|
||||
|
||||
# Scale gradients by grad_output if it's not 1.0
|
||||
if not torch.equal(
|
||||
grad_output,
|
||||
torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),
|
||||
):
|
||||
grad_input_flat = grad_input_flat * grad_output
|
||||
grad_weight = grad_weight * grad_output
|
||||
if grad_bias_maybe is not None:
|
||||
grad_bias_maybe = grad_bias_maybe * grad_output
|
||||
|
||||
# Reshape grad_input_flat to match original student_input shape (B, N, D)
|
||||
# ctx.orig_dims stores (B, N, D, K)
|
||||
# We need the first three dimensions for student_input's shape.
|
||||
# Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors
|
||||
if (
|
||||
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
|
||||
and grad_input_flat.numel() == 0
|
||||
):
|
||||
# If original input was empty, gradient should also be empty with correct shape
|
||||
grad_input_reshaped = torch.zeros(
|
||||
ctx.orig_dims[0],
|
||||
ctx.orig_dims[1],
|
||||
ctx.orig_dims[2],
|
||||
dtype=grad_input_flat.dtype,
|
||||
device=grad_input_flat.device,
|
||||
)
|
||||
elif grad_input_flat.numel() == 0 and not (
|
||||
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
|
||||
):
|
||||
# This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)
|
||||
# but as a safeguard:
|
||||
grad_input_reshaped = torch.zeros(
|
||||
ctx.orig_dims[0],
|
||||
ctx.orig_dims[1],
|
||||
ctx.orig_dims[2],
|
||||
dtype=grad_input_flat.dtype,
|
||||
device=grad_input_flat.device,
|
||||
)
|
||||
else:
|
||||
grad_input_reshaped = grad_input_flat.view(
|
||||
ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]
|
||||
)
|
||||
|
||||
nones_for_hyperparams = [None] * ctx.hyperparams_count
|
||||
grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None
|
||||
|
||||
return (
|
||||
grad_input_reshaped, # Gradient for student_input (reshaped)
|
||||
grad_weight, # Gradient for student_lm_head_weight
|
||||
None, # Gradient for target_token_ids
|
||||
None, # Gradient for target_logprobs
|
||||
None, # Gradient for target_mask
|
||||
None, # Gradient for true_labels
|
||||
grad_bias_return, # Gradient for student_lm_head_bias
|
||||
*nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss
|
||||
)
|
||||
|
||||
|
||||
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
||||
"""
|
||||
wrapper for chunked top-k logprob kl-d
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_hard_loss: float = 0.5,
|
||||
weight_soft_loss: float = 0.5,
|
||||
temperature: float = 1.0, # This is the kd_temperature
|
||||
beta: float = 1.0,
|
||||
ignore_index: int = -100,
|
||||
compiled: bool = True,
|
||||
chunk_size: int = 1024,
|
||||
compute_ce_loss: bool = True,
|
||||
normalize_topk: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
|
||||
raise ValueError("Loss weights must be between 0.0 and 1.0.")
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be positive.")
|
||||
|
||||
self.weight_hard_loss = weight_hard_loss
|
||||
self.weight_soft_loss = weight_soft_loss
|
||||
self.temperature = temperature
|
||||
self.beta = beta
|
||||
self.ignore_index = ignore_index
|
||||
self.compiled = compiled
|
||||
self.chunk_size = chunk_size
|
||||
self.compute_ce_loss = compute_ce_loss
|
||||
self.normalize_topk = normalize_topk
|
||||
|
||||
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
|
||||
print(
|
||||
f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero."
|
||||
)
|
||||
# self.weight_hard_loss = 0.0 # Or let user manage this
|
||||
if self.weight_soft_loss == 0.0:
|
||||
print(
|
||||
"Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head
|
||||
student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head
|
||||
target_token_ids: torch.Tensor,
|
||||
target_logprobs: torch.Tensor,
|
||||
target_mask: torch.Tensor,
|
||||
true_labels: torch.Tensor,
|
||||
student_bias: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
return LigerFusedLinearKLTopKLogprobFunction.apply(
|
||||
student_hidden_states,
|
||||
lm_head_weight,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
true_labels,
|
||||
student_bias,
|
||||
self.weight_hard_loss,
|
||||
self.weight_soft_loss,
|
||||
self.ignore_index,
|
||||
self.temperature,
|
||||
self.beta,
|
||||
self.compiled,
|
||||
self.chunk_size,
|
||||
self.compute_ce_loss,
|
||||
self.normalize_topk,
|
||||
)
|
||||
97
src/axolotl/integrations/kd/kernels/models.py
Normal file
97
src/axolotl/integrations/kd/kernels/models.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
model patcher for chunked top-k kl-div
|
||||
"""
|
||||
|
||||
from typing import Optional, Union, Unpack
|
||||
|
||||
import torch
|
||||
from transformers import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import LossKwargs
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
||||
"""
|
||||
placeholder kwargs for hf model classes
|
||||
"""
|
||||
|
||||
|
||||
def kldiv_forward_llama_like(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
target_logprobs: Optional[torch.Tensor] = None,
|
||||
target_token_ids: Optional[torch.LongTensor] = None,
|
||||
target_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
|
||||
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
|
||||
) -> CausalLMOutputWithPast:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
|
||||
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
loss = self.loss_function(
|
||||
self.lm_head.weight,
|
||||
hidden_states,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
true_labels=labels,
|
||||
)
|
||||
num_items_in_batch = kwargs.pop("num_items_in_batch", -1)
|
||||
if num_items_in_batch is not None and num_items_in_batch > 0:
|
||||
loss = loss / num_items_in_batch
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=None,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_kernel(model_type):
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||
model_cls.forward = kldiv_forward_llama_like
|
||||
@@ -16,40 +16,7 @@
|
||||
loss for top_k KL divergence
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def zscore_standardize(
|
||||
logits: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
base_temperature: float = 1.0,
|
||||
eps: float = 1e-9,
|
||||
):
|
||||
"""
|
||||
Z-score standardize along the last dimension of `logits`.
|
||||
i.e., for each [B, seq_len] row, across K entries:
|
||||
z = (logits - mean) / std,
|
||||
then scale by 1 / base_temperature if desired.
|
||||
|
||||
mask can be broadcastable or None. If None, we standardize all elements.
|
||||
"""
|
||||
if mask is None:
|
||||
# shape: [B, seq_len, K]
|
||||
# Mean and std over dim=-1
|
||||
mean = logits.mean(dim=-1, keepdim=True)
|
||||
var = logits.var(dim=-1, unbiased=False, keepdim=True)
|
||||
else:
|
||||
# If you have to exclude some tokens, multiply by mask, etc.
|
||||
float_mask = mask.to(logits.dtype)
|
||||
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
|
||||
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
|
||||
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
|
||||
|
||||
std = torch.sqrt(var.clamp_min(eps))
|
||||
z = (logits - mean) / std
|
||||
|
||||
# Scale by 1 / base_temperature
|
||||
z = z / base_temperature
|
||||
return z
|
||||
from torch import nn
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -60,7 +27,6 @@ def loss(
|
||||
target_mask: torch.Tensor,
|
||||
num_items_in_batch: int = -1, # Use -1 to indicate "None"
|
||||
kd_temperature: float = 1.0,
|
||||
top_k_before_softmax: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A KD loss function that is TorchScript-friendly.
|
||||
@@ -77,8 +43,6 @@ def loss(
|
||||
num_items_in_batch (int, optional): The number of items in the batch.
|
||||
kd_temperature (float, optional): The temperature for KD.
|
||||
Default: 1.0
|
||||
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
|
||||
Default: 0
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
@@ -88,46 +52,24 @@ def loss(
|
||||
# student_logits shape: [B, student_seq_len, vocab_size]
|
||||
teacher_seq_len = target_token_ids.shape[1]
|
||||
|
||||
if top_k_before_softmax:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, teacher_seq_len, vocab_size]
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = (
|
||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||
) # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
# keep in full precision for numerical stability of loss
|
||||
student_logits_for_kd = student_logits_for_kd.float()
|
||||
|
||||
student_logits_topk = student_logits_topk.float()
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Apply KD temperature to student’s logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_topk = student_logits_topk / kd_temperature
|
||||
# Compute logsumexp across full vocabulary
|
||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||
|
||||
# Convert student top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - torch.logsumexp(
|
||||
student_logits_topk, dim=-1, keepdim=True
|
||||
) # [B, teacher_seq_len, K]
|
||||
else:
|
||||
# Slice student logits to match teacher-provided sequence length
|
||||
student_logits_for_kd = (
|
||||
student_logits[:, :teacher_seq_len, :] / kd_temperature
|
||||
) # [B, teacher_seq_len, vocab_size]
|
||||
|
||||
# keep in full precision for numerical stability of loss
|
||||
student_logits_for_kd = student_logits_for_kd.float()
|
||||
|
||||
# Gather student logits for teacher's top-K tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, teacher_seq_len, K]
|
||||
|
||||
# Compute logsumexp across full vocabulary
|
||||
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
|
||||
|
||||
# Convert just the top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - student_lse
|
||||
# Convert just the top-k logits to logprobs
|
||||
student_logprobs_topk = student_logits_topk - student_lse
|
||||
|
||||
# Convert teacher_mask to boolean for indexing
|
||||
# In TorchScript, .bool() is sometimes unsupported, so we do:
|
||||
@@ -144,10 +86,6 @@ def loss(
|
||||
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# Multiply by T^2 (classical KD scaling)
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items (if provided) or by valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
@@ -158,80 +96,74 @@ def loss(
|
||||
return kd_loss
|
||||
|
||||
|
||||
def topk_kd_loss_with_zscore(
|
||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
|
||||
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
|
||||
kd_temperature: float = 1.0, # classic KD temperature
|
||||
zscore_base_temp: float = 1.0, # from the paper
|
||||
num_items_in_batch: int = -1,
|
||||
):
|
||||
class ChunkedTopKKDLoss(nn.Module):
|
||||
"""
|
||||
A variant of top_k KL divergence with Z-score scaling
|
||||
from "Logit Standardization in Knowledge Distillation".
|
||||
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
|
||||
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
|
||||
|
||||
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
|
||||
"""
|
||||
|
||||
target_logprobs = target_logprobs.float()
|
||||
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
|
||||
super().__init__()
|
||||
self.num_output_chunks = num_output_chunks
|
||||
self.kd_temperature = kd_temperature
|
||||
|
||||
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
|
||||
# 1) Gather the student's top-k logits to match teacher
|
||||
student_logits_for_kd = student_logits[
|
||||
:, :teacher_seq_len, :
|
||||
] # [B, seq_len, vocab]
|
||||
student_topk_logits = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
) # [B, seq_len, K]
|
||||
def forward(
|
||||
self,
|
||||
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
|
||||
target_token_ids: torch.Tensor, # [B, seq_len, K]
|
||||
target_logprobs: torch.Tensor, # [B, seq_len, K]
|
||||
target_mask: torch.Tensor, # [B, seq_len, K]
|
||||
num_items_in_batch: int = -1, # optional batch size for normalization
|
||||
) -> torch.Tensor:
|
||||
|
||||
student_topk_logits = student_topk_logits.float()
|
||||
# 1. Split along the "token" dimension (dim=1).
|
||||
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
|
||||
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
|
||||
logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)
|
||||
mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)
|
||||
|
||||
# 2) If you want to keep the "classical" T scaling, apply it first
|
||||
if kd_temperature != 1.0:
|
||||
student_topk_logits = student_topk_logits / kd_temperature
|
||||
# We'll accumulate a global "sum of losses" and "sum of valid tokens"
|
||||
# so that our final average is consistent with the entire sequence/batch.
|
||||
total_loss = 0.0
|
||||
total_valid_tokens = 0
|
||||
|
||||
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
|
||||
# (They differ by +some_constant from real logits, but in z-score
|
||||
# that constant is subtracted out anyway.)
|
||||
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
|
||||
# 2. Loop over each chunk and compute a chunk-specific loss.
|
||||
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
|
||||
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
|
||||
):
|
||||
# We pass num_items_in_batch=-1 so that the kd_loss
|
||||
# will average over *this chunk's* valid tokens only.
|
||||
chunk_loss = loss(
|
||||
student_logits=st_chunk,
|
||||
target_token_ids=tid_chunk,
|
||||
target_logprobs=lp_chunk,
|
||||
target_mask=msk_chunk,
|
||||
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
|
||||
kd_temperature=self.kd_temperature,
|
||||
)
|
||||
|
||||
# 4) Z-score teacher and student
|
||||
# If target_mask is 2D, expand to 3D for the K dimension
|
||||
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
|
||||
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
|
||||
# kd_loss returns an average over the chunk's valid tokens.
|
||||
# We want a global average in the end, so we need to re‐weight
|
||||
# by the number of valid tokens in this chunk and keep track of the total.
|
||||
chunk_valid_mask = msk_chunk.to(torch.bool)
|
||||
chunk_valid_count = chunk_valid_mask.sum() # scalar tensor
|
||||
|
||||
teacher_z = zscore_standardize(
|
||||
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
student_z = zscore_standardize(
|
||||
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
|
||||
)
|
||||
# Re-scale "chunk average" back to "chunk sum"
|
||||
chunk_loss_sum = chunk_loss * chunk_valid_count
|
||||
|
||||
# 5) Convert to log-probs for KL
|
||||
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
|
||||
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
|
||||
total_loss += chunk_loss_sum
|
||||
total_valid_tokens += chunk_valid_count
|
||||
|
||||
# 6) Restrict to valid tokens if needed
|
||||
valid_mask = target_mask.bool() # shape [B, seq_len, K]
|
||||
teacher_probs_z = teacher_logprobs_z.exp()
|
||||
teacher_probs_z = teacher_probs_z[valid_mask]
|
||||
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
|
||||
student_logprobs_z = student_logprobs_z[valid_mask]
|
||||
# 3. Normalize *once* at the end.
|
||||
if num_items_in_batch > 0:
|
||||
# If the user gave us a manual denominator (e.g. total items in batch),
|
||||
# we divide by it. Typically used if each item is of different length.
|
||||
final_loss = total_loss / float(num_items_in_batch)
|
||||
else:
|
||||
# Otherwise, divide by total valid tokens across all chunks.
|
||||
# to get the same result as a non-chunked approach.
|
||||
final_loss = total_loss / float(total_valid_tokens)
|
||||
|
||||
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
|
||||
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
|
||||
kd_loss = kd_loss_per_token.sum()
|
||||
|
||||
# 8) If using classical KD scaling by T^2
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
|
||||
# kd_loss = kd_loss * (zscore_base_temp**2)
|
||||
|
||||
# 9) Normalize
|
||||
if num_items_in_batch is not None and num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
|
||||
|
||||
return kd_loss
|
||||
return final_loss
|
||||
|
||||
@@ -18,8 +18,7 @@ KD trainer
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
@@ -27,6 +26,18 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
Custom trainer subclass for Knowledge Distillation (KD)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_accepts_loss_kwargs = True
|
||||
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
|
||||
self.args.kd_ce_alpha, # hard label loss
|
||||
self.args.kd_alpha, # kd loss
|
||||
self.args.kd_temperature,
|
||||
self.args.kd_beta,
|
||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||
normalize_topk=self.args.kd_normalize_topk,
|
||||
)
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
columns_to_add = []
|
||||
@@ -52,12 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
|
||||
target_logprobs = inputs.pop("target_logprobs")
|
||||
target_token_ids = inputs.pop("target_token_ids")
|
||||
target_mask = inputs.pop("target_mask")
|
||||
|
||||
seq_len = target_token_ids.shape[1]
|
||||
if (
|
||||
self.args.sample_packing
|
||||
and hasattr(inputs, "attention_mask")
|
||||
and hasattr(inputs, "position_ids")
|
||||
):
|
||||
del inputs["attention_mask"]
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
@@ -65,49 +76,4 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
outputs = model(**inputs)
|
||||
|
||||
# FIXME: account for tokenizer.padding_side
|
||||
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
|
||||
|
||||
shift_logits = student_logits.contiguous()
|
||||
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
|
||||
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
|
||||
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
|
||||
|
||||
if self.args.kd_zscore_base_temp:
|
||||
loss_kd = topk_kd_loss_with_zscore(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
zscore_base_temp=self.args.kd_zscore_base_temp,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_kd = topk_kd_loss(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
target_mask_for_loss,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
kd_temperature=self.args.kd_temperature,
|
||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||
)
|
||||
|
||||
if self.args.kd_ce_alpha > 0:
|
||||
kd_alpha = self.args.kd_alpha
|
||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
||||
else:
|
||||
loss = loss_kd
|
||||
# Save past state if it exists
|
||||
# TODO: this needs to be fixed and made cleaner later.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
|
||||
self.args.past_index
|
||||
]
|
||||
|
||||
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
|
||||
loss *= self.accelerator.num_processes
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
return outputs[0]
|
||||
|
||||
100
src/axolotl/integrations/kd/utils.py
Normal file
100
src/axolotl/integrations/kd/utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Helper KD utils"""
|
||||
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import FloatTensor, Tensor
|
||||
|
||||
|
||||
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
|
||||
"""
|
||||
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
|
||||
"""
|
||||
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
|
||||
# This should ideally be handled by the caller ensuring correct padding/truncation first
|
||||
if logprobs.shape[-1] != topk:
|
||||
# pad last dimension of logprobs to match topk length with -inf
|
||||
padding_len = topk - logprobs.shape[-1]
|
||||
padding_tensor = torch.full(
|
||||
(
|
||||
*logprobs.shape[:-1],
|
||||
padding_len,
|
||||
), # Takes all dimensions of logprobs except the last, then appends padding_needed
|
||||
float("-inf"),
|
||||
dtype=logprobs.dtype,
|
||||
device=logprobs.device,
|
||||
)
|
||||
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
|
||||
|
||||
# Convert logprobs at T_online to probabilities
|
||||
# use log sum exp trick to avoid underflow
|
||||
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
|
||||
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
|
||||
|
||||
# Normalize probabilities (sum to 1)
|
||||
# This is important if the top-k from server aren't a full distribution
|
||||
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
|
||||
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
|
||||
|
||||
final_logprobs_tensor = torch.log(teacher_probs_t_online)
|
||||
|
||||
return final_logprobs_tensor
|
||||
|
||||
|
||||
def strided_chunk_views(
|
||||
tensor: Union[np.ndarray, torch.Tensor],
|
||||
chunks: int,
|
||||
dim: int = 0,
|
||||
stride: int = 1,
|
||||
chunk_size: int | None = None,
|
||||
) -> List[Union[np.ndarray, torch.Tensor]]:
|
||||
"""
|
||||
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor (numpy array or torch tensor)
|
||||
chunks: Number of chunks to create
|
||||
dim: Dimension along which to chunk (default: 0)
|
||||
stride: Stride between chunk starting positions (default: 1)
|
||||
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
|
||||
|
||||
Returns:
|
||||
List of tensor chunks (views when possible, copies when necessary)
|
||||
"""
|
||||
|
||||
# Get the size of the specified dimension
|
||||
dim_size = tensor.shape[dim]
|
||||
|
||||
# Calculate chunk size if not provided
|
||||
if chunk_size is None:
|
||||
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
|
||||
|
||||
chunks_list = []
|
||||
|
||||
for i in range(chunks):
|
||||
start_idx = i * stride
|
||||
end_idx = min(start_idx + chunk_size, dim_size)
|
||||
|
||||
# Break if we've gone beyond the tensor
|
||||
if start_idx >= dim_size:
|
||||
break
|
||||
|
||||
# Create slice objects for all dimensions
|
||||
slices = [slice(None)] * tensor.ndim
|
||||
slices[dim] = slice(start_idx, end_idx)
|
||||
|
||||
chunk = tensor[tuple(slices)]
|
||||
chunks_list.append(chunk)
|
||||
|
||||
return chunks_list
|
||||
|
||||
|
||||
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
|
||||
dim_size = input_tensor.shape[dim]
|
||||
stride = math.ceil(dim_size / chunks)
|
||||
|
||||
return strided_chunk_views(
|
||||
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
|
||||
)
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
|
||||
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
|
||||
their context parallel version of Flash Attention 2.
|
||||
their sequence parallel version of Flash Attention 2.
|
||||
|
||||
We also provide some patches for accelerate functions to prepare the dataloader for
|
||||
context parallelism training.
|
||||
sequence parallelism training.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
@@ -13,9 +13,9 @@ import inspect
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -63,15 +63,15 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
|
||||
|
||||
def register_ring_attn(
|
||||
context_parallel_degree: int,
|
||||
sequence_parallel_degree: int,
|
||||
heads_k_stride: int | None,
|
||||
ring_attn_func: RingAttnFunc | None,
|
||||
):
|
||||
"""Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
context_parallel_degree: Context parallelism factor.
|
||||
heads_k_stride: Context parallelism K head stride size. Passed through to
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
||||
@@ -80,18 +80,28 @@ def register_ring_attn(
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
LOG.info(
|
||||
"Enabling ring attention context parallelism: "
|
||||
f"each sequence will be processed across {context_parallel_degree} GPUs"
|
||||
if rank == 0:
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
||||
)
|
||||
|
||||
assert sequence_parallel_degree <= world_size, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must be less than or equal to world_size ({world_size})"
|
||||
)
|
||||
assert world_size % sequence_parallel_degree == 0, (
|
||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
||||
f"must evenly divide world_size ({world_size})"
|
||||
)
|
||||
|
||||
# Assign ranks to context parallel groups
|
||||
# Assign ranks to sequence parallel groups
|
||||
group_assignments = {}
|
||||
for i in range(world_size // context_parallel_degree):
|
||||
for i in range(world_size // sequence_parallel_degree):
|
||||
ring_attn_ranks = list(
|
||||
range(
|
||||
i * context_parallel_degree,
|
||||
(i + 1) * context_parallel_degree,
|
||||
i * sequence_parallel_degree,
|
||||
(i + 1) * sequence_parallel_degree,
|
||||
)
|
||||
)
|
||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||
@@ -103,7 +113,9 @@ def register_ring_attn(
|
||||
if rank in ring_attn_ranks:
|
||||
set_ring_attn_group(group)
|
||||
|
||||
LOG.info(f"Context parallel group assignments: {group_assignments}")
|
||||
# Log the GPU group assignments
|
||||
if rank == 0:
|
||||
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
|
||||
|
||||
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
@@ -138,7 +150,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
||||
|
||||
|
||||
def patch_prepare_data_loader():
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the CP degree.
|
||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||
|
||||
Raies:
|
||||
RuntimeError: If source code to patch does not exist.
|
||||
@@ -164,15 +176,15 @@ def patch_prepare_data_loader():
|
||||
patched_function = namespace["prepare_data_loader"]
|
||||
|
||||
accelerate.data_loader.prepare_data_loader = patched_function
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for CP support")
|
||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||
|
||||
|
||||
def patch_prepare_device_mesh(context_parallel_degree: int):
|
||||
def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||
that includes context parallelism with the specified degree.
|
||||
that includes sequence parallelism with the specified degree.
|
||||
|
||||
Args:
|
||||
context_parallel_degree (int): The degree of context parallelism to use.
|
||||
sequence_parallel_degree (int): The degree of sequence parallelism to use.
|
||||
"""
|
||||
|
||||
def _prepare_device_mesh(self):
|
||||
@@ -187,11 +199,11 @@ def patch_prepare_device_mesh(context_parallel_degree: int):
|
||||
):
|
||||
return self.state.ds_device_mesh
|
||||
|
||||
# Create device mesh with context parallelism
|
||||
# Create device mesh with sequence parallelism
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // context_parallel_degree,
|
||||
context_parallel_degree,
|
||||
world_size // sequence_parallel_degree,
|
||||
sequence_parallel_degree,
|
||||
)
|
||||
device_ids = list(range(world_size))
|
||||
|
||||
@@ -209,5 +221,5 @@ def patch_prepare_device_mesh(context_parallel_degree: int):
|
||||
|
||||
LOG.info(
|
||||
"Successfully patched Accelerator._prepare_device_mesh "
|
||||
f"with context_parallel_degree={context_parallel_degree}"
|
||||
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import typing
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
@@ -31,7 +34,7 @@ from axolotl.loaders import (
|
||||
load_processor,
|
||||
load_tokenizer,
|
||||
)
|
||||
from axolotl.utils.ctx_managers import ContextParallelContextManager
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
@@ -44,6 +47,9 @@ try:
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -147,7 +153,7 @@ def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||
|
||||
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PeftModel | PreTrainedModel, safe_serialization: bool
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
"""
|
||||
Set up signal handler for graceful termination.
|
||||
@@ -201,20 +207,15 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.context_parallel_degree > 1 and not cfg.sdp_attention:
|
||||
# Models to enter context parallel manager for
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
models = [trainer.model]
|
||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||
models.append(trainer.ref_model)
|
||||
|
||||
# Attention backend
|
||||
backend = "sdp_attention" if cfg.sdp_attention else "flash_attention"
|
||||
|
||||
stack.enter_context(
|
||||
ContextParallelContextManager(
|
||||
SequenceParallelContextManager(
|
||||
models=models,
|
||||
backend=backend,
|
||||
context_parallel_degree=cfg.context_parallel_degree,
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
@@ -228,7 +229,7 @@ def execute_training(
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
trainer: Any,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
model: PreTrainedModel,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
"""
|
||||
@@ -379,7 +380,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
|
||||
def save_initial_configs(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
model: PreTrainedModel,
|
||||
peft_config: PeftConfig | None,
|
||||
processor: ProcessorMixin | None,
|
||||
):
|
||||
@@ -433,7 +434,7 @@ def setup_model_card(cfg: DictDefault):
|
||||
|
||||
def handle_untrained_tokens_fix(
|
||||
cfg: DictDefault,
|
||||
model: PeftModel | PreTrainedModel,
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
train_dataset: Dataset,
|
||||
safe_serialization: bool,
|
||||
@@ -476,7 +477,7 @@ def handle_untrained_tokens_fix(
|
||||
|
||||
|
||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||
Trainer,
|
||||
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
|
||||
@@ -52,3 +52,10 @@ def patch_optimized_env():
|
||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
|
||||
def get_not_null(value, default=None):
|
||||
"""
|
||||
return the value if it's not None, otherwise return the default value
|
||||
"""
|
||||
return value if value is not None else default
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,7 +1,7 @@
|
||||
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -161,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
features: List[List[dict]] = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Init for context manager submodule."""
|
||||
"""Init for context manager submodule"""
|
||||
|
||||
from .context_parallel.manager import ContextParallelContextManager
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
__all__ = ["ContextParallelContextManager"]
|
||||
from .sequence_parallel import SequenceParallelContextManager
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
# BSD 3-Clause License
|
||||
|
||||
# Copyright 2024 Meta
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without modification,
|
||||
# are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice,this list
|
||||
# of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice, this
|
||||
# list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its contributors may
|
||||
# be used to endorse or promote products derived from this software without specific
|
||||
# prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
|
||||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
|
||||
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
|
||||
# SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
|
||||
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
|
||||
# TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
|
||||
# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
|
||||
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
||||
# DAMAGE.
|
||||
|
||||
"""
|
||||
Distributed utils for SDPA context parallel implementation. Slightly modified from
|
||||
https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5c2/torchtune/training/_distributed.py.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from typing import Callable, Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
|
||||
|
||||
def _get_sdpa_context() -> (
|
||||
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
|
||||
):
|
||||
"""
|
||||
Creates a context manager to confine to flash/efficient/cuDNN attention backends.
|
||||
|
||||
Returns:
|
||||
A context manager function that takes an optional context parallel context.
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(cp_context: Union[Generator[None, None, None], None] = None):
|
||||
with contextlib.ExitStack() as stack:
|
||||
if cp_context is not None:
|
||||
stack.enter_context(
|
||||
sdpa_kernel(
|
||||
[
|
||||
SDPBackend.FLASH_ATTENTION,
|
||||
SDPBackend.EFFICIENT_ATTENTION,
|
||||
SDPBackend.CUDNN_ATTENTION,
|
||||
]
|
||||
)
|
||||
)
|
||||
stack.enter_context(cp_context)
|
||||
|
||||
yield
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def get_context_parallel_manager(
|
||||
*,
|
||||
world_mesh: torch.distributed.DeviceMesh,
|
||||
model: nn.Module,
|
||||
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
|
||||
"""
|
||||
Context manager for applying context parallelism to a model. In addition to applying the
|
||||
standard context manager to patch SDPA and shard model inputs and buffers along the sequence
|
||||
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
|
||||
|
||||
Args:
|
||||
world_mesh: Global device mesh.
|
||||
model: Model to apply context parallelism to.
|
||||
|
||||
Returns:
|
||||
A context manager applying context parallelism if enabled is True. Otherwise a context manager
|
||||
disabling the math SDPA backend.
|
||||
|
||||
Raises:
|
||||
ValueError: if enabled is True but world_mesh does not contain a "cp" dimension
|
||||
"""
|
||||
|
||||
if "cp" not in world_mesh.mesh_dim_names:
|
||||
raise ValueError(
|
||||
"Context parallel is enabled but no context parallel device mesh is provided."
|
||||
)
|
||||
# TODO: context parallel for multimodal models requires extra work
|
||||
# if not isinstance(model, TransformerDecoder):
|
||||
# raise ValueError("Context parallel is only supported for text models")
|
||||
# model_buffers = list(model.buffers())
|
||||
|
||||
# def get_all_buffers(module, prefix=""):
|
||||
# buffers = {}
|
||||
# for name, buffer in module.named_buffers(recurse=False):
|
||||
# full_name = f"{prefix}.{name}" if prefix else name
|
||||
# buffers[full_name] = buffer
|
||||
|
||||
# for name, child in module.named_children():
|
||||
# child_prefix = f"{prefix}.{name}" if prefix else name
|
||||
# buffers.update(get_all_buffers(child, child_prefix))
|
||||
|
||||
# return buffers
|
||||
|
||||
# model_buffers = get_all_buffers(model)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context(model_inputs: list[torch.Tensor]):
|
||||
# Create context parallel context if enabled
|
||||
cp_context = None
|
||||
if any([isinstance(input, BlockMask) for input in model_inputs]):
|
||||
raise ValueError(
|
||||
"Context parallel with flex attention is not yet supported"
|
||||
)
|
||||
set_rotate_method("allgather")
|
||||
|
||||
cp_context = context_parallel(
|
||||
world_mesh["cp"],
|
||||
# buffers=model_inputs + model_buffers,
|
||||
buffers=model_inputs,
|
||||
# buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
|
||||
buffer_seq_dims=[1] * len(model_inputs),
|
||||
no_restore_buffers=set(model_inputs),
|
||||
)
|
||||
|
||||
# Create and enter the train context with the optional cp_context
|
||||
sdpa_context = _get_sdpa_context()
|
||||
|
||||
with sdpa_context(cp_context):
|
||||
yield
|
||||
|
||||
return context
|
||||
@@ -1,216 +0,0 @@
|
||||
"""Module for Axolotl trainer context parallelism manager and utilities."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Callable, Literal
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel.distributed import (
|
||||
get_context_parallel_manager,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel.utils import (
|
||||
AllGatherWithGrad,
|
||||
apply_context_parallelism,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
class ContextParallelContextManager:
|
||||
"""Context manager for context parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply context parallelism
|
||||
during model forward passes using a pre-forward hook, and gather outputs from
|
||||
across the context parallelism group using a post-forward hook.
|
||||
|
||||
Args:
|
||||
models: List of models to apply context parallelism to pre- and post- forward
|
||||
hooks.
|
||||
backend: Which attention backend to use.
|
||||
context_parallel_degree: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Context parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models: list[PreTrainedModel],
|
||||
backend: Literal["sdp_attention", "flash_attention"],
|
||||
context_parallel_degree: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
):
|
||||
self.models = models
|
||||
self.backend = backend
|
||||
self.context_parallel_degree = context_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self._register_ring_attn()
|
||||
|
||||
# Store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
if self.backend == "flash_attention":
|
||||
# Set distributed info for local rank
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Create a partially applied version of the apply_context_parallelism function
|
||||
self.apply_context_parallelism = functools.partial(
|
||||
apply_context_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
else:
|
||||
# SPDA device mesh init
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // self.context_parallel_degree,
|
||||
self.context_parallel_degree,
|
||||
)
|
||||
world_mesh = dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
# SDPA context parallel managers
|
||||
self.context_parallel_managers = []
|
||||
for model in models:
|
||||
ctx_manager = get_context_parallel_manager(
|
||||
world_mesh=world_mesh,
|
||||
model=model,
|
||||
)
|
||||
self.context_parallel_managers.append(ctx_manager)
|
||||
|
||||
def __enter__(self):
|
||||
self._register_model_hooks()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
|
||||
|
||||
def _register_ring_attn(self):
|
||||
if self.backend == "flash_attention":
|
||||
# Initialize ring attn for context parallelism
|
||||
register_ring_attn(
|
||||
context_parallel_degree=self.context_parallel_degree,
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Patches for accelerate functionality
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(context_parallel_degree=self.context_parallel_degree)
|
||||
|
||||
def _register_model_hooks(self):
|
||||
# Forward pre-hook to apply context parallelism
|
||||
def cp_flash_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
# Apply context parallelism to updated kwargs
|
||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||
self.apply_context_parallelism(updated_kwargs)
|
||||
)
|
||||
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
# Forward post-hook to gather outputs
|
||||
def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||
# Gather the sharded outputs
|
||||
output = self._gather_outputs(output)
|
||||
|
||||
# Remove padding if it was added
|
||||
if self.pad_len > 0:
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
if value.size(1) == self.original_seq_len + self.pad_len:
|
||||
# Slice to remove padding
|
||||
output[key] = value[:, : self.original_seq_len].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
def make_sdpa_pre_hook(model_idx: int) -> Callable:
|
||||
def cp_sdpa_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1}
|
||||
|
||||
with self.context_parallel_managers[model_idx](list(to_shard.values())):
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
return cp_sdpa_pre_hook
|
||||
|
||||
# Register both hooks
|
||||
for i, model in enumerate(self.models):
|
||||
if self.backend == "flash_attention":
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(cp_flash_pre_hook, with_kwargs=True)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
model.register_forward_hook(cp_flash_post_hook)
|
||||
)
|
||||
else:
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(
|
||||
make_sdpa_pre_hook(i), with_kwargs=True
|
||||
)
|
||||
)
|
||||
|
||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
||||
|
||||
return output
|
||||
@@ -1,15 +1,28 @@
|
||||
"""Utils for context parallel context manager."""
|
||||
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn.patch import update_ring_attn_params
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
|
||||
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
|
||||
def apply_context_parallelism(
|
||||
def apply_sequence_parallelism(
|
||||
batch: dict[str, torch.Tensor],
|
||||
local_rank: int,
|
||||
local_world_size: int,
|
||||
@@ -17,15 +30,15 @@ def apply_context_parallelism(
|
||||
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
|
||||
) -> tuple[dict[str, torch.Tensor], int, int]:
|
||||
"""
|
||||
Apply context parallelism slicing to a batch.
|
||||
Apply sequence parallelism slicing to a batch.
|
||||
|
||||
Special handling is implemented for integer logits_to_keep, which indicates
|
||||
to only keep the last N tokens in the input sequence during generation.
|
||||
to only keep the last N tokens in the sequence during generation.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
|
||||
local_rank: Local rank in the context parallel group.
|
||||
local_world_size: World size of the context parallel group.
|
||||
local_rank: Local rank in the sequence parallel group.
|
||||
local_world_size: World size of the sequence parallel group.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused, but
|
||||
related to above TODO.
|
||||
@@ -120,7 +133,7 @@ def apply_context_parallelism(
|
||||
# Update the total sequence length after padding
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# Slice batch for context parallel
|
||||
# Slice batch for sequence parallel
|
||||
for key in batch:
|
||||
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
|
||||
continue
|
||||
@@ -146,6 +159,144 @@ def apply_context_parallelism(
|
||||
return batch, original_seq_len, pad_len
|
||||
|
||||
|
||||
class SequenceParallelContextManager:
|
||||
"""Context manager for sequence parallelism operations.
|
||||
|
||||
This class provides a context that will automatically apply sequence parallelism
|
||||
during model forward passes using a pre-forward hook, and gather outputs from
|
||||
across the sequence parallelism group using a post-forward hook.
|
||||
|
||||
Args:
|
||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||
hooks.
|
||||
sequence_parallel_degree: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
models: list[nn.Module],
|
||||
sequence_parallel_degree: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
):
|
||||
self.models = models
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self._register_ring_attn()
|
||||
|
||||
# Set distributed info for local rank
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Will store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
|
||||
# Create a partially applied version of the apply_sequence_parallelism function
|
||||
self.apply_sequence_parallelism = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self._register_model_hooks()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
|
||||
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree,
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Patches for accelerate functionality
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree
|
||||
)
|
||||
|
||||
def _register_model_hooks(self):
|
||||
# Forward pre-hook to apply sequence parallelism
|
||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
# Apply sequence parallelism to updated kwargs
|
||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||
self.apply_sequence_parallelism(updated_kwargs)
|
||||
)
|
||||
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
# Forward post-hook to gather outputs
|
||||
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||
# Gather the sharded outputs
|
||||
output = self._gather_outputs(output)
|
||||
|
||||
# Remove padding if it was added
|
||||
if self.pad_len > 0:
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
if value.size(1) == self.original_seq_len + self.pad_len:
|
||||
# Slice to remove padding
|
||||
output[key] = value[:, : self.original_seq_len].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
# Register both hooks
|
||||
for model in self.models:
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(
|
||||
sequence_parallel_pre_hook, with_kwargs=True
|
||||
)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
model.register_forward_hook(sequence_parallel_post_hook)
|
||||
)
|
||||
|
||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AllGatherWithGrad(torch.autograd.Function):
|
||||
"""Custom autograd function for all-gather to preserve gradients."""
|
||||
|
||||
@@ -40,6 +40,7 @@ def retry_on_request_exceptions(
|
||||
except (
|
||||
requests.exceptions.ReadTimeout,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.HTTPError,
|
||||
huggingface_hub.errors.HfHubHTTPError,
|
||||
) as exc:
|
||||
if attempt < max_retries - 1:
|
||||
|
||||
@@ -258,7 +258,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
batch_max_len: int, # Maximum sequence length (bin capacity)
|
||||
lengths: np.ndarray, # Sequence lengths
|
||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
|
||||
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||
num_count_samples: int = 16, # Number of times to estimate batch count
|
||||
sequential: bool = False, # Whether to use sequential packing
|
||||
group_size: int = 100_000, # Size of groups for parallel packing
|
||||
@@ -443,10 +443,18 @@ class MultipackBatchSampler(BatchSampler):
|
||||
|
||||
if self._len_across_ranks is None:
|
||||
# Sample multiple times to get stable estimate
|
||||
len_batches = min( # pylint: disable=consider-using-generator
|
||||
[len(self._batches) for _ in range(self.num_count_samples)]
|
||||
)
|
||||
_sampled_lens = []
|
||||
for _ in range(self.num_count_samples):
|
||||
self._batches = None # Reset cached batches
|
||||
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
|
||||
len_batches = min(_sampled_lens)
|
||||
|
||||
# Gather minimum across all ranks
|
||||
self._len_across_ranks = self.gather_len_batches(len_batches)
|
||||
if self._len_across_ranks is None:
|
||||
self._len_across_ranks = self.gather_len_batches(len_batches)
|
||||
else:
|
||||
self._len_across_ranks = min(
|
||||
self._len_across_ranks, self.gather_len_batches(len_batches)
|
||||
)
|
||||
|
||||
return self._len_across_ranks
|
||||
|
||||
@@ -262,7 +262,7 @@ class AxolotlInputConfig(
|
||||
|
||||
val_set_size: float | None = Field(default=0.0)
|
||||
|
||||
context_parallel_degree: int | None = None
|
||||
sequence_parallel_degree: int | None = None
|
||||
heads_k_stride: int | None = None
|
||||
ring_attn_func: RingAttnFunc | None = None
|
||||
|
||||
@@ -1179,39 +1179,24 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_liger_context_parallel(cls, data):
|
||||
def check_grpo_liger_sequence_parallel(cls, data):
|
||||
if (
|
||||
data.get("rl") == "grpo"
|
||||
and data.get("trl", {})
|
||||
and data.get("trl").get("use_liger_loss")
|
||||
and data.get("context_parallel_degree", 1) > 1
|
||||
and data.get("sequence_parallel_degree", 1) > 1
|
||||
):
|
||||
raise ValueError("GRPO + CP + Liger not currently supported")
|
||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_context_parallel_degree(self):
|
||||
if not self.context_parallel_degree:
|
||||
self.context_parallel_degree = 1
|
||||
elif self.context_parallel_degree > 1:
|
||||
import torch
|
||||
|
||||
world_size = torch.cuda.device_count()
|
||||
if not world_size >= self.context_parallel_degree:
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
self.sequence_parallel_degree = 1
|
||||
elif self.sequence_parallel_degree > 1:
|
||||
if not self.flash_attention:
|
||||
raise ValueError(
|
||||
f"World size ({world_size}) must be greater "
|
||||
f"than or equal to CP degree ({self.context_parallel_degree})"
|
||||
)
|
||||
if not world_size % self.context_parallel_degree == 0:
|
||||
raise ValueError(
|
||||
f"SP degree ({self.context_parallel_degree}) "
|
||||
f"must evenly divide world size ({world_size})"
|
||||
)
|
||||
|
||||
if not (self.flash_attention or self.sdp_attention):
|
||||
raise ValueError(
|
||||
"flash_attention: true or sdp_attention: true "
|
||||
"must be set with context_parallel_degree > 1"
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
if self.sample_packing and self.micro_batch_size > 1:
|
||||
@@ -1220,22 +1205,21 @@ class AxolotlInputConfig(
|
||||
"due to a `ring-flash-attn` requirement"
|
||||
)
|
||||
|
||||
if self.flash_attention:
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"context_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
try:
|
||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||
except ImportError as exception:
|
||||
raise ImportError(
|
||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||
"or `pip install ring-flash-attn>=0.1.4`."
|
||||
) from exception
|
||||
|
||||
# TODO: monkeypatch / callback to average losses correctly across CP ranks
|
||||
# / fix gradient scaling across CP ranks. Losses, grads should be scaled
|
||||
# TODO: monkeypatch / callback to average losses correctly across SP ranks
|
||||
# / fix gradient scaling across SP ranks. Losses, grads should be scaled
|
||||
# according to the proportion of non-padding tokens per rank.
|
||||
LOG.warning(
|
||||
"Context parallelism (SP) is enabled with "
|
||||
f"context_parallel_degree={self.context_parallel_degree}. "
|
||||
"Sequence parallelism (SP) is enabled with "
|
||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
||||
"Please note that logged losses may differ slightly to the non-SP "
|
||||
"losses due to transformers Trainer implementation details. "
|
||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||
@@ -1246,7 +1230,7 @@ class AxolotlInputConfig(
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_ring_attn_func(self):
|
||||
if getattr(self, "context_parallel_degree", 1) == 1:
|
||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
||||
return self
|
||||
|
||||
if self.ring_attn_func is not None:
|
||||
|
||||
@@ -16,7 +16,6 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
@@ -442,7 +441,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
- 1
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_degree
|
||||
* cfg.sequence_parallel_degree
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||
@@ -479,9 +478,12 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len * cfg.num_epochs * cfg.context_parallel_degree
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||
)
|
||||
)
|
||||
if cfg.dataloader_drop_last:
|
||||
# drop the last batch for each epoch
|
||||
total_num_steps -= int(math.ceil(cfg.num_epochs))
|
||||
|
||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
@@ -502,7 +504,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.context_parallel_degree
|
||||
* cfg.sequence_parallel_degree
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
@@ -629,6 +631,8 @@ def setup_trainer(
|
||||
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
||||
on the provided parameters.
|
||||
"""
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
if (
|
||||
cfg.torch_compile
|
||||
and cfg.fsdp_config
|
||||
|
||||
@@ -64,7 +64,7 @@ def fixture_base_cfg():
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"context_parallel_degree": 1,
|
||||
"sequence_parallel_degree": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
"bf16": False,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""E2E tests for context parallelism"""
|
||||
"""E2E tests for sequence parallelism"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,10 +12,10 @@ from axolotl.utils.dict import DictDefault
|
||||
from ...utils import check_tensorboard
|
||||
|
||||
|
||||
class TestContextParallelism:
|
||||
"""Test case for training with context parallelism enabled"""
|
||||
class TestSequenceParallelism:
|
||||
"""Test case for training with sequence parallelism enabled"""
|
||||
|
||||
def _run_context_parallel_test(
|
||||
def _run_sequence_parallel_test(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing=True,
|
||||
@@ -24,7 +24,7 @@ class TestContextParallelism:
|
||||
ring_attn_func=None,
|
||||
threshold=2.0,
|
||||
):
|
||||
"""Helper method to run context parallel tests with different configurations"""
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -66,7 +66,7 @@ class TestContextParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
}
|
||||
)
|
||||
@@ -109,7 +109,7 @@ class TestContextParallelism:
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_context_parallel_training(
|
||||
def test_sequence_parallel_training(
|
||||
self,
|
||||
temp_dir,
|
||||
sample_packing,
|
||||
@@ -118,8 +118,8 @@ class TestContextParallelism:
|
||||
ring_attn_func,
|
||||
threshold,
|
||||
):
|
||||
"""Test context parallel training with different configurations"""
|
||||
self._run_context_parallel_test(
|
||||
"""Test sequence parallel training with different configurations"""
|
||||
self._run_sequence_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=sample_packing,
|
||||
micro_batch_size=micro_batch_size,
|
||||
|
||||
@@ -296,7 +296,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for context parallelism functionality."""
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel import apply_context_parallelism
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
@@ -54,8 +54,8 @@ def fixture_cfg():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context_parallel_batch():
|
||||
"""Create a test batch for context parallelism tests."""
|
||||
def sequence_parallel_batch():
|
||||
"""Create a test batch for sequence parallelism tests."""
|
||||
batch_size = 1
|
||||
seq_len = 8
|
||||
|
||||
@@ -110,7 +110,7 @@ class TestRingAttention:
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(
|
||||
context_parallel_degree=4,
|
||||
sequence_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
@@ -126,7 +126,7 @@ class TestRingAttention:
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for validating context parallelism configurations."""
|
||||
"""Tests for validating sequence parallelism configurations."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_mocks(self, monkeypatch):
|
||||
@@ -155,24 +155,24 @@ class TestConfigValidation:
|
||||
[
|
||||
# Valid configuration
|
||||
(
|
||||
{"context_parallel_degree": 2, "flash_attention": True},
|
||||
{"context_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
True,
|
||||
None,
|
||||
),
|
||||
# Default context_parallel_degree
|
||||
({}, {"context_parallel_degree": 1}, True, None),
|
||||
# Invalid: context_parallel_degree > 1 without flash_attention
|
||||
# Default sequence_parallel_degree
|
||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
||||
(
|
||||
{"context_parallel_degree": 2, "flash_attention": False},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
||||
None,
|
||||
False,
|
||||
"flash_attention: true must be set",
|
||||
),
|
||||
# Invalid: context_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
(
|
||||
{
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"micro_batch_size": 2,
|
||||
@@ -185,32 +185,32 @@ class TestConfigValidation:
|
||||
# Valid: Basic GRPO config
|
||||
(
|
||||
{
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
{
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": TRLConfig(use_liger_loss=True),
|
||||
},
|
||||
True,
|
||||
"GRPO + CP + Liger not currently supported",
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
# Invalid: GRPO config with Liger loss
|
||||
(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"GRPO + CP + Liger not currently supported",
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
@@ -222,10 +222,10 @@ class TestConfigValidation:
|
||||
"grpo_with_liger_loss",
|
||||
],
|
||||
)
|
||||
def test_context_parallel_config_validation(
|
||||
def test_sequence_parallel_config_validation(
|
||||
self, base_cfg, config_updates, expected_values, should_pass, error_msg
|
||||
):
|
||||
"""Test various context parallelism configuration scenarios."""
|
||||
"""Test various sequence parallelism configuration scenarios."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
@@ -261,7 +261,7 @@ class TestConfigValidation:
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg | {
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": sample_packing,
|
||||
}
|
||||
@@ -281,7 +281,7 @@ class TestConfigValidation:
|
||||
|
||||
# Invalid configuration with invalid ring_attn_func
|
||||
cfg = base_cfg | {
|
||||
"context_parallel_degree": 2,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"ring_attn_func": "INVALID_FUNC",
|
||||
}
|
||||
@@ -294,8 +294,8 @@ class TestConfigValidation:
|
||||
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
|
||||
|
||||
|
||||
class TestApplyContextParallelism:
|
||||
"""Tests for the apply_context_parallelism function."""
|
||||
class TestApplySequenceParallelism:
|
||||
"""Tests for the apply_sequence_parallelism function."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_distributed(self, monkeypatch):
|
||||
@@ -324,12 +324,12 @@ class TestApplyContextParallelism:
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_context_parallelism(
|
||||
batch=context_parallel_batch,
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
local_rank=0,
|
||||
local_world_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
@@ -337,17 +337,17 @@ class TestApplyContextParallelism:
|
||||
)
|
||||
|
||||
# Should return the original batch unchanged
|
||||
assert result == context_parallel_batch
|
||||
assert result == sequence_parallel_batch
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = context_parallel_batch
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
result, _, _ = apply_context_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
@@ -366,15 +366,15 @@ class TestApplyContextParallelism:
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, context_parallel_batch):
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = context_parallel_batch
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
result, _, _ = apply_context_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=1,
|
||||
local_world_size=2,
|
||||
@@ -386,14 +386,14 @@ class TestApplyContextParallelism:
|
||||
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
|
||||
|
||||
# TODO(djsaunde): add back once implemented.
|
||||
# def test_batch_zigzag(self, context_parallel_batch):
|
||||
# def test_batch_zigzag(self, sequence_parallel_batch):
|
||||
# """Test BATCH_ZIGZAG sharding pattern."""
|
||||
# batch = context_parallel_batch
|
||||
# batch = sequence_parallel_batch
|
||||
# original_input_ids = batch["input_ids"].clone()
|
||||
# seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# # Test rank 0
|
||||
# result_rank0 = apply_context_parallelism(
|
||||
# result_rank0 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=0,
|
||||
# local_world_size=2,
|
||||
@@ -401,7 +401,7 @@ class TestApplyContextParallelism:
|
||||
# )
|
||||
|
||||
# # Test rank 1
|
||||
# result_rank1 = apply_context_parallelism(
|
||||
# result_rank1 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=1,
|
||||
# local_world_size=2,
|
||||
@@ -430,17 +430,17 @@ class TestApplyContextParallelism:
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_partial_application(
|
||||
self, mock_get_ring_attn_group, context_parallel_batch
|
||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
||||
):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = context_parallel_batch
|
||||
batch = sequence_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# Create a partially applied function
|
||||
rank0_ring_parallel = functools.partial(
|
||||
apply_context_parallelism,
|
||||
apply_sequence_parallelism,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
@@ -457,14 +457,16 @@ class TestApplyContextParallelism:
|
||||
original_input_ids[:, : original_input_ids.shape[1] // 2],
|
||||
)
|
||||
|
||||
def test_missing_position_ids(self, context_parallel_batch):
|
||||
def test_missing_position_ids(self, sequence_parallel_batch):
|
||||
"""Test handling of batch without position_ids."""
|
||||
# Create a batch without position_ids
|
||||
batch = {k: v for k, v in context_parallel_batch.items() if k != "position_ids"}
|
||||
batch = {
|
||||
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
|
||||
}
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# This should run without error even though position_ids is missing
|
||||
result, _, _ = apply_context_parallelism(
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
|
||||
Reference in New Issue
Block a user