Compare commits

..

3 Commits

Author SHA1 Message Date
Salman Mohammadi
1a09d5e844 some refactoring 2025-02-19 17:35:35 +00:00
Salman Mohammadi
cf61b4aba7 Merge branch 'main' into grpo_liger 2025-02-19 16:17:42 +00:00
Salman Mohammadi
14d274efe6 WIP liger support 2025-02-19 15:34:32 +00:00
25 changed files with 358 additions and 1710 deletions

View File

@@ -4,10 +4,6 @@ on:
pull_request:
paths:
- 'tests/e2e/multigpu/*.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday

View File

@@ -37,11 +37,15 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
).env(df_args)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
app = App("Axolotl CI/CD", secrets=[])

View File

@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: # Optional[bool]
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
include_tokens_per_second:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.49.0
transformers==4.48.3
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.15.1
trl==0.15.0
optimum==1.16.2
hf_transfer

View File

@@ -123,6 +123,8 @@ class ModalCloud(Cloud):
if env := self.get_env():
image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image
def get_secrets(self):

View File

@@ -9,6 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl")
@@ -30,31 +31,21 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
training_kwargs = [
"use_vllm",
"vllm_device",
"vllm_gpu_memory_utilization",
"vllm_max_model_len",
"vllm_dtype",
"use_liger_loss",
"num_generations",
"log_completions",
"sync_ref_model",
"ref_model_mixup_alpha",
"ref_model_sync_steps",
"max_completion_length",
]
grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
return grpo_args_kwargs
@classmethod
@@ -71,9 +62,7 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
return trainer_kwargs
@classmethod

View File

@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""
use_liger_loss: bool = False

View File

@@ -1,13 +1,24 @@
"""
Axolotl GRPO trainer
"""
from contextlib import contextmanager, nullcontext
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
import torch
from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation
from axolotl.core.trainers.base import SchedulerMixin
from transformers.utils import is_liger_kernel_available
if is_liger_kernel_available():
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from accelerate.utils import broadcast_object_list, gather_object
from trl.trainer.utils import pad
# mypy: ignore-errors
@@ -20,7 +31,20 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.use_liger_loss = kwargs["args"].use_liger_loss
if self.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.grpo_loss_fn = LigerFusedLinearGRPOLoss(
beta=self.beta,
compiled=is_compiled_module(self.model),
use_ref_model=True,
num_generations=self.args.num_generations,
)
# pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing:
@@ -29,9 +53,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
if is_peft_model(self.model) and hasattr(self.model.base_model, "gradient_checkpointing_enable"):
self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"):
@@ -39,15 +61,12 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
@@ -58,9 +77,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model
# pylint: enable=unused-argument,redefined-builtin
@@ -72,25 +89,18 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
unwrapped_model = unwrapped_model._orig_mod # pylint: disable=protected-access
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
k.removeprefix("base_model.model.").removeprefix("base_model.model.").replace(".base_layer", ""): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
@@ -99,10 +109,218 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if self.use_liger_loss:
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
if self.max_prompt_length is not None:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
completion_ids = completion_ids[process_slice]
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_inputs_repeated = torch.repeat_interleave(
prompt_inputs["input_ids"], self.num_generations, dim=0
)
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
)
prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
outputs = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1)
hidden_states = outputs.last_hidden_state[:, :-1]
logits = outputs.logits # (B, L, V)
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps), hidden_states
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps, hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
self.ref_model, prompt_completion_ids, num_logits_to_keep
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
model, prompt_completion_ids, num_logits_to_keep
)
# done in liger
# Compute the KL divergence between the model and the reference model
# per_token_kl = (
# torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# )
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# done in liger
# # Compute grouped-wise rewards
# mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
# std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# done in liger
# # Normalize the rewards to compute the advantages
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# done in liger
# x - x.detach() allows for preserving gradients from x
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
lm_head = model.get_output_embeddings()
if self.ref_model is not None:
ref_lm_head = self.ref_model.get_output_embeddings()
else:
with self.null_ref_context():
ref_lm_head = model.get_output_embeddings()
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
loss, metrics = self.grpo_loss_fn(
lm_head,
hidden_states, # this is the hidden states from the model
completion_mask,
rewards,
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
ref_input=ref_hidden_states, # this is the hidden states from the ref model
ref_weight=ref_weight,
ref_bias=ref_bias,
)
else:
super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")

View File

@@ -0,0 +1,58 @@
### AXOLOTL COMMUNITY LICENSE AGREEMENT
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
and conditions set forth in this Agreement.
1. Definitions
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
which may be licensed separately by their respective authors and/or licensors.
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
permits Plugin Integrations to integrate with the Axolotl service.
2. Grant of License
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
- Licensee must comply with all the terms and conditions of this Agreement.
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
portions of the Software.
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
3. Restrictions
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
third parties to fine-tune artificial intelligence models.
3.2 Licensee shall not:
- Use the Software for any illegal or unauthorized purpose.
- Reverse engineer, decompile, or disassemble the Software.
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
Software or interfere with any third-party use of the Software.
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
4. Intellectual Property Rights
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
Licensee.
5. Disclaimer of Warranty
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
6. Termination
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
copies in its possession.
7. Governing Law
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
without regards to conflicts of laws provisions thereof.
8. Entire Agreement
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
Licensees continued use of the Software after any such updates shall constitute acceptance of updated terms
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
bound by the terms and conditions of this Agreement.
This Agreement was last updated on August 23, 2024.

View File

@@ -1,391 +0,0 @@
"""
benchmark utility helper for benchmarking the KL divergence triton kernel
"""
import gc
import time
import torch
from torch.utils.benchmark import Timer
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
# pylint: disable=cell-var-from-loop
def benchmark_kl_div_loss_with_backward():
# Test configurations
batch_sizes = [1, 4]
seq_lens = [64, 512, 2048, 4096, 8192]
vocab_size = 32000
top_k = 64
# Store results
results = []
# Run benchmarks
for batch_size in batch_sizes:
for seq_len in seq_lens:
# Generate random test data
torch.manual_seed(42)
# Create tensors with gradients
student_logits = torch.randn(
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
)
# pylint: disable=duplicate-code
target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
)
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
target_mask = torch.randint(
0, 2, (batch_size, seq_len, top_k), device="cuda"
).float()
# Clone student_logits for the two implementations
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
# Define functions for timing that include both forward and backward passes
def run_reference():
# Forward pass
loss_ref = eager_loss(
student_logits_ref, target_token_ids, target_logprobs, target_mask
)
# Backward pass
loss_ref.backward()
def run_triton():
# Forward pass
# pylint: disable=duplicate-code
loss_triton = triton_loss(
student_logits_triton,
target_token_ids,
target_logprobs,
target_mask,
)
# Backward pass
loss_triton.backward()
# Benchmark reference implementation (forward + backward)
t0 = Timer(
stmt="run_reference()",
globals={
"run_reference": run_reference,
},
)
# Reset gradients before timing
student_logits_ref.grad = None
ref_time = t0.timeit(10).median * 1000 # Convert to ms
# Benchmark Triton implementation (forward + backward)
t1 = Timer(
stmt="run_triton()",
globals={
"run_triton": run_triton,
},
)
# Reset gradients before timing
student_logits_triton.grad = None
triton_time = t1.timeit(10).median * 1000 # Convert to ms
# Compute speedup
speedup = ref_time / triton_time if triton_time > 0 else float("inf")
# Store results
results.append(
{
"batch_size": batch_size,
"seq_len": seq_len,
"reference_time_ms": ref_time,
"triton_time_ms": triton_time,
"speedup": speedup,
}
)
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
print(f" Reference time (fwd+bwd): {ref_time:.2f} ms")
print(f" Triton time (fwd+bwd): {triton_time:.2f} ms")
print(f" Speedup: {speedup:.2f}x")
return results
def benchmark_forward_backward_separately():
"""
Benchmark forward and backward passes separately to identify where the speedup comes from.
"""
# Test configurations
batch_sizes = [1, 4, 8]
seq_lens = [64, 512, 2048]
vocab_size = 32000
top_k = 64
# Store results
detailed_results = []
# Run benchmarks
for batch_size in batch_sizes:
for seq_len in seq_lens:
# Generate random test data
torch.manual_seed(42)
# Create tensors with gradients
student_logits = torch.randn(
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
)
# pylint: disable=duplicate-code
target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
)
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
target_mask = torch.randint(
0, 2, (batch_size, seq_len, top_k), device="cuda"
).float()
# Clone student_logits for the two implementations
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
# Forward-only reference
def run_reference_forward():
with torch.no_grad():
return eager_loss(
student_logits_ref,
target_token_ids,
target_logprobs,
target_mask,
)
# Forward-only triton
def run_triton_forward():
with torch.no_grad():
return triton_loss(
student_logits_triton,
target_token_ids,
target_logprobs,
target_mask,
)
# Benchmark forward pass only
t0_fwd = Timer(
stmt="run_reference_forward()",
globals={
"run_reference_forward": run_reference_forward,
},
)
ref_fwd_time = t0_fwd.timeit(10).median * 1000 # Convert to ms
t1_fwd = Timer(
stmt="run_triton_forward()",
globals={
"run_triton_forward": run_triton_forward,
},
)
triton_fwd_time = t1_fwd.timeit(10).median * 1000 # Convert to ms
# Pre-compute losses for backward pass benchmarking
loss_ref = eager_loss(
student_logits_ref, target_token_ids, target_logprobs, target_mask
)
loss_triton = triton_loss(
student_logits_triton, target_token_ids, target_logprobs, target_mask
)
# Backward-only reference
def run_reference_backward():
student_logits_ref.grad = None
loss_ref.backward()
# Backward-only triton
def run_triton_backward():
student_logits_triton.grad = None
loss_triton.backward()
# Benchmark backward pass only
t0_bwd = Timer(
stmt="run_reference_backward()",
globals={
"run_reference_backward": run_reference_backward,
},
)
ref_bwd_time = t0_bwd.timeit(10).median * 1000 # Convert to ms
t1_bwd = Timer(
stmt="run_triton_backward()",
globals={
"run_triton_backward": run_triton_backward,
},
)
triton_bwd_time = t1_bwd.timeit(10).median * 1000 # Convert to ms
# Compute speedups
fwd_speedup = (
ref_fwd_time / triton_fwd_time if triton_fwd_time > 0 else float("inf")
)
bwd_speedup = (
ref_bwd_time / triton_bwd_time if triton_bwd_time > 0 else float("inf")
)
total_ref_time = ref_fwd_time + ref_bwd_time
total_triton_time = triton_fwd_time + triton_bwd_time
total_speedup = (
total_ref_time / total_triton_time
if total_triton_time > 0
else float("inf")
)
# Store results
detailed_results.append(
{
"batch_size": batch_size,
"seq_len": seq_len,
"ref_forward_ms": ref_fwd_time,
"triton_forward_ms": triton_fwd_time,
"forward_speedup": fwd_speedup,
"ref_backward_ms": ref_bwd_time,
"triton_backward_ms": triton_bwd_time,
"backward_speedup": bwd_speedup,
"total_ref_ms": total_ref_time,
"total_triton_ms": total_triton_time,
"total_speedup": total_speedup,
}
)
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
print(
f" Forward: Reference={ref_fwd_time:.2f}ms, Triton={triton_fwd_time:.2f}ms, Speedup={fwd_speedup:.2f}x"
)
print(
f" Backward: Reference={ref_bwd_time:.2f}ms, Triton={triton_bwd_time:.2f}ms, Speedup={bwd_speedup:.2f}x"
)
print(
f" Total: Reference={total_ref_time:.2f}ms, Triton={total_triton_time:.2f}ms, Speedup={total_speedup:.2f}x"
)
return detailed_results
def benchmark_memory_usage_with_backward():
# Test configurations
batch_sizes = [1, 2]
seq_len = 8192
vocab_size = 128000
top_k = 64
# Store results
mem_results = []
# Run benchmarks
for batch_size in batch_sizes:
# Generate random test data
torch.manual_seed(42)
student_logits = torch.randn(
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
)
target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
)
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
target_mask = torch.randint(
0, 2, (batch_size, seq_len, top_k), device="cuda"
).float()
# Clone student_logits for the implementations
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
# Measure PyTorch memory usage (forward + backward)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
loss_ref = eager_loss(
student_logits_ref, target_token_ids, target_logprobs, target_mask
)
loss_ref.backward()
torch.cuda.synchronize()
pytorch_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
# Measure Triton memory usage (forward + backward)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
loss_triton = triton_loss(
student_logits_triton, target_token_ids, target_logprobs, target_mask
)
loss_triton.backward()
torch.cuda.synchronize()
triton_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
# Measure Triton memory usage with different chunk sizes (forward + backward)
for n_chunks in [1, 2, 4, 8]:
student_logits_chunk = student_logits.clone().detach().requires_grad_(True)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
loss_chunk = triton_loss(
student_logits_chunk,
target_token_ids,
target_logprobs,
target_mask,
)
loss_chunk.backward()
torch.cuda.synchronize()
chunk_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
mem_results.append(
{
"batch_size": batch_size,
"implementation": f"Triton (chunks={n_chunks})",
"memory_mb": chunk_mem,
}
)
# Store results
mem_results.append(
{
"batch_size": batch_size,
"implementation": "PyTorch",
"memory_mb": pytorch_mem,
}
)
mem_results.append(
{
"batch_size": batch_size,
"implementation": "Triton (default)",
"memory_mb": triton_mem,
}
)
print(f"Batch size: {batch_size} (with backward pass)")
print(f" PyTorch memory: {pytorch_mem:.2f} MB")
print(f" Triton memory: {triton_mem:.2f} MB")
print(f" Memory reduction: {(1 - triton_mem/pytorch_mem)*100:.2f}%")
return mem_results
def main():
print("Running benchmarks with forward and backward passes...")
benchmark_kl_div_loss_with_backward()
clean()
print("\nRunning detailed forward/backward benchmarks...")
# benchmark_forward_backward_separately()
# clean()
print("\nRunning memory usage benchmarks with backward passes...")
benchmark_memory_usage_with_backward()
clean()
def clean():
for _ in range(5):
gc.collect()
torch.cuda.empty_cache()
time.sleep(1)
if __name__ == "__main__":
main()

View File

@@ -1,16 +1,14 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
loss for top_k KL divergence

View File

@@ -1,750 +0,0 @@
"""
Optimized Triton kernel for KL divergence loss between teacher and student models.
"""
# pylint: disable=invalid-name,unused-argument
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def fused_logsumexp_logprobs_kernel(
student_logits_ptr, # Input logits in original dtype
student_logprobs_ptr, # Output logprobs (float32)
token_ids_ptr, # Token IDs for top-k
B,
S,
V,
K, # batch size, seq len, vocab size, top-k
temperature,
stride_l_b,
stride_l_s,
stride_l_v,
stride_lp_b,
stride_lp_s,
stride_lp_k,
stride_t_b,
stride_t_s,
stride_t_k,
BLOCK_SIZE: tl.constexpr,
):
"""
Fused kernel that computes logsumexp and logprobs for topk tokens.
All computations are done in float32 for numerical stability.
"""
# Program ID
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Compute logsumexp over the vocabulary
max_val = -float("inf")
# Phase 1: Find max value across vocabulary
for v_offset in range(0, V, BLOCK_SIZE):
# Create block indices and mask
block_size = min(BLOCK_SIZE, V - v_offset)
block_idx = tl.arange(0, BLOCK_SIZE)
mask = block_idx < block_size
# Load logits block and convert to float32 in-place
ptrs = (
student_logits_ptr
+ batch_idx * stride_l_b
+ seq_idx * stride_l_s
+ (v_offset + block_idx) * stride_l_v
)
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
# Apply temperature scaling if needed
if temperature != 1.0:
block_logits = block_logits / temperature
# Update max value
block_max = tl.max(block_logits, axis=0)
max_val = tl.maximum(max_val, block_max)
# Phase 2: Compute sum of exp(logits - max_val)
sum_exp = 0.0
for v_offset in range(0, V, BLOCK_SIZE):
# Create block indices and mask
block_size = min(BLOCK_SIZE, V - v_offset)
block_idx = tl.arange(0, BLOCK_SIZE)
mask = block_idx < block_size
# Load logits block and convert to float32 in-place
ptrs = (
student_logits_ptr
+ batch_idx * stride_l_b
+ seq_idx * stride_l_s
+ (v_offset + block_idx) * stride_l_v
)
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
# Apply temperature scaling if needed
if temperature != 1.0:
block_logits = block_logits / temperature
# Compute exp(logits - max_val) and add to sum
block_exp = tl.exp(block_logits - max_val)
sum_exp += tl.sum(block_exp * mask, axis=0)
# Compute final logsumexp
logsumexp = max_val + tl.log(sum_exp)
# Phase 3: Compute and store logprobs for the top-k tokens
token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
logprobs_base = (
student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s
)
for k in range(K):
# Load token ID for position k
token_id = tl.load(token_ids_base + k * stride_t_k)
# Load the corresponding logit and convert to float32
token_logit_ptr = (
student_logits_ptr
+ batch_idx * stride_l_b
+ seq_idx * stride_l_s
+ token_id * stride_l_v
)
token_logit = tl.load(token_logit_ptr).to(tl.float32)
# Apply temperature scaling if needed
if temperature != 1.0:
token_logit = token_logit / temperature
# Compute logprob directly: logit - logsumexp
token_logprob = token_logit - logsumexp
# Store the result
tl.store(logprobs_base + k * stride_lp_k, token_logprob)
@triton.jit
def grad_softmax_kernel(
grad_student_logits_ptr,
target_token_ids_ptr,
teacher_probs_ptr,
student_probs_ptr,
mask_ptr,
B,
S,
V,
K, # batch size, seq len, vocab size, top-k
scale,
stride_gl_b,
stride_gl_s,
stride_gl_v,
stride_t_b,
stride_t_s,
stride_t_k,
stride_p_b,
stride_p_s,
stride_p_k,
stride_sp_b,
stride_sp_s,
stride_sp_k,
stride_m_b,
stride_m_s,
stride_m_k,
BLOCK_SIZE: tl.constexpr,
):
# Program ID
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Base pointers for this (batch, seq) pair
grad_logits_base = (
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
)
token_ids_base = (
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
)
teacher_probs_base = (
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
)
student_probs_base = (
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
)
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
# Process each teacher probability one at a time, computing all gradients for it
for k in range(0, K):
# Load data for current position k
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
student_prob_k = tl.load(student_probs_base + k * stride_sp_k)
mask_val = tl.load(mask_base + k * stride_m_k)
# Precompute the self-influence term (multiplied by scale)
self_term = teacher_prob * (1.0 - student_prob_k) * scale
# Calculate gradient contributions for all positions j
for j in range(0, K):
token_id_j = tl.load(token_ids_base + j * stride_t_k)
student_prob_j = tl.load(student_probs_base + j * stride_sp_k)
mask_j = tl.load(mask_base + j * stride_m_k)
# Calculate the masking factor
combined_mask = mask_val * mask_j
# Determine if this is a diagonal or off-diagonal term
is_k_equals_j = tl.where(k == j, 1.0, 0.0)
# Compute the gradient contribution
# For diagonal (k==j): -teacher_prob * (1-student_prob_k) * scale * mask
# For off-diagonal: -(-teacher_prob * student_prob_j) * scale * mask
grad_contribution = (
-(
self_term * is_k_equals_j
- teacher_prob * student_prob_j * scale * (1.0 - is_k_equals_j)
)
* combined_mask
)
# Atomically update the gradient for this token
tl.atomic_add(
grad_logits_base + token_id_j * stride_gl_v, grad_contribution
)
@triton.jit
def grad_topk_softmax_kernel(
grad_student_logits_ptr,
student_logits_ptr,
target_token_ids_ptr,
teacher_probs_ptr,
student_probs_ptr,
mask_ptr,
B,
S,
V,
K, # batch size, seq len, vocab size, top-k
scale,
stride_gl_b,
stride_gl_s,
stride_gl_v,
stride_l_b,
stride_l_s,
stride_l_v,
stride_t_b,
stride_t_s,
stride_t_k,
stride_p_b,
stride_p_s,
stride_p_k,
stride_sp_b,
stride_sp_s,
stride_sp_k,
stride_m_b,
stride_m_s,
stride_m_k,
BLOCK_SIZE: tl.constexpr,
):
# Program ID
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Base pointers for this (batch, seq) pair
grad_logits_base = (
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
)
# logits_base = student_logits_ptr + batch_idx * stride_l_b + seq_idx * stride_l_s
token_ids_base = (
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
)
teacher_probs_base = (
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
)
student_probs_base = (
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
)
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
# Load all token IDs, probs and masks for this position
token_ids = tl.zeros([K], dtype=tl.int32)
teacher_probs = tl.zeros([K], dtype=tl.float32)
student_probs = tl.zeros([K], dtype=tl.float32)
masks = tl.zeros([K], dtype=tl.float32)
for k in range(K):
token_ids[k] = tl.load(token_ids_base + k * stride_t_k)
teacher_probs[k] = tl.load(teacher_probs_base + k * stride_p_k)
student_probs[k] = tl.load(student_probs_base + k * stride_sp_k)
masks[k] = tl.load(mask_base + k * stride_m_k)
# Process gradients for all tokens in this position
for k in range(K):
# token_id = token_ids[k]
mask_k = masks[k]
# Skip computation if mask is zero by multiplying gradient by mask
for j in range(K):
other_token_id = token_ids[j]
mask_j = masks[j]
combined_mask = mask_k * mask_j
# Compute gradient differently for diagonal vs off-diagonal entries
# Using * 1.0 to convert boolean to float
is_diagonal = tl.where(j == k, 1.0, 0.0)
# Self influence: gradient = teacher_prob * (1 - student_prob)
self_grad = teacher_probs[k] * (1.0 - student_probs[k]) * is_diagonal
# Cross influence: gradient = -teacher_prob[k] * student_prob[j]
cross_grad = -teacher_probs[k] * student_probs[j] * (1.0 - is_diagonal)
# Combined gradient scaled by mask
grad_val = (self_grad + cross_grad) * scale * combined_mask
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)
# Triton-accelerated implementation of KL divergence loss for top-k tokens
# Chunking helper functions for handling long sequences
def chunk_tensor(
tensor: torch.Tensor, max_seq_len: int
) -> Tuple[torch.Tensor, Optional[int]]:
"""Split a tensor along sequence dimension if needed."""
_, seq_len, *__ = tensor.shape
if seq_len <= max_seq_len:
return tensor, None
num_chunks = (seq_len + max_seq_len - 1) // max_seq_len
chunks = []
for i in range(num_chunks):
start_idx = i * max_seq_len
end_idx = min((i + 1) * max_seq_len, seq_len)
chunks.append(tensor[:, start_idx:end_idx, ...])
return chunks, num_chunks
def merge_chunks(chunks: list, original_shape: torch.Size):
"""Merge chunks back into a single tensor with original shape."""
return torch.cat(chunks, dim=1)
# Triton-accelerated implementation of KL divergence loss for top-k tokens
class TopKKLDivergence(torch.autograd.Function):
"""
Autograd function for KL divergence loss between top-k logprobs
with support for chunking to handle very long sequences.
"""
# Max sequence length to process in a single kernel launch
# This is a tunable parameter that might need adjustment based on GPU memory
MAX_SEQ_LEN = 8192
@staticmethod
def forward(
ctx,
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch=-1,
kd_temperature=1.0,
top_k_before_softmax=0,
):
"""
Forward pass for KL divergence loss between top-k logprobs with chunking.
"""
# Only convert target_logprobs to float, leave student_logits as is
target_logprobs = target_logprobs.float()
# Get dimensions
batch_size, _, vocab_size = student_logits.shape
_, teacher_seq_len, top_k = target_token_ids.shape
# Slice student logits to match teacher sequence length
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
# Store original values for backward pass
ctx.original_seq_len = teacher_seq_len
ctx.original_dtype = student_logits.dtype
# Apply chunking for long sequences
if teacher_seq_len > TopKKLDivergence.MAX_SEQ_LEN:
# Chunk the inputs
student_logits_chunks, num_chunks = chunk_tensor(
student_logits_for_kd, TopKKLDivergence.MAX_SEQ_LEN
)
target_token_ids_chunks, _ = chunk_tensor(
target_token_ids, TopKKLDivergence.MAX_SEQ_LEN
)
# target_logprobs_chunks, _ = chunk_tensor(
# target_logprobs, TopKKLDivergence.MAX_SEQ_LEN
# )
# target_mask_chunks, _ = chunk_tensor(
# target_mask, TopKKLDivergence.MAX_SEQ_LEN
# )
# Process each chunk
student_logprobs_chunks = []
student_probs_chunks = []
for i in range(num_chunks):
chunk_logits = student_logits_chunks[i]
chunk_token_ids = target_token_ids_chunks[i]
chunk_seq_len = chunk_logits.shape[1]
if top_k_before_softmax:
# Apply temperature to student logits
if kd_temperature != 1.0:
chunk_logits = chunk_logits / kd_temperature
# Gather student logits for top-k tokens
chunk_logits_topk = torch.gather(
chunk_logits, dim=-1, index=chunk_token_ids
)
# Compute softmax over gathered logits
chunk_logprobs_topk = torch.log_softmax(chunk_logits_topk, dim=-1)
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
else:
# Allocate output tensor for logprobs directly (always in float32)
chunk_logprobs_topk = torch.empty(
(batch_size, chunk_seq_len, top_k),
dtype=torch.float32,
device=chunk_logits.device,
)
# Launch fused kernel directly
grid = (batch_size * chunk_seq_len,)
fused_logsumexp_logprobs_kernel[grid](
chunk_logits.contiguous(),
chunk_logprobs_topk,
chunk_token_ids.contiguous(),
batch_size,
chunk_seq_len,
vocab_size,
top_k,
kd_temperature,
chunk_logits.stride(0),
chunk_logits.stride(1),
chunk_logits.stride(2),
chunk_logprobs_topk.stride(0),
chunk_logprobs_topk.stride(1),
chunk_logprobs_topk.stride(2),
chunk_token_ids.stride(0),
chunk_token_ids.stride(1),
chunk_token_ids.stride(2),
min(1024, triton.next_power_of_2(vocab_size)),
)
# Calculate probs from logprobs
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
# Store results
student_logprobs_chunks.append(chunk_logprobs_topk)
student_probs_chunks.append(chunk_probs_topk)
# Merge results
student_logprobs_topk = torch.cat(student_logprobs_chunks, dim=1)
student_probs_topk = torch.cat(student_probs_chunks, dim=1)
# Save chunking info for backward pass
ctx.used_chunking = True
ctx.num_chunks = num_chunks
else:
# Original code path for shorter sequences
if top_k_before_softmax:
# Apply temperature to student logits
if kd_temperature != 1.0:
student_logits_for_kd = student_logits_for_kd / kd_temperature
# Gather student logits for top-k tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
)
# Compute softmax over gathered logits
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
student_probs_topk = torch.exp(student_logprobs_topk)
else:
# Allocate output tensor for logprobs directly (always in float32)
student_logprobs_topk = torch.empty(
(batch_size, teacher_seq_len, top_k),
dtype=torch.float32,
device=student_logits.device,
)
# Launch fused kernel directly
grid = (batch_size * teacher_seq_len,)
fused_logsumexp_logprobs_kernel[grid](
student_logits_for_kd.contiguous(),
student_logprobs_topk,
target_token_ids.contiguous(),
batch_size,
teacher_seq_len,
vocab_size,
top_k,
kd_temperature,
student_logits_for_kd.stride(0),
student_logits_for_kd.stride(1),
student_logits_for_kd.stride(2),
student_logprobs_topk.stride(0),
student_logprobs_topk.stride(1),
student_logprobs_topk.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
min(1024, triton.next_power_of_2(vocab_size)),
)
# Calculate probs from logprobs
student_probs_topk = torch.exp(student_logprobs_topk)
# No chunking used
ctx.used_chunking = False
# Save tensors for backward pass
ctx.save_for_backward(
student_logits_for_kd,
target_token_ids,
target_logprobs,
target_mask,
student_probs_topk,
)
ctx.kd_temperature = kd_temperature
ctx.top_k_before_softmax = top_k_before_softmax
ctx.num_items_in_batch = num_items_in_batch
# Convert mask to boolean
valid_mask = target_mask.bool()
# Extract valid tokens only - this is where the error was happening
# Use cloned contiguous tensors and explicit indexing for safety
student_logprobs_flat = student_logprobs_topk.view(-1, top_k)
target_logprobs_flat = target_logprobs.view(-1, top_k)
valid_mask_flat = valid_mask.view(-1, top_k)
# Gather valid indices explicitly to avoid illegal memory access
valid_indices = torch.nonzero(valid_mask_flat.view(-1)).squeeze(-1)
student_logprobs_valid = torch.index_select(
student_logprobs_flat.view(-1), 0, valid_indices
)
target_logprobs_valid = torch.index_select(
target_logprobs_flat.view(-1), 0, valid_indices
)
# Convert teacher logprobs to probabilities
teacher_probs_valid = torch.exp(target_logprobs_valid)
# Compute KL divergence loss
token_losses = teacher_probs_valid * (
target_logprobs_valid - student_logprobs_valid
)
kd_loss = token_losses.sum()
# Apply temperature scaling
# pylint: disable=duplicate-code
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items or valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
num_valid_tokens = valid_indices.numel()
kd_loss = kd_loss / float(num_valid_tokens if num_valid_tokens > 0 else 1)
return kd_loss
@staticmethod
def backward(ctx, grad_output):
"""
Optimized backward pass for KL divergence loss with proper dtype handling and chunking.
"""
(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
student_probs,
) = ctx.saved_tensors
kd_temperature = ctx.kd_temperature
num_items_in_batch = ctx.num_items_in_batch
original_dtype = ctx.original_dtype
# Get dimensions
batch_size, _, vocab_size = student_logits.shape
_, teacher_seq_len, top_k = target_token_ids.shape
# Initialize gradient tensor in float32 to support atomic operations
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
# Compute scaling factor
scale = grad_output.item()
# Apply temperature scaling from forward pass
if kd_temperature != 1.0:
scale = scale * (kd_temperature**2)
# Normalize by number of items or valid tokens
if num_items_in_batch > 0:
scale = scale / float(num_items_in_batch)
else:
scale = scale / float(target_mask.sum().item())
# Apply chain rule for temperature scaling (1/temperature)
if kd_temperature != 1.0:
scale = scale / kd_temperature
# Convert teacher logprobs to probabilities
teacher_probs = torch.exp(target_logprobs)
# Use chunking for the backward pass if used in forward
if getattr(ctx, "used_chunking", False):
num_chunks = ctx.num_chunks
max_seq = TopKKLDivergence.MAX_SEQ_LEN
# Process each chunk
for i in range(num_chunks):
start_idx = i * max_seq
end_idx = min((i + 1) * max_seq, teacher_seq_len)
chunk_len = end_idx - start_idx
# Get chunk slices
# student_logits_chunk = student_logits[:, start_idx:end_idx, :]
target_token_ids_chunk = target_token_ids[:, start_idx:end_idx, :]
teacher_probs_chunk = teacher_probs[:, start_idx:end_idx, :]
student_probs_chunk = student_probs[:, start_idx:end_idx, :]
target_mask_chunk = target_mask[:, start_idx:end_idx, :]
grad_student_logits_chunk = grad_student_logits[:, start_idx:end_idx, :]
# Launch gradient computation kernel for this chunk
grid = (batch_size * chunk_len,)
grad_softmax_kernel[grid](
grad_student_logits_chunk.contiguous(),
target_token_ids_chunk.contiguous(),
teacher_probs_chunk.contiguous(),
student_probs_chunk.contiguous(),
target_mask_chunk.contiguous(),
batch_size,
chunk_len,
vocab_size,
top_k,
scale,
grad_student_logits_chunk.stride(0),
grad_student_logits_chunk.stride(1),
grad_student_logits_chunk.stride(2),
target_token_ids_chunk.stride(0),
target_token_ids_chunk.stride(1),
target_token_ids_chunk.stride(2),
teacher_probs_chunk.stride(0),
teacher_probs_chunk.stride(1),
teacher_probs_chunk.stride(2),
student_probs_chunk.stride(0),
student_probs_chunk.stride(1),
student_probs_chunk.stride(2),
target_mask_chunk.stride(0),
target_mask_chunk.stride(1),
target_mask_chunk.stride(2),
min(1024, triton.next_power_of_2(top_k)),
)
# Update the gradient tensor (already in-place)
else:
# Original code path for shorter sequences
# Launch gradient computation kernel
grid = (batch_size * teacher_seq_len,)
grad_softmax_kernel[grid](
grad_student_logits.contiguous(),
target_token_ids.contiguous(),
teacher_probs.contiguous(),
student_probs.contiguous(),
target_mask.contiguous(),
batch_size,
teacher_seq_len,
vocab_size,
top_k,
scale,
grad_student_logits.stride(0),
grad_student_logits.stride(1),
grad_student_logits.stride(2),
target_token_ids.stride(0),
target_token_ids.stride(1),
target_token_ids.stride(2),
teacher_probs.stride(0),
teacher_probs.stride(1),
teacher_probs.stride(2),
student_probs.stride(0),
student_probs.stride(1),
student_probs.stride(2),
target_mask.stride(0),
target_mask.stride(1),
target_mask.stride(2),
min(1024, triton.next_power_of_2(top_k)),
)
# Convert gradient back to original dtype if needed
if original_dtype != torch.float32:
grad_student_logits = grad_student_logits.to(original_dtype)
# Return gradients for student_logits and None for other inputs
return grad_student_logits, None, None, None, None, None, None
# Wrapper function for chunked computation
def loss(
student_logits: torch.Tensor,
target_token_ids: torch.Tensor,
target_logprobs: torch.Tensor,
target_mask: torch.Tensor,
num_items_in_batch: int = -1,
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
max_seq_len: Optional[int] = None,
):
"""
Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation
with support for very long sequences.
Args:
student_logits: Student logits [B, seq_len, vocab_size]
target_token_ids: Teacher token IDs [B, seq_len, top_k]
target_logprobs: Teacher logprobs [B, seq_len, top_k]
target_mask: Token mask [B, seq_len, top_k]
num_items_in_batch: Number of items for normalization (-1 for auto)
kd_temperature: Temperature for KD
top_k_before_softmax: Flag for softmax application order
max_seq_len: Override default MAX_SEQ_LEN value for chunking
"""
# Allow overriding the max sequence length
if max_seq_len is not None and max_seq_len > 0:
TopKKLDivergence.MAX_SEQ_LEN = max_seq_len
total_loss = TopKKLDivergence.apply(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
-1 if num_items_in_batch <= 0 else num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
return total_loss

View File

@@ -1,67 +0,0 @@
"""
Optimized Triton kernels for logsumexp
"""
# pylint: disable=invalid-name,unused-argument
import triton
import triton.language as tl
# Helper function for computing logsumexp
@triton.jit
def logsumexp_kernel(
logits_ptr,
output_ptr,
B,
S,
V, # batch size, seq len, vocab size
stride_b,
stride_s,
stride_v,
out_stride_b,
out_stride_s,
BLOCK_SIZE: tl.constexpr,
):
# Program ID
# pylint: disable=duplicate-code
pid = tl.program_id(0)
batch_idx = pid // S
seq_idx = pid % S
# Bounds check
if batch_idx >= B or seq_idx >= S:
return
# Pointers
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s
# Find maximum for numerical stability
max_val = -float("inf")
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size
logits_block = tl.load(
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask,
other=-float("inf"),
)
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))
# Compute sum of exp(logit - max_val)
sum_exp = 0.0
for v_offset in range(0, V, BLOCK_SIZE):
v_size = min(BLOCK_SIZE, V - v_offset)
mask = tl.arange(0, BLOCK_SIZE) < v_size
logits_block = tl.load(
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
mask=mask,
other=-float("inf"),
)
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)
# Compute logsumexp
result = max_val + tl.log(sum_exp)
# Store result
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)

View File

@@ -20,7 +20,6 @@ from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton
class AxolotlKDTrainer(AxolotlTrainer):
@@ -86,12 +85,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
num_items_in_batch=num_items_in_batch,
)
else:
loss_fn = (
topk_kd_loss
if self.args.kd_top_k_before_softmax
else topk_kd_loss_triton
)
loss_kd = loss_fn(
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,

View File

@@ -127,8 +127,6 @@ class ReLoRACallback(TrainerCallback):
optimizer: torch.optim.Optimizer,
**_kwargs,
):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join(
args.output_dir,

View File

@@ -272,7 +272,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
res[key].append(val)
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary
if not res:

View File

@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
peft_init_lora_weights: Optional[Union[bool, str]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,

View File

@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64
use_liger_loss: Optional[bool] = False

View File

@@ -172,11 +172,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
)
try:
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
except AttributeError:
pass

View File

@@ -1321,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")

View File

@@ -4,17 +4,13 @@ helper util to calculate dataset lengths
import numpy as np
def get_dataset_lengths(dataset, from_arrow=False):
def get_dataset_lengths(dataset):
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
if from_arrow:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths

View File

@@ -90,12 +90,6 @@ class TestKnowledgeDistillation:
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
)
@pytest.mark.parametrize(
"load_in_8bit",
@@ -127,9 +121,3 @@ class TestKnowledgeDistillation:
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
)

View File

@@ -1,163 +0,0 @@
"""
sanity checks on kl loss and gradients
"""
import torch
# Import both implementations
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
def test_kl_loss_gradient():
"""Test that the gradient of the Triton implementation matches the eager implementation."""
# Set the random seed for reproducibility
torch.manual_seed(42)
# Create random inputs
batch_size = 2
seq_len = 3
vocab_size = 100
top_k = 5
# Generate random student logits
student_logits = torch.randn(
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
)
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
# Generate random target token IDs, ensuring they're valid indices
# pylint: disable=duplicate-code
target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
)
# Generate random target logprobs (before normalization)
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
# Normalize the target logprobs to ensure they form a valid distribution
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
# Create a random mask with some tokens masked out
target_mask = torch.randint(
0, 2, (batch_size, seq_len, top_k), device="cuda"
).float()
# Additional parameters
num_items_in_batch = batch_size * seq_len
kd_temperature = 1.0
top_k_before_softmax = 0 # Test both modes
# Compute the loss and gradients with eager implementation
loss_eager = eager_loss(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
loss_eager.backward()
grad_eager = student_logits.grad.clone()
# Reset gradients
student_logits.grad.zero_()
# Compute the loss and gradients with Triton implementation
loss_triton = triton_loss(
student_logits_triton,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
loss_triton.backward()
grad_triton = student_logits_triton.grad.clone()
# Compare loss values
print(f"Eager loss: {loss_eager.item()}")
print(f"Triton loss: {loss_triton.item()}")
loss_diff = abs(loss_eager.item() - loss_triton.item())
print(f"Loss difference: {loss_diff}")
assert loss_diff < 1e-5, "Loss values differ significantly!"
# Compare gradients
grad_diff = (grad_eager - grad_triton).abs().max().item()
print(f"Max gradient difference: {grad_diff}")
# Print some sample gradients
sample_idx = (0, 0, 0) # (batch, seq, vocab)
print(f"Sample eager gradient: {grad_eager[sample_idx].item()}")
print(f"Sample triton gradient: {grad_triton[sample_idx].item()}")
# Compute relative difference for non-zero gradients
mask = grad_eager.abs() > 1e-10
if mask.sum() > 0:
rel_diff = (
(
(grad_eager[mask] - grad_triton[mask]).abs()
/ (grad_eager[mask].abs() + 1e-10)
)
.max()
.item()
)
print(f"Max relative gradient difference: {rel_diff}")
assert rel_diff < 1e-3, "Gradients differ significantly!"
# Also test top_k_before_softmax = 1 mode
top_k_before_softmax = 1
# Reset the gradients
student_logits = torch.randn(
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
)
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
# Compute the loss and gradients with eager implementation
loss_eager = eager_loss(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
loss_eager.backward()
grad_eager = student_logits.grad.clone()
# Compute the loss and gradients with Triton implementation
loss_triton = triton_loss(
student_logits_triton,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch,
kd_temperature,
top_k_before_softmax,
)
loss_triton.backward()
grad_triton = student_logits_triton.grad.clone()
# Compare gradients for top_k_before_softmax = 1
grad_diff = (grad_eager - grad_triton).abs().max().item()
print("\nWith top_k_before_softmax=1:")
print(f"Max gradient difference: {grad_diff}")
# Compute relative difference for non-zero gradients
mask = grad_eager.abs() > 1e-10
if mask.sum() > 0:
rel_diff = (
(
(grad_eager[mask] - grad_triton[mask]).abs()
/ (grad_eager[mask].abs() + 1e-10)
)
.max()
.item()
)
assert (
rel_diff < 1e-3
), f"Gradients differ significantly, Max relative gradient difference: {rel_diff}"

View File

@@ -1,204 +0,0 @@
"""
sanity checks on logsumexp kernel validity
"""
import torch
import triton
from axolotl.integrations.kd.topk_logprob.logsumexp import logsumexp_kernel
# PyTorch implementation of logsumexp for reference
def torch_logsumexp(logits):
"""PyTorch implementation of logsumexp over last dimension"""
return torch.logsumexp(logits, dim=-1)
# Wrapper function for Triton logsumexp kernel
def triton_logsumexp(logits):
"""Triton implementation of logsumexp over last dimension"""
B, S, V = logits.shape # pylint: disable=invalid-name
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
grid = (B * S,)
logsumexp_kernel[grid](
logits.contiguous(),
output,
B,
S,
V,
logits.stride(0),
logits.stride(1),
logits.stride(2),
output.stride(0),
output.stride(1),
min(1024, triton.next_power_of_2(V)),
)
return output
class TritonLogSumExp(torch.autograd.Function):
"""
Wrap a custom autograd function to use the Triton logsumexp for gradient testing
"""
@staticmethod
def forward(ctx, logits):
B, S, V = logits.shape # pylint: disable=invalid-name
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
# Save inputs for backward pass
ctx.save_for_backward(logits)
ctx.shape = logits.shape
grid = (B * S,)
logsumexp_kernel[grid](
logits.contiguous(),
output,
B,
S,
V,
logits.stride(0),
logits.stride(1),
logits.stride(2),
output.stride(0),
output.stride(1),
min(1024, triton.next_power_of_2(V)),
)
return output
@staticmethod
def backward(ctx, grad_output):
(logits,) = ctx.saved_tensors
# For logsumexp, the gradient is softmax(input) * grad_output
# First compute the logsumexp
lse = TritonLogSumExp.apply(logits)
# Compute softmax by exponentiating differences
softmax_output = torch.exp(logits - lse.unsqueeze(-1))
# Compute gradient of logsumexp by multiplying the softmax output by the gradient
grad_input = softmax_output * grad_output.unsqueeze(-1)
return grad_input
def test_logsumexp_values():
"""Test that the Triton logsumexp implementation matches PyTorch's"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Test with various input shapes
test_shapes = [
(2, 3, 10), # small vocab
(4, 5, 100), # medium vocab
(2, 2, 32000), # large vocab (typical for LLMs)
]
for shape in test_shapes:
# Create random input tensors
logits = torch.randn(shape, device="cuda", requires_grad=False)
# Compute logsumexp using both implementations
torch_result = torch_logsumexp(logits)
triton_result = triton_logsumexp(logits)
# Compare results
max_diff = (torch_result - triton_result).abs().max().item()
print(f"Shape {shape}, Max diff: {max_diff}")
# Assert that the results are very close
assert max_diff < 1e-5, f"Results differ for shape {shape}: max diff {max_diff}"
def test_logsumexp_edge_cases():
"""Test edge cases for numerical stability"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Case 1: Very large values that might cause overflow
logits_large = torch.ones(2, 3, 100, device="cuda") * 1000
# Case 2: Very small values that might cause underflow
logits_small = torch.ones(2, 3, 100, device="cuda") * -1000
# Case 3: Mix of large and small values
logits_mixed = torch.zeros(2, 3, 100, device="cuda")
logits_mixed[:, :, 0] = 1000 # One very large value
# Case 4: All identical values
logits_identical = torch.ones(2, 3, 100, device="cuda") * 5
# Case 5: Extreme values with NaN check
logits_extreme = torch.cat(
[
torch.full((1, 3, 50), 1e10, device="cuda"),
torch.full((1, 3, 50), -1e10, device="cuda"),
],
dim=0,
)
for i, logits in enumerate(
[logits_large, logits_small, logits_mixed, logits_identical, logits_extreme]
):
# Compute logsumexp using both implementations
torch_result = torch_logsumexp(logits)
triton_result = triton_logsumexp(logits)
# Check for NaNs
assert not torch.isnan(
torch_result
).any(), f"PyTorch produced NaNs for case {i+1}"
assert not torch.isnan(
triton_result
).any(), f"Triton produced NaNs for case {i+1}"
# Compare results
max_diff = (torch_result - triton_result).abs().max().item()
print(f"Edge case {i+1}, Max diff: {max_diff}")
# For very extreme values, allow a bit more tolerance
if i == 4: # extreme case
assert max_diff < 1e-2, f"Results differ too much for edge case {i+1}"
else:
assert max_diff < 1e-5, f"Results differ too much for edge case {i+1}"
def test_logsumexp_gradients():
"""Test that the gradients of Triton logsumexp match PyTorch's"""
# Set random seed for reproducibility
torch.manual_seed(42)
# Create input tensors with gradients enabled
shapes = [(2, 3, 10), (4, 5, 100)]
for shape in shapes:
# Create two identical tensors for PyTorch and Triton
logits_torch = torch.randn(shape, device="cuda", requires_grad=True)
logits_triton = logits_torch.clone().detach().requires_grad_(True)
# Forward pass
torch_output = torch_logsumexp(logits_torch)
triton_output = TritonLogSumExp.apply(logits_triton)
# Compare forward pass values
max_diff_forward = (torch_output - triton_output).abs().max().item()
assert max_diff_forward < 1e-5, f"Forward pass values differ for shape {shape}"
# Create random gradient
grad_output = torch.randn_like(torch_output)
# Backward pass
torch_output.backward(grad_output)
triton_output.backward(grad_output)
# Compare gradients
max_diff_grad = (logits_torch.grad - logits_triton.grad).abs().max().item()
print(f"Shape {shape}, Max gradient diff: {max_diff_grad}")
# Assert that gradients are very close
assert (
max_diff_grad < 1e-5
), f"Gradients differ for shape {shape}: max diff {max_diff_grad}"

View File

@@ -102,11 +102,7 @@ def is_hopper():
def check_tensorboard(
temp_run_dir: str,
tag: str,
comparison_val: float,
assertion_err: str,
lt: bool = True,
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
) -> None:
"""
helper function to parse and check tensorboard logs
@@ -116,20 +112,10 @@ def check_tensorboard(
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
if lt:
if "%s" in assertion_err:
assert df.value.values[-1] < comparison_val, (
assertion_err % df.value.values[-1]
)
else:
assert df.value.values[-1] < comparison_val, assertion_err
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else:
if "%s" in assertion_err:
assert df.value.values[-1] > comparison_val, (
assertion_err % df.value.values[-1]
)
else:
assert df.value.values[-1] > comparison_val, assertion_err
assert df.value.values[-1] < lt_val, assertion_err
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: