Compare commits

..

81 Commits

Author SHA1 Message Date
Sung Ching Liu
328bb0466b Merge branch 'main' into flx_attn_support 2025-02-21 11:27:25 -05:00
Sunny Liu
e792b54bab remove unnecessary components 2025-02-21 11:23:21 -05:00
NanoCode012
bf842730a5 fix(doc): add missing auto_find_batch_size (#2339) [skip ci] 2025-02-21 11:56:38 +07:00
Wing Lian
1db6ad60a7 support for passing init_lora_weights to lora_config (#2352) 2025-02-20 22:56:34 -05:00
salman
29b366b2e1 Bumping 0.15.1 TRL version for GRPO+PEFT fix (#2344)
* bumping TRL version

* apply upstream fixes to our custom fix

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-02-20 22:56:04 -05:00
NanoCode012
b53a41372f feat: update transformers version to 4.49.0 (#2340) 2025-02-20 21:12:06 -05:00
Wing Lian
02f45e94be calculate sample length fixes and SFT splitting fixes (#2351)
* fix chat template splitting long samples across multiple rows

* make the preprocessing faster
2025-02-20 14:29:58 -05:00
bursteratom
82d04ea060 test v2batch w/ flex attn 2025-02-13 00:11:45 -05:00
Sung Ching Liu
0ef1f011fe Merge branch 'main' into flx_attn_support 2025-02-11 23:31:56 -05:00
Sunny Liu
c0a1d205c7 packed doc mask starts at 1, 0 means masked out 2025-02-07 14:44:52 -05:00
Sunny Liu
d0e739da24 attempt at getting around bf16 error 2025-02-04 21:57:21 -05:00
Sunny Liu
3f6be519d5 stack 2025-02-04 21:25:13 -05:00
Sunny Liu
adcbc7459b misc 2025-02-04 21:17:50 -05:00
Sunny Liu
470ba65c44 make doc mask instead of the whole block mask in collator 2025-02-04 20:27:39 -05:00
Sunny Liu
8e1adc154d stuff 2025-02-02 20:36:14 -05:00
Sunny Liu
e5b36900e4 misc 2025-02-02 20:32:03 -05:00
Sunny Liu
9f6c89b12b undo my stupidity 2025-02-02 20:25:53 -05:00
Sunny Liu
b0871c8d3b attempt - mask padding 2025-02-02 20:18:49 -05:00
bursteratom
d3ea379a23 figure out slight diff from flash result 2025-02-02 01:45:54 -05:00
bursteratom
0ebab63309 test 2025-02-02 01:27:15 -05:00
bursteratom
e98581f6f5 BLOCK SIZE 2025-02-02 01:22:23 -05:00
bursteratom
b832b11c8f stuff 2025-02-02 00:51:43 -05:00
bursteratom
b692d394b1 more test 2025-02-02 00:48:57 -05:00
bursteratom
2319e5276d more test 2025-02-02 00:48:15 -05:00
bursteratom
9a43a0925d more test 2025-02-02 00:45:30 -05:00
bursteratom
10de67e8ea more test 2025-02-02 00:43:41 -05:00
bursteratom
fa7355404c test 2025-02-02 00:38:35 -05:00
bursteratom
907424a2e8 stuff 2025-02-02 00:29:09 -05:00
Sunny Liu
3f4fd3c1eb remove padding self attention 2025-02-01 22:47:10 -05:00
Sunny Liu
48c3c47071 vanills mask 2025-02-01 14:23:37 -05:00
Sunny Liu
3ed9c117fb try vanilla mask 2025-02-01 14:09:13 -05:00
Sunny Liu
84960003ed reset llama_patch_multipack.py 2025-01-30 14:40:18 -05:00
Sunny Liu
93a268e43d --no-verify
fixes silly mistake
2025-01-30 14:08:26 -05:00
Sunny Liu
065f6d477e flex batching WIP 2025-01-30 14:04:59 -05:00
Sunny Liu
96ad741cd5 flex batching WIP 2025-01-30 12:35:25 -05:00
bursteratom
ba88bc7840 wip flex block mask creation 2025-01-29 00:25:25 -05:00
Sung Ching Liu
b31796a681 Merge branch 'main' into flx_attn_support 2025-01-28 14:20:43 -05:00
Sunny Liu
5ca57cb55a undo bool conversion 2025-01-23 17:56:13 -05:00
Sunny Liu
0149de7fb0 mask to bool 2025-01-23 15:30:08 -05:00
Sunny Liu
8c34c65181 dummy 2025-01-23 14:56:26 -05:00
Sunny Liu
555aa5772a skip mask conversion if already 4d 2025-01-23 14:01:53 -05:00
Sunny Liu
e8b2789086 revert mask expand 2025-01-23 11:20:38 -05:00
Sunny Liu
85752cdfc9 mask expansion 2025-01-22 21:33:38 -05:00
Sunny Liu
f2f23c8041 mask expansion 2025-01-22 21:31:42 -05:00
Sunny Liu
8b3eec7f6e mask expansion 2025-01-22 21:29:52 -05:00
Sunny Liu
bb9bea3110 mask expansion 2025-01-22 21:27:25 -05:00
Sunny Liu
0dd18a3681 llama sdpa patching WIP - static class function import 2025-01-22 21:10:05 -05:00
Sunny Liu
152e988d3c llama sdpa patching WIP - static class function import 2025-01-22 21:02:26 -05:00
Sunny Liu
27532825a9 llama sdpa patching WIP - static class function import 2025-01-22 21:00:34 -05:00
Sunny Liu
06f83a54a5 llama sdpa patching WIP - static class function import 2025-01-22 20:45:44 -05:00
Sunny Liu
d7b133dc1f llama sdpa patching WIP - static class function import 2025-01-22 20:33:13 -05:00
Sunny Liu
f3bec17917 llama sdpa patching WIP - static class function import 2025-01-22 20:25:26 -05:00
Sunny Liu
b7deb5241c llama sdpa patching WIP 2025-01-22 20:16:27 -05:00
Sunny Liu
cee310dcfa llama sdpa patching WIP 2025-01-22 20:15:23 -05:00
Sunny Liu
d1be6e228d llama sdpa patching WIP 2025-01-22 20:14:20 -05:00
Sunny Liu
5f9f77f384 llama patch 2025-01-22 11:29:28 -05:00
bursteratom
b2a34380b3 sample packing doc mask creation WIP 2025-01-21 09:18:38 -05:00
Sunny Liu
80bfc50d1f get seqlens from position ids for foc masking 2025-01-17 17:22:04 -05:00
Sunny Liu
a5360c172c llama hijacking 2025-01-17 15:54:03 -05:00
Sunny Liu
013a9b73fc fix transformers version for testing 2025-01-16 15:32:57 -05:00
Sunny
aad62428e0 not sure if this is necessary actually 2025-01-16 15:08:34 -05:00
Sunny
a6f2c5d583 flex sample packing WIP 2025-01-15 21:12:33 -05:00
Sunny
dbcd11e533 revert seq len in multipack sampler 2025-01-14 11:45:35 -05:00
Sunny
c06a6be915 flex_attn sample packing WIP 2025-01-14 00:22:05 -05:00
bursteratom
d3a0cb5edb transformers version 2025-01-13 10:33:00 -05:00
bursteratom
8b47e456b0 revert to transformers 4.47.1 2025-01-13 10:29:27 -05:00
Sunny Liu
2319ac729c Merge branch 'main' into flx_attn_support 2025-01-13 09:42:58 -05:00
Sunny
f99cae0e7b llama test 2025-01-12 17:30:19 -05:00
Wing Lian
888cd9407f use 2.5.1 docker images as latest tag as it seems stable (#2198) 2025-01-12 13:34:17 -05:00
Wing Lian
bd62d6e10a rename liger test so it properly runs in ci (#2246) 2025-01-12 13:34:17 -05:00
NanoCode012
5eae134110 feat: add support for data_files in pretraining (#2238) 2025-01-12 13:34:17 -05:00
Wing Lian
b7d27bdfa4 update upstream HF deps (#2239)
* bump axolotl contribs for upstream main conflicts:

* bump datasets, tokenizer, trl

* remove log workarounds in trl

* bump lm-eval

* remove unsloth_ import from critical path

* remove llama fa2 from conftest

* unsloth breaks with latest upstream
2025-01-12 13:34:17 -05:00
Vincenzo di Cicco
da97a21bdc Use SequentialSampler if curriculum_sampling is enabled with sample_packing (#2235) 2025-01-12 13:34:17 -05:00
Wing Lian
e0d4b88598 update modal version for ci (#2242) 2025-01-12 13:34:17 -05:00
NanoCode012
fac059a209 fix: mistral nemo does not recognize token_type_ids in forward (#2233) 2025-01-12 13:34:17 -05:00
Wing Lian
9c9ac1cf0b add hf cache caching for GHA (#2247)
* add hf cache caching for GHA

* use modal volume to cache hf data

* make sure to update the cache as we add new fixtures in conftest
2025-01-12 13:34:17 -05:00
Wing Lian
2346f21b2b Merge group queue (#2248)
* add support for merge groups

* also lint merge groups
2025-01-12 13:34:17 -05:00
salman
0b47281f51 Fixing OSX installation (#2231)
* bumping version, removing non-osx compatible deps

* updating pylintrc

* fixing linters

* reverting changes
2025-01-12 13:34:17 -05:00
Sunny
543daaf46f llama test 2025-01-09 16:08:24 -05:00
Sunny
bcd9ad44e0 flex attention support 2025-01-06 19:54:11 -05:00
bursteratom
61ad375bf4 config validation for flex attention 2025-01-05 23:27:49 -05:00
14 changed files with 215 additions and 283 deletions

View File

@@ -407,7 +407,10 @@ save_total_limit: # Checkpoints saved at a time
max_steps: max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time. # bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second: include_tokens_per_second: # Optional[bool]
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128

View File

@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.48.3 transformers==4.49.0
tokenizers>=0.21.0 tokenizers>=0.21.0
accelerate==1.3.0 accelerate==1.3.0
datasets==3.2.0 datasets==3.2.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.15.0 trl==0.15.1
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer

View File

@@ -831,7 +831,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if "max_length" in kwargs: if "max_length" in kwargs:
kwargs.pop("max_length") kwargs.pop("max_length")
elif use_batch_sampler_collator: elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: if self.cfg.flex_attention is True:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq collator = V2BatchSamplerDataCollatorForSeq2Seq
elif ( elif (
self.cfg.model_config_type in ["llama"] self.cfg.model_config_type in ["llama"]

View File

@@ -9,7 +9,6 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -31,21 +30,31 @@ class GRPOStrategy:
@classmethod @classmethod
def set_training_args_kwargs(cls, cfg): def set_training_args_kwargs(cls, cfg):
training_kwargs = [ grpo_args_kwargs = {}
"use_vllm", if cfg.trl and cfg.trl.use_vllm:
"vllm_device", grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
"vllm_gpu_memory_utilization", if cfg.trl and cfg.trl.vllm_device:
"vllm_max_model_len", grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
"vllm_dtype", else:
"use_liger_loss", grpo_args_kwargs["vllm_device"] = "auto"
"num_generations", if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
"log_completions", grpo_args_kwargs[
"sync_ref_model", "vllm_gpu_memory_utilization"
"ref_model_mixup_alpha", ] = cfg.trl.vllm_gpu_memory_utilization
"ref_model_sync_steps", if cfg.trl and cfg.trl.vllm_max_model_len:
"max_completion_length", grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
] if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]} grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod
@@ -62,7 +71,9 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg): def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {} trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes: if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod

View File

@@ -13,4 +13,3 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
""" """
Axolotl GRPO Config for GRPO training Axolotl GRPO Config for GRPO training
""" """
use_liger_loss: bool = False

View File

@@ -1,24 +1,13 @@
""" """
Axolotl GRPO trainer Axolotl GRPO trainer
""" """
from contextlib import contextmanager, nullcontext
from accelerate.utils import is_peft_model from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module from accelerate.utils.other import is_compiled_module
import torch
from transformers import PreTrainedModel from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation from trl.models import unwrap_model_for_generation
from axolotl.core.trainers.base import SchedulerMixin from axolotl.core.trainers.base import SchedulerMixin
from transformers.utils import is_liger_kernel_available
if is_liger_kernel_available():
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from accelerate.utils import broadcast_object_list, gather_object
from trl.trainer.utils import pad
# mypy: ignore-errors # mypy: ignore-errors
@@ -32,19 +21,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.use_liger_loss = kwargs["args"].use_liger_loss
if self.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.grpo_loss_fn = LigerFusedLinearGRPOLoss(
beta=self.beta,
compiled=is_compiled_module(self.model),
use_ref_model=True,
num_generations=self.args.num_generations,
)
# pylint: disable=access-member-before-definition # pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested # Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing: if kwargs["args"].gradient_checkpointing:
@@ -53,7 +29,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model.config.use_cache = False self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT # Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(self.model.base_model, "gradient_checkpointing_enable"): if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
self.model.base_model.gradient_checkpointing_enable() self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models # Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"): elif hasattr(self.model, "gradient_checkpointing_enable"):
@@ -61,12 +39,15 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"]) self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition # pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel: def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model.""" """Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin # pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = ( use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] "use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
) )
if use_reentrant: if use_reentrant:
@@ -77,7 +58,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def make_inputs_require_grad(module, input, output): def make_inputs_require_grad(module, input, output):
output.requires_grad_(True) output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
return model return model
# pylint: enable=unused-argument,redefined-builtin # pylint: enable=unused-argument,redefined-builtin
@@ -89,18 +72,25 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
gather_deepspeed3_params=self.args.ds3_gather_for_generation, gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model: ) as unwrapped_model:
if is_compiled_module(unwrapped_model): if is_compiled_module(unwrapped_model):
unwrapped_model = unwrapped_model._orig_mod # pylint: disable=protected-access unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
if is_peft_model(unwrapped_model): if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter() unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict() state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes # Remove base_model and base_layer prefixes
state_dict = { state_dict = {
k.removeprefix("base_model.model.").removeprefix("base_model.model.").replace(".base_layer", ""): v k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
for k, v in state_dict.items() for k, v in state_dict.items()
} }
# Remove values with adapter prefix (example: "_lora") # Remove values with adapter prefix (example: "_lora")
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
# When module to save, remove its prefix and discard the original module # When module to save, remove its prefix and discard the original module
state_dict = { state_dict = {
k.replace("modules_to_save.default.", ""): v k.replace("modules_to_save.default.", ""): v
@@ -109,218 +99,10 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
} }
else: else:
state_dict = unwrapped_model.state_dict() state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process: if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model llm_model = (
llm_model.load_weights(state_dict.items()) self.llm.llm_engine.model_executor.driver_worker.model_runner.model
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if self.use_liger_loss:
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
if self.max_prompt_length is not None:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text) * self.num_generations
# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts) * self.num_generations,
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
) )
completion_ids = completion_ids[process_slice] llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
# Pad the completions, and concatenate them with the prompts unwrapped_model.unmerge_adapter()
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
prompt_inputs_repeated = torch.repeat_interleave(
prompt_inputs["input_ids"], self.num_generations, dim=0
)
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
)
prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]
# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, num_logits_to_keep):
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
outputs = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1)
hidden_states = outputs.last_hidden_state[:, :-1]
logits = outputs.logits # (B, L, V)
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps), hidden_states
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps, hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
self.ref_model, prompt_completion_ids, num_logits_to_keep
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
model, prompt_completion_ids, num_logits_to_keep
)
# done in liger
# Compute the KL divergence between the model and the reference model
# per_token_kl = (
# torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# )
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Decode the generated completions
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
# Compute the rewards
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, PreTrainedModel):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
for key in reward_kwargs:
for example in inputs:
# Repeat each value in the column for `num_generations` times
reward_kwargs[key].extend([example[key]] * self.num_generations)
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)
# done in liger
# # Compute grouped-wise rewards
# mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
# std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# done in liger
# # Normalize the rewards to compute the advantages
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# done in liger
# x - x.detach() allows for preserving gradients from x
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
lm_head = model.get_output_embeddings()
if self.ref_model is not None:
ref_lm_head = self.ref_model.get_output_embeddings()
else:
with self.null_ref_context():
ref_lm_head = model.get_output_embeddings()
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
loss, metrics = self.grpo_loss_fn(
lm_head,
hidden_states, # this is the hidden states from the model
completion_mask,
rewards,
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
ref_input=ref_hidden_states, # this is the hidden states from the ref model
ref_weight=ref_weight,
ref_bias=ref_bias,
)
else:
super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with (
self.accelerator.unwrap_model(self.model).disable_adapter()
if self.is_peft_model and not self.ref_adapter_name
else nullcontext()
):
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")

View File

@@ -127,6 +127,8 @@ class ReLoRACallback(TrainerCallback):
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
**_kwargs, **_kwargs,
): ):
if not optimizer:
optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0: if state.global_step > 0 and state.global_step % self.relora_steps == 0:
checkpoint_folder = os.path.join( checkpoint_folder = os.path.join(
args.output_dir, args.output_dir,

View File

@@ -95,6 +95,103 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_packed_mask_from_pos_ids(position_ids):
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
device = position_ids.device
results = []
for i, row in enumerate(position_ids):
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
for i, seq_len in enumerate(seq_lengths):
start_id = start_indices[i]
doc_mask[start_id : start_id + seq_len] = (
(i+1) * doc_mask[start_id : start_id + seq_len]
)
if padding_length:
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
results.append(doc_mask)
return torch.stack(results)
def get_seqlens_from_pos_ids(position_ids):
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
max_seq_len = position_ids.shape[1]
device = position_ids.device
results = []
totalseqlens = []
for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
if padding_length:
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[len(row) - torch.sum(seq_lengths)],
dtype=torch.int32,
device=device,
),
]
)
results.append(seq_lengths)
totalseqlens.append(len(adjusted_row))
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
def get_cu_seqlens_from_pos_ids(position_ids): def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids""" """generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1: if len(position_ids.shape) == 1:
@@ -176,7 +273,10 @@ def mask_2d_to_4d(
when they attend to each other within that sequence. when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking. This expansion transforms the mask to lower triangular form to prevent future peeking.
""" """
bsz, src_len = mask.size()
if len(mask.size()) == 4:
return mask
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
tgt_len = tgt_len if tgt_len is not None else src_len tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2) mask = mask.unsqueeze(1).unsqueeze(2)

View File

@@ -272,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
dict(zip(feature_names, row)) dict(zip(feature_names, row))
) )
for key, val in tokenized_prompt.items(): for key, val in tokenized_prompt.items():
for i in range(0, len(val), self.sequence_len): res[key].append(val)
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary # If there are no examples left, return an empty dictionary
if not res: if not res:

View File

@@ -342,6 +342,7 @@ class LoraConfig(BaseModel):
peft_use_dora: Optional[bool] = None peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None peft_layer_replication: Optional[List[Tuple[int, int]]] = None
peft_init_lora_weights: Optional[Union[bool, str]] = None
qlora_sharded_model_loading: Optional[bool] = Field( qlora_sharded_model_loading: Optional[bool] = Field(
default=False, default=False,
@@ -822,6 +823,7 @@ class AxolotlInputConfig(
xformers_attention: Optional[bool] = None xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None s2_attention: Optional[bool] = None
flex_attention: Optional[bool] = None
flash_attention: Optional[bool] = None flash_attention: Optional[bool] = None
flash_attn_cross_entropy: Optional[bool] = None flash_attn_cross_entropy: Optional[bool] = None
flash_attn_rms_norm: Optional[bool] = None flash_attn_rms_norm: Optional[bool] = None
@@ -1788,6 +1790,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (
data.get("flex_attention") is True
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"Flex attention is not supported on torch version < 2.5.1"
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_torch_compile_auto(cls, data): def check_torch_compile_auto(cls, data):

View File

@@ -33,4 +33,3 @@ class TRLConfig(BaseModel):
sync_ref_model: Optional[bool] = False sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9 ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64 ref_model_sync_steps: Optional[int] = 64
use_liger_loss: Optional[bool] = False

View File

@@ -172,10 +172,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
) )
try: try:
min_input_len = np.min(get_dataset_lengths(dataset)) ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
LOG.debug(f"min_input_len: {min_input_len}") min_input_len = np.min(ds_lengths)
max_input_len = np.max(get_dataset_lengths(dataset)) LOG.info(f"min_input_len: {min_input_len}")
LOG.debug(f"max_input_len: {max_input_len}") max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
except AttributeError: except AttributeError:
pass pass

View File

@@ -403,7 +403,7 @@ class ModelLoader:
if ( if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and self.cfg.flash_attention and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing and self.cfg.sample_packing
): ):
if "auto_map" in self.model_config: if "auto_map" in self.model_config:
@@ -707,7 +707,13 @@ class ModelLoader:
""" """
sample packing uses custom FA2 patch sample packing uses custom FA2 patch
""" """
if self.cfg.flash_attention:
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention: if not self.cfg.sample_packing and self.cfg.s2_attention:
pass pass
self.model_kwargs["attn_implementation"] = "flash_attention_2" self.model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -1113,7 +1119,7 @@ class ModelLoader:
should_convert = ( should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility. # convert them back to fp16/bf16 for flash-attn compatibility.
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp) ((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
) )
@@ -1321,6 +1327,8 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits: if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq" lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora: if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.") LOG.info("Initializing LoRA weights using dora. This might take longer.")

View File

@@ -4,13 +4,17 @@ helper util to calculate dataset lengths
import numpy as np import numpy as np
def get_dataset_lengths(dataset): def get_dataset_lengths(dataset, from_arrow=False):
if "length" in dataset.column_names: if "length" in dataset.column_names:
lengths = np.array(dataset["length"]) lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names: elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"] position_ids = dataset["position_ids"]
lengths = np.array([x[-1] + 1 for x in position_ids]) lengths = np.array([x[-1] + 1 for x in position_ids])
else: else:
input_ids = dataset["input_ids"] if from_arrow:
lengths = np.array([len(seq) for seq in input_ids]) input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
return lengths return lengths