diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 2aa319c1c..fe8c0b364 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1009,6 +1009,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs["dataloader_prefetch_factor"] = ( self.cfg.dataloader_prefetch_factor ) + if self.cfg.seed: + training_args_kwargs["seed"] = self.cfg.seed if self.cfg.gradient_checkpointing: training_args_kwargs["gradient_checkpointing"] = ( self.cfg.gradient_checkpointing diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 0186baacc..de7243dbd 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -151,6 +151,8 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): self.local_rank = dist.get_rank(group=self.sp_group) self.local_world_size = dist.get_world_size(group=self.sp_group) + print("end of trainer init") + def _get_train_sampler(self) -> Sampler: # Get distributed training info world_size = dist.get_world_size() @@ -319,576 +321,576 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): if self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() - def _generate_and_score_completions( - self, inputs: list[dict[str, torch.Tensor | Any]] - ) -> dict[str, torch.Tensor | Any]: - device = self.accelerator.device - prompts = [x["prompt"] for x in inputs] - prompts_text = [ - maybe_apply_chat_template(example, self.processing_class)["prompt"] - for example in inputs - ] - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - ) - # pylint: disable=protected-access - prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs) + # def _generate_and_score_completions( + # self, inputs: list[dict[str, torch.Tensor | Any]] + # ) -> dict[str, torch.Tensor | Any]: + # device = self.accelerator.device + # prompts = [x["prompt"] for x in inputs] + # prompts_text = [ + # maybe_apply_chat_template(example, self.processing_class)["prompt"] + # for example in inputs + # ] + # prompt_inputs = self.processing_class( + # text=prompts_text, + # return_tensors="pt", + # padding=True, + # padding_side="left", + # add_special_tokens=False, + # ) + # # pylint: disable=protected-access + # prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs) - prompt_ids, prompt_mask = ( - prompt_inputs["input_ids"], - prompt_inputs["attention_mask"], - ) + # prompt_ids, prompt_mask = ( + # prompt_inputs["input_ids"], + # prompt_inputs["attention_mask"], + # ) - if self.max_prompt_length is not None: - prompt_ids = prompt_ids[:, -self.max_prompt_length :] - prompt_mask = prompt_mask[:, -self.max_prompt_length :] + # if self.max_prompt_length is not None: + # prompt_ids = prompt_ids[:, -self.max_prompt_length :] + # prompt_mask = prompt_mask[:, -self.max_prompt_length :] - # Generate completions using either vLLM or regular generation - if self.args.use_vllm: - # First, have main process load weights if needed - # pylint: disable=access-member-before-definition - if self.state.global_step != self._last_loaded_step: # type: ignore[has-type] - self._move_model_to_vllm() - # pylint: disable=attribute-defined-outside-init - self._last_loaded_step = self.state.global_step + # # Generate completions using either vLLM or regular generation + # if self.args.use_vllm: + # # First, have main process load weights if needed + # # pylint: disable=access-member-before-definition + # if self.state.global_step != self._last_loaded_step: # type: ignore[has-type] + # self._move_model_to_vllm() + # # pylint: disable=attribute-defined-outside-init + # self._last_loaded_step = self.state.global_step - all_prompts_text = gather_object(prompts_text) - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - # ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - ordered_set_of_prompts = all_prompts_text[ - :: self.num_generations * self.args.sequence_parallel_degree - ] - with profiling_context(self, "vLLM.generate"): - completion_ids = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - guided_decoding_regex=self.guided_decoding_regex, - ) - else: - completion_ids = [None] * ( - len(all_prompts_text) // self.args.sequence_parallel_degree - ) + # all_prompts_text = gather_object(prompts_text) + # if self.accelerator.is_main_process: + # # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # # prompt individually. + # # ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + # ordered_set_of_prompts = all_prompts_text[ + # :: self.num_generations * self.args.sequence_parallel_degree + # ] + # with profiling_context(self, "vLLM.generate"): + # completion_ids = self.vllm_client.generate( + # prompts=ordered_set_of_prompts, + # n=self.num_generations, + # repetition_penalty=self.repetition_penalty, + # temperature=self.temperature, + # top_p=self.top_p, + # top_k=-1 if self.top_k is None else self.top_k, + # min_p=0.0 if self.min_p is None else self.min_p, + # max_tokens=self.max_completion_length, + # guided_decoding_regex=self.guided_decoding_regex, + # ) + # else: + # completion_ids = [None] * ( + # len(all_prompts_text) // self.args.sequence_parallel_degree + # ) - # Broadcast the completions from the main process to all processes - completion_ids = broadcast_object_list(completion_ids, from_process=0) + # # Broadcast the completions from the main process to all processes + # completion_ids = broadcast_object_list(completion_ids, from_process=0) - # Determine the appropriate slice based on sequence parallelism - if self.args.sequence_parallel_degree > 1: - # Calculate SP group ID (which group of ranks this rank belongs to) - sp_group_id = self.accelerator.process_index // self.local_world_size + # # Determine the appropriate slice based on sequence parallelism + # if self.args.sequence_parallel_degree > 1: + # # Calculate SP group ID (which group of ranks this rank belongs to) + # sp_group_id = self.accelerator.process_index // self.local_world_size - # Calculate the start index for this SP group - sp_group_start = sp_group_id * len(prompts) * self.local_world_size + # # Calculate the start index for this SP group + # sp_group_start = sp_group_id * len(prompts) * self.local_world_size - # All ranks in the same SP group get the same data slice - process_slice = slice( - sp_group_start, - sp_group_start + len(prompts), - ) - completion_ids = completion_ids[process_slice] - else: - # Original behavior for non-sequence parallel case - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - completion_ids = completion_ids[process_slice] + # # All ranks in the same SP group get the same data slice + # process_slice = slice( + # sp_group_start, + # sp_group_start + len(prompts), + # ) + # completion_ids = completion_ids[process_slice] + # else: + # # Original behavior for non-sequence parallel case + # process_slice = slice( + # self.accelerator.process_index * len(prompts), + # (self.accelerator.process_index + 1) * len(prompts), + # ) + # completion_ids = completion_ids[process_slice] - # Pad the completions, and concatenate them with the prompts - completion_ids = [ - torch.tensor(ids, device=device) for ids in completion_ids - ] - completion_ids = pad( - completion_ids, padding_value=self.processing_class.pad_token_id - ) - else: - # Regular generation path - with unwrap_model_for_generation( - self.model_wrapped, - self.accelerator, - gather_deepspeed3_params=self.args.ds3_gather_for_generation, - ) as unwrapped_model: - prompt_completion_ids = unwrapped_model.generate( - prompt_ids, - attention_mask=prompt_mask, - generation_config=self.generation_config, - ) + # # Pad the completions, and concatenate them with the prompts + # completion_ids = [ + # torch.tensor(ids, device=device) for ids in completion_ids + # ] + # completion_ids = pad( + # completion_ids, padding_value=self.processing_class.pad_token_id + # ) + # else: + # # Regular generation path + # with unwrap_model_for_generation( + # self.model_wrapped, + # self.accelerator, + # gather_deepspeed3_params=self.args.ds3_gather_for_generation, + # ) as unwrapped_model: + # prompt_completion_ids = unwrapped_model.generate( + # prompt_ids, + # attention_mask=prompt_mask, + # generation_config=self.generation_config, + # ) - # Compute prompt length and extract completion ids - prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] - completion_ids = prompt_completion_ids[:, prompt_length:] + # # Compute prompt length and extract completion ids + # prompt_length = prompt_ids.size(1) + # prompt_ids = prompt_completion_ids[:, :prompt_length] + # completion_ids = prompt_completion_ids[:, prompt_length:] - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + # prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - # Mask everything after the first EOS token - is_eos = completion_ids == self.processing_class.eos_token_id - eos_idx = torch.full( - (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device - ) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand( - is_eos.size(0), -1 - ) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + # # Mask everything after the first EOS token + # is_eos = completion_ids == self.processing_class.eos_token_id + # eos_idx = torch.full( + # (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device + # ) + # eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + # sequence_indices = torch.arange(is_eos.size(1), device=device).expand( + # is_eos.size(0), -1 + # ) + # completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - # Concatenate prompt_mask with completion_mask for logit computation - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) - logits_to_keep = completion_ids.size( - 1 - ) # we only need to compute the logits for the completion tokens + # # Concatenate prompt_mask with completion_mask for logit computation + # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + # logits_to_keep = completion_ids.size( + # 1 + # ) # we only need to compute the logits for the completion tokens - with torch.no_grad(): - # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's - # computation here, and use per_token_logps.detach() instead. - if self.num_iterations > 1: - if self.args.sequence_parallel_degree > 1: - old_per_token_logps, _ = self._get_per_token_logps_v2( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - ) - else: - old_per_token_logps = super()._get_per_token_logps( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - ) - else: - old_per_token_logps = None + # with torch.no_grad(): + # # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's + # # computation here, and use per_token_logps.detach() instead. + # if self.num_iterations > 1: + # if self.args.sequence_parallel_degree > 1: + # old_per_token_logps, _ = self._get_per_token_logps_v2( + # self.model, + # prompt_completion_ids, + # attention_mask, + # logits_to_keep, + # ) + # else: + # old_per_token_logps = super()._get_per_token_logps( + # self.model, + # prompt_completion_ids, + # attention_mask, + # logits_to_keep, + # ) + # else: + # old_per_token_logps = None - if self.beta == 0.0: - ref_per_token_logps = None - elif self.ref_model is not None: - if self.args.sequence_parallel_degree > 1: - ref_per_token_logps, _ = self._get_per_token_logps_v2( - self.ref_model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - ) - else: - ref_per_token_logps = super()._get_per_token_logps( - self.ref_model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - ) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - if self.args.sequence_parallel_degree > 1: - ref_per_token_logps, _ = self._get_per_token_logps_v2( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - ) - else: - ref_per_token_logps = super()._get_per_token_logps( - self.model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - ) + # if self.beta == 0.0: + # ref_per_token_logps = None + # elif self.ref_model is not None: + # if self.args.sequence_parallel_degree > 1: + # ref_per_token_logps, _ = self._get_per_token_logps_v2( + # self.ref_model, + # prompt_completion_ids, + # attention_mask, + # logits_to_keep, + # ) + # else: + # ref_per_token_logps = super()._get_per_token_logps( + # self.ref_model, + # prompt_completion_ids, + # attention_mask, + # logits_to_keep, + # ) + # else: + # with self.accelerator.unwrap_model(self.model).disable_adapter(): + # if self.args.sequence_parallel_degree > 1: + # ref_per_token_logps, _ = self._get_per_token_logps_v2( + # self.model, + # prompt_completion_ids, + # attention_mask, + # logits_to_keep, + # ) + # else: + # ref_per_token_logps = super()._get_per_token_logps( + # self.model, + # prompt_completion_ids, + # attention_mask, + # logits_to_keep, + # ) - # Decode the generated completions - completions_text = self.processing_class.batch_decode( - completion_ids, skip_special_tokens=True - ) - if is_conversational(inputs[0]): - completions = [] - for prompt, completion in zip(prompts, completions_text): - bootstrap = ( - prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" - ) - completions.append( - [{"role": "assistant", "content": bootstrap + completion}] - ) - else: - completions = completions_text + # # Decode the generated completions + # completions_text = self.processing_class.batch_decode( + # completion_ids, skip_special_tokens=True + # ) + # if is_conversational(inputs[0]): + # completions = [] + # for prompt, completion in zip(prompts, completions_text): + # bootstrap = ( + # prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + # ) + # completions.append( + # [{"role": "assistant", "content": bootstrap + completion}] + # ) + # else: + # completions = completions_text - rewards_per_func = torch.zeros( - len(prompts), len(self.reward_funcs), device=device - ) - for i, (reward_func, reward_processing_class) in enumerate( - zip(self.reward_funcs, self.reward_processing_classes) - ): - if isinstance( - reward_func, nn.Module - ): # Module instead of PretrainedModel for compat with compiled models - reward_func_name = ( - f"reward {reward_func.config._name_or_path.split('/')[-1]}" - ) - else: - # pylint: disable=protected-access - reward_func_name = reward_func.__name__ - with profiling_context(self, reward_func_name): - if isinstance( - reward_func, nn.Module - ): # Module instead of PretrainedModel for compat with compiled models - if is_conversational(inputs[0]): - messages = [ - {"messages": p + c} for p, c in zip(prompts, completions) - ] - texts = [ - apply_chat_template(x, reward_processing_class)["text"] - for x in messages - ] - else: - texts = [p + c for p, c in zip(prompts, completions)] - reward_inputs = reward_processing_class( - text=texts, - return_tensors="pt", - padding=True, - padding_side="right", - add_special_tokens=False, - ) - # pylint: disable=protected-access - reward_inputs = Trainer._prepare_inputs(self, reward_inputs) - with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[ - :, 0 - ] # Shape (B*G,) - else: - # Repeat all input columns (but "prompt" and "completion") to match the number of generations - keys = [ - key for key in inputs[0] if key not in ["prompt", "completion"] - ] - reward_kwargs = { - key: [example[key] for example in inputs] for key in keys - } - output_reward_func = reward_func( - prompts=prompts, completions=completions, **reward_kwargs - ) - # Convert None values to NaN - output_reward_func = [ - reward if reward is not None else torch.nan - for reward in output_reward_func - ] + # rewards_per_func = torch.zeros( + # len(prompts), len(self.reward_funcs), device=device + # ) + # for i, (reward_func, reward_processing_class) in enumerate( + # zip(self.reward_funcs, self.reward_processing_classes) + # ): + # if isinstance( + # reward_func, nn.Module + # ): # Module instead of PretrainedModel for compat with compiled models + # reward_func_name = ( + # f"reward {reward_func.config._name_or_path.split('/')[-1]}" + # ) + # else: + # # pylint: disable=protected-access + # reward_func_name = reward_func.__name__ + # with profiling_context(self, reward_func_name): + # if isinstance( + # reward_func, nn.Module + # ): # Module instead of PretrainedModel for compat with compiled models + # if is_conversational(inputs[0]): + # messages = [ + # {"messages": p + c} for p, c in zip(prompts, completions) + # ] + # texts = [ + # apply_chat_template(x, reward_processing_class)["text"] + # for x in messages + # ] + # else: + # texts = [p + c for p, c in zip(prompts, completions)] + # reward_inputs = reward_processing_class( + # text=texts, + # return_tensors="pt", + # padding=True, + # padding_side="right", + # add_special_tokens=False, + # ) + # # pylint: disable=protected-access + # reward_inputs = Trainer._prepare_inputs(self, reward_inputs) + # with torch.inference_mode(): + # rewards_per_func[:, i] = reward_func(**reward_inputs).logits[ + # :, 0 + # ] # Shape (B*G,) + # else: + # # Repeat all input columns (but "prompt" and "completion") to match the number of generations + # keys = [ + # key for key in inputs[0] if key not in ["prompt", "completion"] + # ] + # reward_kwargs = { + # key: [example[key] for example in inputs] for key in keys + # } + # output_reward_func = reward_func( + # prompts=prompts, completions=completions, **reward_kwargs + # ) + # # Convert None values to NaN + # output_reward_func = [ + # reward if reward is not None else torch.nan + # for reward in output_reward_func + # ] - rewards_per_func[:, i] = torch.tensor( - output_reward_func, dtype=torch.float32, device=device - ) + # rewards_per_func[:, i] = torch.tensor( + # output_reward_func, dtype=torch.float32, device=device + # ) - # If all reward functions return None for a given row, issue a detailed warning - if torch.isnan(rewards_per_func).all(dim=1).any(): - nan_row_idx = ( - torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] - ) - row_reward_kwargs = { - key: value[nan_row_idx] for key, value in reward_kwargs.items() - } - row_reward_kwargs["prompt"] = prompts[nan_row_idx] - row_reward_kwargs["completion"] = completions[nan_row_idx] - warnings.warn( - f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " - "Please ensure that at least one reward function returns a valid reward." - ) + # # If all reward functions return None for a given row, issue a detailed warning + # if torch.isnan(rewards_per_func).all(dim=1).any(): + # nan_row_idx = ( + # torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + # ) + # row_reward_kwargs = { + # key: value[nan_row_idx] for key, value in reward_kwargs.items() + # } + # row_reward_kwargs["prompt"] = prompts[nan_row_idx] + # row_reward_kwargs["completion"] = completions[nan_row_idx] + # warnings.warn( + # f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " + # "Please ensure that at least one reward function returns a valid reward." + # ) - # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the - # completions may be distributed across processes - rewards_per_func = gather(rewards_per_func) + # # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # # completions may be distributed across processes + # rewards_per_func = gather(rewards_per_func) - # Apply weights to each reward function's output and sum - rewards = ( - rewards_per_func * self.reward_weights.to(device).unsqueeze(0) - ).nansum(dim=1) + # # Apply weights to each reward function's output and sum + # rewards = ( + # rewards_per_func * self.reward_weights.to(device).unsqueeze(0) + # ).nansum(dim=1) - # Compute grouped-wise rewards - mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) - std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + # # Compute grouped-wise rewards + # mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + # std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) - # Normalize the rewards to compute the advantages - mean_grouped_rewards = mean_grouped_rewards.repeat_interleave( - self.num_generations, dim=0 - ) - std_grouped_rewards = std_grouped_rewards.repeat_interleave( - self.num_generations, dim=0 - ) - advantages = rewards - mean_grouped_rewards - if self.args.scale_rewards: - advantages = advantages / (std_grouped_rewards + 1e-4) + # # Normalize the rewards to compute the advantages + # mean_grouped_rewards = mean_grouped_rewards.repeat_interleave( + # self.num_generations, dim=0 + # ) + # std_grouped_rewards = std_grouped_rewards.repeat_interleave( + # self.num_generations, dim=0 + # ) + # advantages = rewards - mean_grouped_rewards + # if self.args.scale_rewards: + # advantages = advantages / (std_grouped_rewards + 1e-4) - # Slice to keep only the local part of the data - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - advantages = advantages[process_slice] + # # Slice to keep only the local part of the data + # process_slice = slice( + # self.accelerator.process_index * len(prompts), + # (self.accelerator.process_index + 1) * len(prompts), + # ) + # advantages = advantages[process_slice] - # Log the metrics - mode = "eval" if self.control.should_evaluate else "train" + # # Log the metrics + # mode = "eval" if self.control.should_evaluate else "train" - if mode == "train": - # pylint: disable=no-member - self._total_train_tokens += ( - self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item() - ) - # pylint: disable=no-member - self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + # if mode == "train": + # # pylint: disable=no-member + # self._total_train_tokens += ( + # self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item() + # ) + # # pylint: disable=no-member + # self._metrics[mode]["num_tokens"] = [self._total_train_tokens] - completion_length = ( - self.accelerator.gather_for_metrics(completion_mask.sum(1)) - .float() - .mean() - .item() - ) - self._metrics[mode]["completion_length"].append(completion_length) + # completion_length = ( + # self.accelerator.gather_for_metrics(completion_mask.sum(1)) + # .float() + # .mean() + # .item() + # ) + # self._metrics[mode]["completion_length"].append(completion_length) - # Calculate mean reward per function, but only for samples where the function was applied - for i, reward_func in enumerate(self.reward_funcs): - if isinstance( - reward_func, nn.Module - ): # Module instead of PretrainedModel for compat with compiled models - reward_func_name = reward_func.config._name_or_path.split("/")[-1] - else: - # pylint: disable=protected-access - reward_func_name = reward_func.__name__ - # Only calculate mean for samples where this reward function was applied (non-NaN values) - mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() - self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards) - self._metrics[mode]["reward"].append(rewards.mean().item()) - self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) + # # Calculate mean reward per function, but only for samples where the function was applied + # for i, reward_func in enumerate(self.reward_funcs): + # if isinstance( + # reward_func, nn.Module + # ): # Module instead of PretrainedModel for compat with compiled models + # reward_func_name = reward_func.config._name_or_path.split("/")[-1] + # else: + # # pylint: disable=protected-access + # reward_func_name = reward_func.__name__ + # # Only calculate mean for samples where this reward function was applied (non-NaN values) + # mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + # self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards) + # self._metrics[mode]["reward"].append(rewards.mean().item()) + # self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) - if ( - self.log_completions - and self.state.global_step % self.args.logging_steps == 0 - ): - prompts_to_log = gather_object(prompts_text) - completions_to_log = gather_object(completions_text) - rewards_to_log = rewards.tolist() + # if ( + # self.log_completions + # and self.state.global_step % self.args.logging_steps == 0 + # ): + # prompts_to_log = gather_object(prompts_text) + # completions_to_log = gather_object(completions_text) + # rewards_to_log = rewards.tolist() - if self.accelerator.is_main_process: - if is_rich_available(): - print_prompt_completions_sample( - prompts_to_log, - completions_to_log, - rewards_to_log, - self.state.global_step, - ) - if ( - self.args.report_to - and "wandb" in self.args.report_to - and wandb.run is not None - ): - import pandas as pd + # if self.accelerator.is_main_process: + # if is_rich_available(): + # print_prompt_completions_sample( + # prompts_to_log, + # completions_to_log, + # rewards_to_log, + # self.state.global_step, + # ) + # if ( + # self.args.report_to + # and "wandb" in self.args.report_to + # and wandb.run is not None + # ): + # import pandas as pd - # For logging - table = { - "step": [str(self.state.global_step)] * len(rewards), - "prompt": prompts_to_log, - "completion": completions_to_log, - "reward": rewards.tolist(), - } - df = pd.DataFrame(table) - wandb.log({"completions": wandb.Table(dataframe=df)}) + # # For logging + # table = { + # "step": [str(self.state.global_step)] * len(rewards), + # "prompt": prompts_to_log, + # "completion": completions_to_log, + # "reward": rewards.tolist(), + # } + # df = pd.DataFrame(table) + # wandb.log({"completions": wandb.Table(dataframe=df)}) - return { - "prompt_ids": prompt_ids, - "prompt_mask": prompt_mask, - "completion_ids": completion_ids, - "completion_mask": completion_mask, - "old_per_token_logps": old_per_token_logps, - "ref_per_token_logps": ref_per_token_logps, - "advantages": advantages, - } + # return { + # "prompt_ids": prompt_ids, + # "prompt_mask": prompt_mask, + # "completion_ids": completion_ids, + # "completion_mask": completion_mask, + # "old_per_token_logps": old_per_token_logps, + # "ref_per_token_logps": ref_per_token_logps, + # "advantages": advantages, + # } - def _get_per_token_logps_v2( - self, model, input_ids, attention_mask, logits_to_keep, completion_mask=None - ): - # Pad sequence to be divisible by SP degree if needed - total_seq_len = input_ids.shape[1] - if total_seq_len % self.local_world_size != 0: - pad_len = self.local_world_size - (total_seq_len % self.local_world_size) - pad_token_id = self.processing_class.pad_token_id or 0 + # def _get_per_token_logps_v2( + # self, model, input_ids, attention_mask, logits_to_keep, completion_mask=None + # ): + # # Pad sequence to be divisible by SP degree if needed + # total_seq_len = input_ids.shape[1] + # if total_seq_len % self.local_world_size != 0: + # pad_len = self.local_world_size - (total_seq_len % self.local_world_size) + # pad_token_id = self.processing_class.pad_token_id or 0 - # Pad input_ids and attention_mask - padding = torch.full( - (input_ids.shape[0], pad_len), - pad_token_id, - dtype=input_ids.dtype, - device=input_ids.device, - ) - input_ids = torch.cat([input_ids, padding], dim=1) + # # Pad input_ids and attention_mask + # padding = torch.full( + # (input_ids.shape[0], pad_len), + # pad_token_id, + # dtype=input_ids.dtype, + # device=input_ids.device, + # ) + # input_ids = torch.cat([input_ids, padding], dim=1) - attn_padding = torch.zeros( - (attention_mask.shape[0], pad_len), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - attention_mask = torch.cat([attention_mask, attn_padding], dim=1) - if completion_mask is not None: - completion_mask = torch.cat([completion_mask, attn_padding], dim=1) + # attn_padding = torch.zeros( + # (attention_mask.shape[0], pad_len), + # dtype=attention_mask.dtype, + # device=attention_mask.device, + # ) + # attention_mask = torch.cat([attention_mask, attn_padding], dim=1) + # if completion_mask is not None: + # completion_mask = torch.cat([completion_mask, attn_padding], dim=1) - total_seq_len += pad_len - logits_to_keep += pad_len + # total_seq_len += pad_len + # logits_to_keep += pad_len - # Split the sequence - slice_size = total_seq_len // self.local_world_size - start = self.local_rank * slice_size - end = start + slice_size + # # Split the sequence + # slice_size = total_seq_len // self.local_world_size + # start = self.local_rank * slice_size + # end = start + slice_size - # Get our slice - input_ids_slice = input_ids[:, start:end] - attention_mask_slice = attention_mask[:, start:end] + # # Get our slice + # input_ids_slice = input_ids[:, start:end] + # attention_mask_slice = attention_mask[:, start:end] - # Calculate where our slice starts and ends relative to the completion tokens - local_completion_mask = None - prompt_len = input_ids.size(1) - logits_to_keep - if start >= prompt_len: - # Slice starts within the completion section - start_in_completion = start - prompt_len - end_in_completion = min(end - prompt_len, logits_to_keep) - local_logits_to_keep = end_in_completion - start_in_completion - if completion_mask is not None: - local_completion_mask = completion_mask[ - :, start_in_completion:end_in_completion - ] - elif end <= prompt_len: - # Slice is entirely within the prompt section (no completion tokens) - local_logits_to_keep = 0 - if completion_mask is not None: - local_completion_mask = torch.zeros( - (completion_mask.size(0), 0), device=completion_mask.device - ) - else: - # Slice contains the boundary between prompt and completion - start_in_completion = 0 - end_in_completion = min(end - prompt_len, logits_to_keep) - local_logits_to_keep = end_in_completion - start_in_completion - if completion_mask is not None: - local_completion_mask = completion_mask[ - :, start_in_completion:end_in_completion - ] + # # Calculate where our slice starts and ends relative to the completion tokens + # local_completion_mask = None + # prompt_len = input_ids.size(1) - logits_to_keep + # if start >= prompt_len: + # # Slice starts within the completion section + # start_in_completion = start - prompt_len + # end_in_completion = min(end - prompt_len, logits_to_keep) + # local_logits_to_keep = end_in_completion - start_in_completion + # if completion_mask is not None: + # local_completion_mask = completion_mask[ + # :, start_in_completion:end_in_completion + # ] + # elif end <= prompt_len: + # # Slice is entirely within the prompt section (no completion tokens) + # local_logits_to_keep = 0 + # if completion_mask is not None: + # local_completion_mask = torch.zeros( + # (completion_mask.size(0), 0), device=completion_mask.device + # ) + # else: + # # Slice contains the boundary between prompt and completion + # start_in_completion = 0 + # end_in_completion = min(end - prompt_len, logits_to_keep) + # local_logits_to_keep = end_in_completion - start_in_completion + # if completion_mask is not None: + # local_completion_mask = completion_mask[ + # :, start_in_completion:end_in_completion + # ] - # Get logits with enough context to compute log probs - logits = model( - input_ids=input_ids_slice, - attention_mask=attention_mask_slice, - logits_to_keep=local_logits_to_keep + 1, - ).logits + # # Get logits with enough context to compute log probs + # logits = model( + # input_ids=input_ids_slice, + # attention_mask=attention_mask_slice, + # logits_to_keep=local_logits_to_keep + 1, + # ).logits - # Only the last rank that contains completion tokens needs to remove the last logit - is_last_rank_with_completions = ( - self.local_rank == self.local_world_size - 1 # Last rank overall - or end - >= prompt_len - + logits_to_keep # Our slice includes the last completion token - ) + # # Only the last rank that contains completion tokens needs to remove the last logit + # is_last_rank_with_completions = ( + # self.local_rank == self.local_world_size - 1 # Last rank overall + # or end + # >= prompt_len + # + logits_to_keep # Our slice includes the last completion token + # ) - if is_last_rank_with_completions: - logits = logits[:, :-1] - if local_completion_mask is not None: - local_completion_mask = local_completion_mask[:, :-1] - local_logits_to_keep -= 1 + # if is_last_rank_with_completions: + # logits = logits[:, :-1] + # if local_completion_mask is not None: + # local_completion_mask = local_completion_mask[:, :-1] + # local_logits_to_keep -= 1 - if start >= prompt_len: - # For ranks where slice is all completion tokens, - # we need to offset to match the logits (which predict the next token) - offset = 1 # Skip the first token as it's predicted by the last token of the previous rank - local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep] - else: - # For the rank that contains the prompt-completion boundary, - # we need to take completion tokens only - offset = prompt_len - start # Where completions start in our slice - local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep] + # if start >= prompt_len: + # # For ranks where slice is all completion tokens, + # # we need to offset to match the logits (which predict the next token) + # offset = 1 # Skip the first token as it's predicted by the last token of the previous rank + # local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep] + # else: + # # For the rank that contains the prompt-completion boundary, + # # we need to take completion tokens only + # offset = prompt_len - start # Where completions start in our slice + # local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep] - logits = logits[ - :, -local_logits_to_keep: - ] # Take only logits for completion tokens - logits = logits / self.temperature - per_token_logps = selective_log_softmax(logits, local_input_ids) + # logits = logits[ + # :, -local_logits_to_keep: + # ] # Take only logits for completion tokens + # logits = logits / self.temperature + # per_token_logps = selective_log_softmax(logits, local_input_ids) - return per_token_logps, local_completion_mask + # return per_token_logps, local_completion_mask - # pylint: disable=unused-argument - @profiling_decorator - def compute_loss( - self, model, inputs, return_outputs=False, num_items_in_batch=None - ): - if return_outputs: - raise ValueError("The GRPOTrainer does not support returning outputs") + # # pylint: disable=unused-argument + # @profiling_decorator + # def compute_loss( + # self, model, inputs, return_outputs=False, num_items_in_batch=None + # ): + # if return_outputs: + # raise ValueError("The GRPOTrainer does not support returning outputs") - # Unpack inputs - prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = ( - inputs["completion_ids"], - inputs["completion_mask"], - ) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) + # # Unpack inputs + # prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + # completion_ids, completion_mask = ( + # inputs["completion_ids"], + # inputs["completion_mask"], + # ) + # prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + # logits_to_keep = completion_ids.size(1) - if self.args.sequence_parallel_degree > 1: - per_token_logps, completion_mask = self._get_per_token_logps_v2( - model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - completion_mask, - ) - else: - per_token_logps = super()._get_per_token_logps( - model, prompt_completion_ids, attention_mask, logits_to_keep - ) + # if self.args.sequence_parallel_degree > 1: + # per_token_logps, completion_mask = self._get_per_token_logps_v2( + # model, + # prompt_completion_ids, + # attention_mask, + # logits_to_keep, + # completion_mask, + # ) + # else: + # per_token_logps = super()._get_per_token_logps( + # model, prompt_completion_ids, attention_mask, logits_to_keep + # ) - # Compute the KL divergence between the model and the reference model - if self.beta != 0.0: - ref_per_token_logps = inputs["ref_per_token_logps"] - per_token_kl = ( - torch.exp(ref_per_token_logps - per_token_logps) - - (ref_per_token_logps - per_token_logps) - - 1 - ) + # # Compute the KL divergence between the model and the reference model + # if self.beta != 0.0: + # ref_per_token_logps = inputs["ref_per_token_logps"] + # per_token_kl = ( + # torch.exp(ref_per_token_logps - per_token_logps) + # - (ref_per_token_logps - per_token_logps) + # - 1 + # ) - # Compute the loss - advantages = inputs["advantages"] - # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation - # and use per_token_logps.detach() instead. - old_per_token_logps = ( - inputs["old_per_token_logps"] - if self.num_iterations > 1 - else per_token_logps.detach() - ) - coef_1 = torch.exp(per_token_logps - old_per_token_logps) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) - per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + # # Compute the loss + # advantages = inputs["advantages"] + # # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation + # # and use per_token_logps.detach() instead. + # old_per_token_logps = ( + # inputs["old_per_token_logps"] + # if self.num_iterations > 1 + # else per_token_logps.detach() + # ) + # coef_1 = torch.exp(per_token_logps - old_per_token_logps) + # coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + # per_token_loss1 = coef_1 * advantages.unsqueeze(1) + # per_token_loss2 = coef_2 * advantages.unsqueeze(1) + # per_token_loss = -torch.min(per_token_loss1, per_token_loss2) - if self.beta != 0.0: - per_token_loss = per_token_loss + self.beta * per_token_kl + # if self.beta != 0.0: + # per_token_loss = per_token_loss + self.beta * per_token_kl - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() + # loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() - # Log metrics - mode = "eval" if self.control.should_evaluate else "train" + # # Log metrics + # mode = "eval" if self.control.should_evaluate else "train" - if self.beta != 0.0: - mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum() - self._metrics[mode]["kl"].append( - self.accelerator.gather_for_metrics(mean_kl).mean().item() - ) + # if self.beta != 0.0: + # mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum() + # self._metrics[mode]["kl"].append( + # self.accelerator.gather_for_metrics(mean_kl).mean().item() + # ) - is_clipped = (per_token_loss1 < per_token_loss2).float() - clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum() - self._metrics[mode]["clip_ratio"].append( - self.accelerator.gather_for_metrics(clip_ratio).mean().item() - ) + # is_clipped = (per_token_loss1 < per_token_loss2).float() + # clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum() + # self._metrics[mode]["clip_ratio"].append( + # self.accelerator.gather_for_metrics(clip_ratio).mean().item() + # ) - return loss + # return loss diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 2b782fece..0d71e7057 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -4,7 +4,6 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag import functools import logging -from contextlib import contextmanager import torch import torch.distributed as dist @@ -14,14 +13,66 @@ from torch.utils.data import DistributedSampler, Sampler from torch.utils.hooks import RemovableHandle from axolotl.monkeypatch.attention.ring_attn import ( - RingAttnFunc, get_ring_attn_group, update_ring_attn_params, ) +from axolotl.utils.schemas.enums import RingAttnFunc LOG = logging.getLogger(__name__) +def _handle_logits_to_keep( + logits_to_keep, + local_rank: int, + local_world_size: int, + ring_attn_func: RingAttnFunc, + total_seq_len: int, +): + """ + Handle logits_to_keep parameter for sequence parallelism. + + Args: + logits_to_keep: Integer or tensor indicating which positions to compute logits + for. + local_rank: Rank in the sequence parallel group. + local_world_size: World size of the sequence parallel group. + ring_attn_func: Ring attention function being used. + total_seq_len: Full sequence length. + + Returns: + Adjusted logits_to_keep appropriate for this rank's sharded sequence + """ + print("start of _handle_logits_to_keep") + print(dist.get_rank(), logits_to_keep) + + # No transformation needed if logits_to_keep is None + if logits_to_keep is None: + return None + + assert isinstance( + logits_to_keep, int + ), "sequence parallelism currently only supports integer logits_to_keep" + assert ring_attn_func in [ + RingAttnFunc.VARLEN_LLAMA3, + RingAttnFunc.BATCH_RING, + ], "if specifying logits_to_keep, sequence parallelism currently only supports 'batch_ring' and 'varlen_llama3' `ring_attn_func`s" + + # For standard sharding, each rank gets a contiguous chunk + chunk_size = total_seq_len // local_world_size + start_idx = local_rank * chunk_size + end_idx = start_idx + chunk_size + + # Check if logits_to_keep is in this rank's range + if start_idx <= logits_to_keep < end_idx: + print("end of _handle_logits_to_keep") + print(dist.get_rank(), logits_to_keep - start_idx) + return logits_to_keep - start_idx + else: + print("end of _handle_logits_to_keep") + print(dist.get_rank(), -1) + return -1 + + def apply_sequence_parallelism( batch: dict[str, torch.Tensor], local_rank: int, @@ -32,10 +83,10 @@ def apply_sequence_parallelism( Apply sequence parallelism slicing to a batch. Args: - batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) - local_rank: Local rank in the sequence parallel group - local_world_size: World size of the sequence parallel group - ring_attn_func: The ring attention function to use + batch: Batch dictionary (e.g., input_ids, attention_mask, etc.). + local_rank: Local rank in the sequence parallel group. + local_world_size: World size of the sequence parallel group. + ring_attn_func: The ring attention function to use. Returns: Sliced batch dictionary. @@ -48,12 +99,10 @@ def apply_sequence_parallelism( total_seq_len = batch["input_ids"].size(1) for key in batch: if ( - key in batch - and isinstance(batch[key], torch.Tensor) + isinstance(batch[key], torch.Tensor) and batch[key].dim() > 1 and batch[key].size(1) == total_seq_len ): - if ring_attn_func in [ RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING, @@ -78,6 +127,14 @@ def apply_sequence_parallelism( dim=1, ).transpose(1, 2) batch[key] = tensor[:, local_rank].contiguous() + if key == "logits_to_keep": + batch[key] = _handle_logits_to_keep( + logits_to_keep=batch[key], + local_rank=local_rank, + local_world_size=local_world_size, + ring_attn_func=ring_attn_func, + total_seq_len=total_seq_len, + ) return batch @@ -205,8 +262,11 @@ class SequenceParallelContextManager: # Forward post-hook to gather outputs def sequence_parallel_post_hook(_, __, output): + print("start of sequence_parallel_post_hook") # Gather the sharded outputs - return self.gather_outputs(output) + output = self.gather_outputs(output) + print("end of sequence_parallel_post_hook") + return output # Register both hooks self.hook_handles.append( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 37aefaabc..fd3d5aa1f 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -18,7 +18,6 @@ from pydantic import ( ) from transformers.utils.import_utils import is_torch_npu_available -from axolotl.monkeypatch.attention.ring_attn import RingAttnFunc from axolotl.utils.distributed import is_main_process from axolotl.utils.schemas.datasets import ( DatasetConfig, diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 2d4f97084..73087ee9d 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -1,6 +1,4 @@ -""" -E2E tests for mixtral -""" +"""E2E tests for mixtral""" import logging import os @@ -99,6 +97,7 @@ class TestMixtral(unittest.TestCase): "bf16": "auto", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 6e1e2f2cb..e3a0ec0a7 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -12,12 +12,12 @@ from accelerate.state import PartialState from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism from axolotl.monkeypatch.attention.ring_attn import ( - RingAttnFunc, get_ring_attn_group, register_ring_attn, set_ring_attn_group, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.schemas.enums import RingAttnFunc @pytest.fixture