Compare commits
3 Commits
flx_attn_s
...
grpo_liger
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a09d5e844 | ||
|
|
cf61b4aba7 | ||
|
|
14d274efe6 |
@@ -407,10 +407,7 @@ save_total_limit: # Checkpoints saved at a time
|
|||||||
max_steps:
|
max_steps:
|
||||||
|
|
||||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||||
include_tokens_per_second: # Optional[bool]
|
include_tokens_per_second:
|
||||||
|
|
||||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
|
||||||
auto_find_batch_size: # Optional[bool]
|
|
||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|||||||
@@ -13,12 +13,12 @@ liger-kernel==0.5.2
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
transformers==4.49.0
|
transformers==4.48.3
|
||||||
tokenizers>=0.21.0
|
tokenizers>=0.21.0
|
||||||
accelerate==1.3.0
|
accelerate==1.3.0
|
||||||
datasets==3.2.0
|
datasets==3.2.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.1
|
||||||
trl==0.15.1
|
trl==0.15.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
|||||||
@@ -831,9 +831,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.flex_attention is True:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
elif (
|
elif (
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -30,31 +31,21 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_training_args_kwargs(cls, cfg):
|
def set_training_args_kwargs(cls, cfg):
|
||||||
grpo_args_kwargs = {}
|
training_kwargs = [
|
||||||
if cfg.trl and cfg.trl.use_vllm:
|
"use_vllm",
|
||||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
"vllm_device",
|
||||||
if cfg.trl and cfg.trl.vllm_device:
|
"vllm_gpu_memory_utilization",
|
||||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
"vllm_max_model_len",
|
||||||
else:
|
"vllm_dtype",
|
||||||
grpo_args_kwargs["vllm_device"] = "auto"
|
"use_liger_loss",
|
||||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
"num_generations",
|
||||||
grpo_args_kwargs[
|
"log_completions",
|
||||||
"vllm_gpu_memory_utilization"
|
"sync_ref_model",
|
||||||
] = cfg.trl.vllm_gpu_memory_utilization
|
"ref_model_mixup_alpha",
|
||||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
"ref_model_sync_steps",
|
||||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
"max_completion_length",
|
||||||
if cfg.trl and cfg.trl.num_generations:
|
]
|
||||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
|
||||||
if cfg.trl and cfg.trl.sync_ref_model:
|
|
||||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
|
||||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
|
||||||
grpo_args_kwargs[
|
|
||||||
"ref_model_mixup_alpha"
|
|
||||||
] = cfg.trl.ref_model_mixup_alpha
|
|
||||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
|
||||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
|
||||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
|
||||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -71,9 +62,7 @@ class GRPOStrategy:
|
|||||||
def set_trainer_kwargs(cls, cfg):
|
def set_trainer_kwargs(cls, cfg):
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||||
trainer_kwargs[
|
trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
|
||||||
"reward_processing_classes"
|
|
||||||
] = cfg.trl.reward_processing_classes
|
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
|||||||
"""
|
"""
|
||||||
Axolotl GRPO Config for GRPO training
|
Axolotl GRPO Config for GRPO training
|
||||||
"""
|
"""
|
||||||
|
use_liger_loss: bool = False
|
||||||
|
|||||||
@@ -1,13 +1,24 @@
|
|||||||
"""
|
"""
|
||||||
Axolotl GRPO trainer
|
Axolotl GRPO trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
from accelerate.utils import is_peft_model
|
from accelerate.utils import is_peft_model
|
||||||
from accelerate.utils.other import is_compiled_module
|
from accelerate.utils.other import is_compiled_module
|
||||||
|
import torch
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from trl import GRPOConfig, GRPOTrainer
|
from trl import GRPOConfig, GRPOTrainer
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
|
|
||||||
from axolotl.core.trainers.base import SchedulerMixin
|
from axolotl.core.trainers.base import SchedulerMixin
|
||||||
|
from transformers.utils import is_liger_kernel_available
|
||||||
|
|
||||||
|
if is_liger_kernel_available():
|
||||||
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
|
||||||
|
|
||||||
|
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||||
|
from accelerate.utils import broadcast_object_list, gather_object
|
||||||
|
from trl.trainer.utils import pad
|
||||||
|
|
||||||
|
|
||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
@@ -21,6 +32,19 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.use_liger_loss = kwargs["args"].use_liger_loss
|
||||||
|
if self.use_liger_loss:
|
||||||
|
if not is_liger_kernel_available():
|
||||||
|
raise ValueError(
|
||||||
|
"You set `use_liger_loss=True` but the liger kernel is not available. "
|
||||||
|
"Please install liger-kernel first: `pip install liger-kernel`"
|
||||||
|
)
|
||||||
|
self.grpo_loss_fn = LigerFusedLinearGRPOLoss(
|
||||||
|
beta=self.beta,
|
||||||
|
compiled=is_compiled_module(self.model),
|
||||||
|
use_ref_model=True,
|
||||||
|
num_generations=self.args.num_generations,
|
||||||
|
)
|
||||||
# pylint: disable=access-member-before-definition
|
# pylint: disable=access-member-before-definition
|
||||||
# Enable gradient checkpointing if requested
|
# Enable gradient checkpointing if requested
|
||||||
if kwargs["args"].gradient_checkpointing:
|
if kwargs["args"].gradient_checkpointing:
|
||||||
@@ -29,9 +53,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
self.model.config.use_cache = False
|
self.model.config.use_cache = False
|
||||||
|
|
||||||
# Enable gradient checkpointing on the base model for PEFT
|
# Enable gradient checkpointing on the base model for PEFT
|
||||||
if is_peft_model(self.model) and hasattr(
|
if is_peft_model(self.model) and hasattr(self.model.base_model, "gradient_checkpointing_enable"):
|
||||||
self.model.base_model, "gradient_checkpointing_enable"
|
|
||||||
):
|
|
||||||
self.model.base_model.gradient_checkpointing_enable()
|
self.model.base_model.gradient_checkpointing_enable()
|
||||||
# Enable gradient checkpointing for non-PEFT models
|
# Enable gradient checkpointing for non-PEFT models
|
||||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||||
@@ -39,15 +61,12 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||||
# pylint: enable=access-member-before-definition
|
# pylint: enable=access-member-before-definition
|
||||||
|
|
||||||
def _enable_gradient_checkpointing(
|
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
|
||||||
self, model: PreTrainedModel, args: GRPOConfig
|
|
||||||
) -> PreTrainedModel:
|
|
||||||
"""Enables gradient checkpointing for the model."""
|
"""Enables gradient checkpointing for the model."""
|
||||||
# pylint: disable=unused-argument,redefined-builtin
|
# pylint: disable=unused-argument,redefined-builtin
|
||||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||||
use_reentrant = (
|
use_reentrant = (
|
||||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
||||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_reentrant:
|
if use_reentrant:
|
||||||
@@ -58,9 +77,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
def make_inputs_require_grad(module, input, output):
|
def make_inputs_require_grad(module, input, output):
|
||||||
output.requires_grad_(True)
|
output.requires_grad_(True)
|
||||||
|
|
||||||
model.get_input_embeddings().register_forward_hook(
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
make_inputs_require_grad
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
# pylint: enable=unused-argument,redefined-builtin
|
# pylint: enable=unused-argument,redefined-builtin
|
||||||
@@ -72,25 +89,18 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||||
) as unwrapped_model:
|
) as unwrapped_model:
|
||||||
if is_compiled_module(unwrapped_model):
|
if is_compiled_module(unwrapped_model):
|
||||||
unwrapped_model = (
|
unwrapped_model = unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
if is_peft_model(unwrapped_model):
|
if is_peft_model(unwrapped_model):
|
||||||
unwrapped_model.merge_adapter()
|
unwrapped_model.merge_adapter()
|
||||||
state_dict = unwrapped_model.state_dict()
|
state_dict = unwrapped_model.state_dict()
|
||||||
|
unwrapped_model.unmerge_adapter()
|
||||||
# Remove base_model and base_layer prefixes
|
# Remove base_model and base_layer prefixes
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k.removeprefix("base_model.model.")
|
k.removeprefix("base_model.model.").removeprefix("base_model.model.").replace(".base_layer", ""): v
|
||||||
.removeprefix("base_model.model.")
|
|
||||||
.replace(".base_layer", ""): v
|
|
||||||
for k, v in state_dict.items()
|
for k, v in state_dict.items()
|
||||||
}
|
}
|
||||||
# Remove values with adapter prefix (example: "_lora")
|
# Remove values with adapter prefix (example: "_lora")
|
||||||
state_dict = {
|
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
|
||||||
k: v
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
if unwrapped_model.prefix not in k
|
|
||||||
}
|
|
||||||
# When module to save, remove its prefix and discard the original module
|
# When module to save, remove its prefix and discard the original module
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k.replace("modules_to_save.default.", ""): v
|
k.replace("modules_to_save.default.", ""): v
|
||||||
@@ -99,10 +109,218 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
state_dict = unwrapped_model.state_dict()
|
state_dict = unwrapped_model.state_dict()
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
llm_model = (
|
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
llm_model.load_weights(state_dict.items())
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||||
|
if self.use_liger_loss:
|
||||||
|
if return_outputs:
|
||||||
|
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||||
|
|
||||||
|
device = self.accelerator.device
|
||||||
|
prompts = [x["prompt"] for x in inputs]
|
||||||
|
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
||||||
|
prompt_inputs = self.processing_class(
|
||||||
|
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
||||||
|
)
|
||||||
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||||
|
|
||||||
|
if self.max_prompt_length is not None:
|
||||||
|
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
|
||||||
|
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]
|
||||||
|
|
||||||
|
# Generate completions using either vLLM or regular generation
|
||||||
|
if self.args.use_vllm:
|
||||||
|
# First, have main process load weights if needed
|
||||||
|
if self.state.global_step != self._last_loaded_step:
|
||||||
|
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||||
|
state_dict = unwrapped_model.state_dict()
|
||||||
|
if self.accelerator.is_main_process:
|
||||||
|
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||||
|
llm_model.load_weights(state_dict.items())
|
||||||
|
self._last_loaded_step = self.state.global_step
|
||||||
|
|
||||||
|
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||||
|
all_prompts_text = gather_object(prompts_text)
|
||||||
|
if self.accelerator.is_main_process:
|
||||||
|
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
|
||||||
|
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
||||||
|
else:
|
||||||
|
completion_ids = [None] * len(all_prompts_text) * self.num_generations
|
||||||
|
|
||||||
|
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
||||||
|
# corresponding slice.
|
||||||
|
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||||
|
process_slice = slice(
|
||||||
|
self.accelerator.process_index * len(prompts) * self.num_generations,
|
||||||
|
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
|
||||||
)
|
)
|
||||||
llm_model.load_weights(state_dict.items())
|
completion_ids = completion_ids[process_slice]
|
||||||
if is_peft_model(unwrapped_model):
|
|
||||||
unwrapped_model.unmerge_adapter()
|
# Pad the completions, and concatenate them with the prompts
|
||||||
|
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
||||||
|
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
||||||
|
prompt_inputs_repeated = torch.repeat_interleave(
|
||||||
|
prompt_inputs["input_ids"], self.num_generations, dim=0
|
||||||
|
)
|
||||||
|
prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1)
|
||||||
|
else:
|
||||||
|
# Regular generation path
|
||||||
|
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||||
|
prompt_completion_ids = unwrapped_model.generate(
|
||||||
|
**prompt_inputs, generation_config=self.generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_length = prompt_inputs["input_ids"].size(1)
|
||||||
|
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||||
|
|
||||||
|
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||||
|
def get_per_token_logps(model, input_ids, num_logits_to_keep):
|
||||||
|
# We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded
|
||||||
|
outputs = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1)
|
||||||
|
hidden_states = outputs.last_hidden_state[:, :-1]
|
||||||
|
logits = outputs.logits # (B, L, V)
|
||||||
|
logits = logits[
|
||||||
|
:, :-1, :
|
||||||
|
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||||
|
|
||||||
|
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
||||||
|
per_token_logps = []
|
||||||
|
for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]):
|
||||||
|
log_probs = logits_row.log_softmax(dim=-1)
|
||||||
|
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
|
||||||
|
per_token_logps.append(token_log_prob)
|
||||||
|
return torch.stack(per_token_logps), hidden_states
|
||||||
|
|
||||||
|
num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||||
|
per_token_logps, hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
if self.ref_model is not None:
|
||||||
|
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
|
||||||
|
self.ref_model, prompt_completion_ids, num_logits_to_keep
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with self.accelerator.unwrap_model(model).disable_adapter():
|
||||||
|
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
|
||||||
|
model, prompt_completion_ids, num_logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
# done in liger
|
||||||
|
# Compute the KL divergence between the model and the reference model
|
||||||
|
# per_token_kl = (
|
||||||
|
# torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Mask everything after the first EOS token
|
||||||
|
is_eos = completion_ids == self.processing_class.eos_token_id
|
||||||
|
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
||||||
|
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||||
|
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
||||||
|
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||||
|
|
||||||
|
# Decode the generated completions
|
||||||
|
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||||
|
if is_conversational(inputs[0]):
|
||||||
|
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
||||||
|
|
||||||
|
# Compute the rewards
|
||||||
|
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
|
||||||
|
|
||||||
|
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
||||||
|
for i, (reward_func, reward_processing_class) in enumerate(
|
||||||
|
zip(self.reward_funcs, self.reward_processing_classes)
|
||||||
|
):
|
||||||
|
if isinstance(reward_func, PreTrainedModel):
|
||||||
|
if is_conversational(inputs[0]):
|
||||||
|
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
||||||
|
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
||||||
|
else:
|
||||||
|
texts = [p + c for p, c in zip(prompts, completions)]
|
||||||
|
reward_inputs = reward_processing_class(
|
||||||
|
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
||||||
|
)
|
||||||
|
reward_inputs = super()._prepare_inputs(reward_inputs)
|
||||||
|
with torch.inference_mode():
|
||||||
|
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
||||||
|
else:
|
||||||
|
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
||||||
|
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
|
||||||
|
for key in reward_kwargs:
|
||||||
|
for example in inputs:
|
||||||
|
# Repeat each value in the column for `num_generations` times
|
||||||
|
reward_kwargs[key].extend([example[key]] * self.num_generations)
|
||||||
|
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
||||||
|
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# Sum the rewards from all reward functions
|
||||||
|
rewards = rewards_per_func.sum(dim=1)
|
||||||
|
|
||||||
|
# done in liger
|
||||||
|
# # Compute grouped-wise rewards
|
||||||
|
# mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
||||||
|
# std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
||||||
|
|
||||||
|
# done in liger
|
||||||
|
# # Normalize the rewards to compute the advantages
|
||||||
|
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||||
|
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||||
|
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
||||||
|
|
||||||
|
# done in liger
|
||||||
|
# x - x.detach() allows for preserving gradients from x
|
||||||
|
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
||||||
|
# per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
||||||
|
# loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||||
|
|
||||||
|
# Log the metrics
|
||||||
|
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
||||||
|
self._metrics["completion_length"].append(completion_length)
|
||||||
|
|
||||||
|
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
|
||||||
|
for i, reward_func in enumerate(self.reward_funcs):
|
||||||
|
if isinstance(reward_func, PreTrainedModel):
|
||||||
|
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
||||||
|
else:
|
||||||
|
reward_func_name = reward_func.__name__
|
||||||
|
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
||||||
|
|
||||||
|
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
||||||
|
|
||||||
|
lm_head = model.get_output_embeddings()
|
||||||
|
|
||||||
|
if self.ref_model is not None:
|
||||||
|
ref_lm_head = self.ref_model.get_output_embeddings()
|
||||||
|
else:
|
||||||
|
with self.null_ref_context():
|
||||||
|
ref_lm_head = model.get_output_embeddings()
|
||||||
|
ref_weight = ref_lm_head.weight
|
||||||
|
ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None
|
||||||
|
|
||||||
|
loss, metrics = self.grpo_loss_fn(
|
||||||
|
lm_head,
|
||||||
|
hidden_states, # this is the hidden states from the model
|
||||||
|
completion_mask,
|
||||||
|
rewards,
|
||||||
|
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
|
||||||
|
ref_input=ref_hidden_states, # this is the hidden states from the ref model
|
||||||
|
ref_weight=ref_weight,
|
||||||
|
ref_bias=ref_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
super().compute_loss(model, inputs, return_outputs, num_items_in_batch)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def null_ref_context(self):
|
||||||
|
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
||||||
|
with (
|
||||||
|
self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||||
|
if self.is_peft_model and not self.ref_adapter_name
|
||||||
|
else nullcontext()
|
||||||
|
):
|
||||||
|
if self.ref_adapter_name:
|
||||||
|
self.model.set_adapter(self.ref_adapter_name)
|
||||||
|
yield
|
||||||
|
if self.ref_adapter_name:
|
||||||
|
self.model.set_adapter(self.model_adapter_name or "default")
|
||||||
|
|||||||
@@ -127,8 +127,6 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
**_kwargs,
|
**_kwargs,
|
||||||
):
|
):
|
||||||
if not optimizer:
|
|
||||||
optimizer = state.optimizer
|
|
||||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||||
checkpoint_folder = os.path.join(
|
checkpoint_folder = os.path.join(
|
||||||
args.output_dir,
|
args.output_dir,
|
||||||
|
|||||||
@@ -95,103 +95,6 @@ def get_cu_seqlens(attn_mask):
|
|||||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
def get_packed_mask_from_pos_ids(position_ids):
|
|
||||||
if len(position_ids.shape) == 1:
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
device = position_ids.device
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for i, row in enumerate(position_ids):
|
|
||||||
# Count the number of consecutive zeros from the right side
|
|
||||||
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
|
||||||
|
|
||||||
# Adjust the row to exclude padding
|
|
||||||
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
|
||||||
|
|
||||||
# Find where the position resets to 0 (indicating a new sequence)
|
|
||||||
seq_starts = torch.cat(
|
|
||||||
[
|
|
||||||
torch.tensor([True], dtype=torch.bool, device=device),
|
|
||||||
adjusted_row[1:] == 0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Get the indices where the sequence starts
|
|
||||||
start_indices = torch.cat(
|
|
||||||
[
|
|
||||||
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
|
||||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Calculate the sequence lengths
|
|
||||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
|
||||||
# Append the padding length to the sequence lengths
|
|
||||||
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
|
|
||||||
for i, seq_len in enumerate(seq_lengths):
|
|
||||||
start_id = start_indices[i]
|
|
||||||
doc_mask[start_id : start_id + seq_len] = (
|
|
||||||
(i+1) * doc_mask[start_id : start_id + seq_len]
|
|
||||||
)
|
|
||||||
if padding_length:
|
|
||||||
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]
|
|
||||||
|
|
||||||
results.append(doc_mask)
|
|
||||||
|
|
||||||
return torch.stack(results)
|
|
||||||
|
|
||||||
|
|
||||||
def get_seqlens_from_pos_ids(position_ids):
|
|
||||||
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
|
|
||||||
if len(position_ids.shape) == 1:
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
max_seq_len = position_ids.shape[1]
|
|
||||||
|
|
||||||
device = position_ids.device
|
|
||||||
results = []
|
|
||||||
totalseqlens = []
|
|
||||||
|
|
||||||
for row in position_ids:
|
|
||||||
# Count the number of consecutive zeros from the right side
|
|
||||||
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
|
||||||
|
|
||||||
# Adjust the row to exclude padding
|
|
||||||
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
|
||||||
|
|
||||||
# Find where the position resets to 0 (indicating a new sequence)
|
|
||||||
seq_starts = torch.cat(
|
|
||||||
[
|
|
||||||
torch.tensor([True], dtype=torch.bool, device=device),
|
|
||||||
adjusted_row[1:] == 0,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Get the indices where the sequence starts
|
|
||||||
start_indices = torch.cat(
|
|
||||||
[
|
|
||||||
torch.nonzero(seq_starts).unbind(dim=1)[0],
|
|
||||||
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Calculate the sequence lengths
|
|
||||||
seq_lengths = start_indices[1:] - start_indices[:-1]
|
|
||||||
# Append the padding length to the sequence lengths
|
|
||||||
if padding_length:
|
|
||||||
seq_lengths = torch.cat(
|
|
||||||
[
|
|
||||||
seq_lengths,
|
|
||||||
torch.tensor(
|
|
||||||
[len(row) - torch.sum(seq_lengths)],
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append(seq_lengths)
|
|
||||||
totalseqlens.append(len(adjusted_row))
|
|
||||||
|
|
||||||
return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
if len(position_ids.shape) == 1:
|
if len(position_ids.shape) == 1:
|
||||||
@@ -273,10 +176,7 @@ def mask_2d_to_4d(
|
|||||||
when they attend to each other within that sequence.
|
when they attend to each other within that sequence.
|
||||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||||
"""
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
if len(mask.size()) == 4:
|
|
||||||
return mask
|
|
||||||
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
|
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|||||||
@@ -272,7 +272,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
dict(zip(feature_names, row))
|
dict(zip(feature_names, row))
|
||||||
)
|
)
|
||||||
for key, val in tokenized_prompt.items():
|
for key, val in tokenized_prompt.items():
|
||||||
res[key].append(val)
|
for i in range(0, len(val), self.sequence_len):
|
||||||
|
res[key].append(val[i : i + self.sequence_len])
|
||||||
|
|
||||||
# If there are no examples left, return an empty dictionary
|
# If there are no examples left, return an empty dictionary
|
||||||
if not res:
|
if not res:
|
||||||
|
|||||||
@@ -342,7 +342,6 @@ class LoraConfig(BaseModel):
|
|||||||
peft_use_dora: Optional[bool] = None
|
peft_use_dora: Optional[bool] = None
|
||||||
peft_use_rslora: Optional[bool] = None
|
peft_use_rslora: Optional[bool] = None
|
||||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||||
peft_init_lora_weights: Optional[Union[bool, str]] = None
|
|
||||||
|
|
||||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -823,7 +822,6 @@ class AxolotlInputConfig(
|
|||||||
xformers_attention: Optional[bool] = None
|
xformers_attention: Optional[bool] = None
|
||||||
sdp_attention: Optional[bool] = None
|
sdp_attention: Optional[bool] = None
|
||||||
s2_attention: Optional[bool] = None
|
s2_attention: Optional[bool] = None
|
||||||
flex_attention: Optional[bool] = None
|
|
||||||
flash_attention: Optional[bool] = None
|
flash_attention: Optional[bool] = None
|
||||||
flash_attn_cross_entropy: Optional[bool] = None
|
flash_attn_cross_entropy: Optional[bool] = None
|
||||||
flash_attn_rms_norm: Optional[bool] = None
|
flash_attn_rms_norm: Optional[bool] = None
|
||||||
@@ -1790,26 +1788,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_flex_torch_version(cls, data):
|
|
||||||
if (data.get("flex_attention") is not None) and (
|
|
||||||
data.get("flex_attention") is True
|
|
||||||
):
|
|
||||||
env_capabilities = data.get("env_capabilities", {})
|
|
||||||
torch_version = env_capabilities.get("torch_version")
|
|
||||||
|
|
||||||
if torch_version is None:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
|
||||||
|
|
||||||
if version.parse(torch_version) < version.parse("2.5.1"):
|
|
||||||
raise ValueError(
|
|
||||||
"Flex attention is not supported on torch version < 2.5.1"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_torch_compile_auto(cls, data):
|
def check_torch_compile_auto(cls, data):
|
||||||
|
|||||||
@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
|
|||||||
sync_ref_model: Optional[bool] = False
|
sync_ref_model: Optional[bool] = False
|
||||||
ref_model_mixup_alpha: Optional[float] = 0.9
|
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||||
ref_model_sync_steps: Optional[int] = 64
|
ref_model_sync_steps: Optional[int] = 64
|
||||||
|
use_liger_loss: Optional[bool] = False
|
||||||
|
|||||||
@@ -172,11 +172,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
min_input_len = np.min(get_dataset_lengths(dataset))
|
||||||
min_input_len = np.min(ds_lengths)
|
LOG.debug(f"min_input_len: {min_input_len}")
|
||||||
LOG.info(f"min_input_len: {min_input_len}")
|
max_input_len = np.max(get_dataset_lengths(dataset))
|
||||||
max_input_len = np.max(ds_lengths)
|
LOG.debug(f"max_input_len: {max_input_len}")
|
||||||
LOG.info(f"max_input_len: {max_input_len}")
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -403,7 +403,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
and (self.cfg.flash_attention or self.cfg.flex_attention)
|
and self.cfg.flash_attention
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
if "auto_map" in self.model_config:
|
if "auto_map" in self.model_config:
|
||||||
@@ -707,13 +707,7 @@ class ModelLoader:
|
|||||||
"""
|
"""
|
||||||
sample packing uses custom FA2 patch
|
sample packing uses custom FA2 patch
|
||||||
"""
|
"""
|
||||||
|
if self.cfg.flash_attention:
|
||||||
if self.cfg.flex_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"flex_attention"
|
|
||||||
)
|
|
||||||
elif self.cfg.flash_attention:
|
|
||||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
pass
|
pass
|
||||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
@@ -1119,7 +1113,7 @@ class ModelLoader:
|
|||||||
should_convert = (
|
should_convert = (
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
|
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
|
||||||
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1327,8 +1321,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
if loftq_bits:
|
if loftq_bits:
|
||||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||||
if cfg.peft_init_lora_weights:
|
|
||||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
|
||||||
if cfg.peft_use_dora:
|
if cfg.peft_use_dora:
|
||||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||||
|
|||||||
@@ -4,17 +4,13 @@ helper util to calculate dataset lengths
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_lengths(dataset, from_arrow=False):
|
def get_dataset_lengths(dataset):
|
||||||
if "length" in dataset.column_names:
|
if "length" in dataset.column_names:
|
||||||
lengths = np.array(dataset["length"])
|
lengths = np.array(dataset["length"])
|
||||||
elif "position_ids" in dataset.column_names:
|
elif "position_ids" in dataset.column_names:
|
||||||
position_ids = dataset["position_ids"]
|
position_ids = dataset["position_ids"]
|
||||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||||
else:
|
else:
|
||||||
if from_arrow:
|
input_ids = dataset["input_ids"]
|
||||||
input_ids = dataset.data.column("input_ids")
|
lengths = np.array([len(seq) for seq in input_ids])
|
||||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
|
||||||
else:
|
|
||||||
input_ids = dataset["input_ids"]
|
|
||||||
lengths = np.array([len(seq) for seq in input_ids])
|
|
||||||
return lengths
|
return lengths
|
||||||
|
|||||||
Reference in New Issue
Block a user