From 14d274efe6642df0d15bf3fa2dd1ef7ef3230af4 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 19 Feb 2025 15:34:32 +0000 Subject: [PATCH] WIP liger support --- src/axolotl/core/trainers/grpo/trainer.py | 264 ++++++++++++++++++++-- 1 file changed, 239 insertions(+), 25 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 8f8b9fcf9..43d8c892b 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,13 +1,20 @@ """ Axolotl GRPO trainer """ + +from contextlib import contextmanager, nullcontext from accelerate.utils import is_peft_model from accelerate.utils.other import is_compiled_module +import torch from transformers import PreTrainedModel from trl import GRPOConfig, GRPOTrainer from trl.models import unwrap_model_for_generation from axolotl.core.trainers.base import SchedulerMixin +from transformers.utils import is_liger_kernel_available + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # mypy: ignore-errors @@ -21,6 +28,19 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Import Liger loss if enabled + if self.args.use_liger_loss: + if not is_liger_kernel_available(): + raise ValueError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + self.grpo_loss_fn = LigerFusedLinearGRPOLoss( + beta=self.beta, + compiled=is_compiled_module(self.model), + use_ref_model=True, + num_generations=self.args.num_generations, + ) # pylint: disable=access-member-before-definition # Enable gradient checkpointing if requested if kwargs["args"].gradient_checkpointing: @@ -29,9 +49,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): self.model.config.use_cache = False # Enable gradient checkpointing on the base model for PEFT - if is_peft_model(self.model) and hasattr( - self.model.base_model, "gradient_checkpointing_enable" - ): + if is_peft_model(self.model) and hasattr(self.model.base_model, "gradient_checkpointing_enable"): self.model.base_model.gradient_checkpointing_enable() # Enable gradient checkpointing for non-PEFT models elif hasattr(self.model, "gradient_checkpointing_enable"): @@ -39,15 +57,12 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"]) # pylint: enable=access-member-before-definition - def _enable_gradient_checkpointing( - self, model: PreTrainedModel, args: GRPOConfig - ) -> PreTrainedModel: + def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel: """Enables gradient checkpointing for the model.""" # pylint: disable=unused-argument,redefined-builtin gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} use_reentrant = ( - "use_reentrant" not in gradient_checkpointing_kwargs - or gradient_checkpointing_kwargs["use_reentrant"] + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] ) if use_reentrant: @@ -58,9 +73,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): def make_inputs_require_grad(module, input, output): output.requires_grad_(True) - model.get_input_embeddings().register_forward_hook( - make_inputs_require_grad - ) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) return model # pylint: enable=unused-argument,redefined-builtin @@ -72,26 +85,18 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): gather_deepspeed3_params=self.args.ds3_gather_for_generation, ) as unwrapped_model: if is_compiled_module(unwrapped_model): - unwrapped_model = ( - unwrapped_model._orig_mod # pylint: disable=protected-access - ) + unwrapped_model = unwrapped_model._orig_mod # pylint: disable=protected-access if is_peft_model(unwrapped_model): unwrapped_model.merge_adapter() state_dict = unwrapped_model.state_dict() unwrapped_model.unmerge_adapter() # Remove base_model and base_layer prefixes state_dict = { - k.removeprefix("base_model.model.") - .removeprefix("base_model.model.") - .replace(".base_layer", ""): v + k.removeprefix("base_model.model.").removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items() } # Remove values with adapter prefix (example: "_lora") - state_dict = { - k: v - for k, v in state_dict.items() - if unwrapped_model.prefix not in k - } + state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} # When module to save, remove its prefix and discard the original module state_dict = { k.replace("modules_to_save.default.", ""): v @@ -101,7 +106,216 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): else: state_dict = unwrapped_model.state_dict() if self.accelerator.is_main_process: - llm_model = ( - self.llm.llm_engine.model_executor.driver_worker.model_runner.model - ) + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model llm_model.load_weights(state_dict.items()) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if self.args.use_liger_loss: + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + + device = self.accelerator.device + prompts = [x["prompt"] for x in inputs] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompt_inputs = self.processing_class( + prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False + ) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + + if self.max_prompt_length is not None: + prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :] + prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :] + + # Generate completions using either vLLM or regular generation + if self.args.use_vllm: + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + state_dict = unwrapped_model.state_dict() + if self.accelerator.is_main_process: + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights(state_dict.items()) + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_prompts_text = gather_object(prompts_text) + if self.accelerator.is_main_process: + outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False) + completion_ids = [out.token_ids for completions in outputs for out in completions.outputs] + else: + completion_ids = [None] * len(all_prompts_text) * self.num_generations + + # Broadcast the completions from the main process to all processes, ensuring each process receives its + # corresponding slice. + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * len(prompts) * self.num_generations, + (self.accelerator.process_index + 1) * len(prompts) * self.num_generations, + ) + completion_ids = completion_ids[process_slice] + + # Pad the completions, and concatenate them with the prompts + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) + prompt_inputs_repeated = torch.repeat_interleave( + prompt_inputs["input_ids"], self.num_generations, dim=0 + ) + prompt_completion_ids = torch.cat([prompt_inputs_repeated, completion_ids], dim=1) + else: + # Regular generation path + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + prompt_completion_ids = unwrapped_model.generate( + **prompt_inputs, generation_config=self.generation_config + ) + + prompt_length = prompt_inputs["input_ids"].size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Get the per-token log probabilities for the completions for the model and the reference model + def get_per_token_logps(model, input_ids, num_logits_to_keep): + # We add 1 to `num_logits_to_keep` because the last logits of the sequence is later excluded + outputs = model(input_ids, num_logits_to_keep=num_logits_to_keep + 1) + hidden_states = outputs.last_hidden_state[:, :-1] + logits = outputs.logits # (B, L, V) + logits = logits[ + :, :-1, : + ] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + + # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. + per_token_logps = [] + for logits_row, input_ids_row in zip(logits, input_ids[:, -num_logits_to_keep:]): + log_probs = logits_row.log_softmax(dim=-1) + token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) + per_token_logps.append(token_log_prob) + return torch.stack(per_token_logps), hidden_states + + num_logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + per_token_logps, hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep) + + with torch.inference_mode(): + if self.ref_model is not None: + ref_per_token_logps, ref_hidden_states = get_per_token_logps( + self.ref_model, prompt_completion_ids, num_logits_to_keep + ) + else: + with self.accelerator.unwrap_model(model).disable_adapter(): + ref_per_token_logps, ref_hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep) + + # done in liger + # Compute the KL divergence between the model and the reference model + # per_token_kl = ( + # torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + # ) + + # Mask everything after the first EOS token + is_eos = completion_ids == self.processing_class.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # Decode the generated completions + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [[{"role": "assistant", "content": completion}] for completion in completions] + + # Compute the rewards + prompts = [prompt for prompt in prompts for _ in range(self.num_generations)] + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes) + ): + if isinstance(reward_func, PreTrainedModel): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + else: + # Repeat all input columns (but "prompt" and "completion") to match the number of generations + reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]} + for key in reward_kwargs: + for example in inputs: + # Repeat each value in the column for `num_generations` times + reward_kwargs[key].extend([example[key]] * self.num_generations) + output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # Sum the rewards from all reward functions + rewards = rewards_per_func.sum(dim=1) + + # done in liger + # # Compute grouped-wise rewards + # mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + # std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + + # done in liger + # # Normalize the rewards to compute the advantages + # mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + # std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + # advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + # done in liger + # x - x.detach() allows for preserving gradients from x + # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + # per_token_loss = -(per_token_loss - self.beta * per_token_kl) + # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + + # Log the metrics + completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() + self._metrics["completion_length"].append(completion_length) + + reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0) + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + reward_func_name = reward_func.config._name_or_path.split("/")[-1] + else: + reward_func_name = reward_func.__name__ + self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) + + self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) + + + lm_head = model.get_output_embeddings() + + if self.ref_model is not None: + ref_lm_head = self.ref_model.get_output_embeddings() + else: + with self.null_ref_context(): + ref_lm_head = model.get_output_embeddings() + ref_weight = ref_lm_head.weight + ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + + loss, metrics = self.grpo_loss_fn( + lm_head, + hidden_states, # this is the hidden states from the model + completion_mask, + rewards, + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + ref_input=ref_hidden_states, # this is the hidden states from the ref model + ref_weight=ref_weight, + ref_bias=ref_bias, + ) + else: + super().compute_loss(model, inputs, return_outputs, num_items_in_batch) + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default")