diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 61eea48d8..2aa319c1c 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -538,8 +538,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): report_to = [] if self.cfg.use_wandb: report_to.append("wandb") - if self.cfg.wandb_name: - training_arguments_kwargs["run_name"] = self.cfg.wandb_name if self.cfg.use_mlflow: report_to.append("mlflow") if self.cfg.use_tensorboard: diff --git a/src/axolotl/core/trainers/grpo/sampler.py b/src/axolotl/core/trainers/grpo/sampler.py index b2010e9b7..33e2e462b 100644 --- a/src/axolotl/core/trainers/grpo/sampler.py +++ b/src/axolotl/core/trainers/grpo/sampler.py @@ -5,7 +5,7 @@ sequence parallelism functionality; i.e., duplicating data across ranks in the s sequencee parallel group. """ -from typing import Optional, Sized +from typing import Sized import torch from torch.utils.data import Sampler @@ -23,11 +23,11 @@ class SequenceParallelRepeatRandomSampler(Sampler): self, dataset: Sized, mini_repeat_count: int, + world_size: int, + rank: int, batch_size: int = 1, repeat_count: int = 1, sequence_parallel_degree: int = 1, - world_size: Optional[int] = None, - rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 48e28c22f..b7e374356 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,5 +1,7 @@ """Axolotl GRPO trainer""" +# pylint: disable=too-many-lines,duplicate-code + import warnings from collections import defaultdict from contextlib import nullcontext @@ -68,6 +70,7 @@ from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group if is_peft_available(): + # pylint: disable=unused-import from peft import PeftConfig, get_peft_model if is_deepspeed_available(): @@ -101,6 +104,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): ] = (None, None), peft_config: "PeftConfig | None" = None, ): + RngLoaderMixin.__init__(self) + SchedulerMixin.__init__(self) + # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path @@ -264,7 +270,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): model.warnings_issued["estimate_tokens"] = True # Initialize the metrics - self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._metrics: dict[str, dict[str, list]] = { + "train": defaultdict(list), + "eval": defaultdict(list), + } self._total_train_tokens = 0 self.log_completions = args.log_completions @@ -437,13 +446,13 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): return SequenceParallelRepeatRandomSampler( dataset=self.train_dataset, mini_repeat_count=self.num_generations, + world_size=world_size, + rank=rank, batch_size=effective_batch_size // self.num_generations // self.args.sequence_parallel_degree, repeat_count=self.num_iterations, sequence_parallel_degree=self.args.sequence_parallel_degree, - world_size=world_size, - rank=rank, shuffle=True, seed=self.args.seed, drop_last=True, @@ -514,6 +523,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): def get_train_dataloader(self) -> DataLoader: """Get dataloader for training""" train_dataset = self.train_dataset + # pylint: disable=access-member-before-definition data_collator = self.data_collator # type: ignore # Initialize SP group attributes if sequence parallelism is enabled @@ -532,7 +542,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): train_dataset, description="training" ) else: - self.data_collator = self._get_collator_with_removed_columns( + self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init data_collator, description="training", ) @@ -591,7 +601,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): self.vllm_client.reset_prefix_cache() def _generate_and_score_completions( - self, inputs: dict[str | torch.Tensor | Any] + self, inputs: list[dict[str, torch.Tensor | Any]] ) -> dict[str, torch.Tensor | Any]: device = self.accelerator.device prompts = [x["prompt"] for x in inputs] @@ -606,6 +616,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): padding_side="left", add_special_tokens=False, ) + # pylint: disable=protected-access prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs) prompt_ids, prompt_mask = ( @@ -682,7 +693,6 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): completion_ids = pad( completion_ids, padding_value=self.processing_class.pad_token_id ) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) else: # Regular generation path with unwrap_model_for_generation( @@ -701,6 +711,8 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + # Mask everything after the first EOS token is_eos = completion_ids == self.processing_class.eos_token_id eos_idx = torch.full( @@ -719,105 +731,59 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): ) # we only need to compute the logits for the completion tokens with torch.no_grad(): - if self.args.sequence_parallel_degree > 1: - # Pad sequence to be divisible by SP degree if needed - total_seq_len = prompt_completion_ids.shape[1] - if total_seq_len % self.local_world_size != 0: - pad_len = self.local_world_size - ( - total_seq_len % self.local_world_size - ) - pad_token_id = self.processing_class.pad_token_id or 0 - - # Pad input_ids and attention_mask - padding = torch.full( - (prompt_completion_ids.shape[0], pad_len), - pad_token_id, - dtype=prompt_completion_ids.dtype, - device=prompt_completion_ids.device, - ) - prompt_completion_ids = torch.cat( - [prompt_completion_ids, padding], dim=1 - ) - - attn_padding = torch.zeros( - (attention_mask.shape[0], pad_len), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - attention_mask = torch.cat([attention_mask, attn_padding], dim=1) - - total_seq_len += pad_len - logits_to_keep += pad_len - - # Split the sequence - slice_size = total_seq_len // self.local_world_size - start = self.local_rank * slice_size - end = start + slice_size - - # Get our slice - prompt_completion_ids = prompt_completion_ids[:, start:end] - attention_mask = attention_mask[:, start:end] - - # Calculate how many completion tokens each rank should process - prompt_len = prompt_ids.size(1) - completion_len = completion_ids.size( - 1 - ) # This is equal to logits_to_keep - - # Calculate where our slice starts and ends relative to the completion tokens - if start >= prompt_len: - # Slice starts within the completion section - start_in_completion = start - prompt_len - end_in_completion = min(end - prompt_len, completion_len) - logits_to_keep = end_in_completion - start_in_completion - completion_mask = completion_mask[ - :, start_in_completion:end_in_completion - ] - elif end <= prompt_len: - # Slice is entirely within the prompt section (no completion tokens) - logits_to_keep = 0 - completion_mask = torch.zeros( - (completion_mask.size(0), 0), device=completion_mask.device - ) - else: - # Slice contains the boundary between prompt and completion - start_in_completion = 0 - end_in_completion = min(end - prompt_len, completion_len) - logits_to_keep = end_in_completion - start_in_completion - completion_mask = completion_mask[ - :, start_in_completion:end_in_completion - ] - # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's # computation here, and use per_token_logps.detach() instead. if self.num_iterations > 1: - old_per_token_logps = self._get_per_token_logps( - self.model, prompt_completion_ids, attention_mask, logits_to_keep - ) + if self.args.sequence_parallel_degree > 1: + old_per_token_logps, _ = self._get_per_token_logps_v2( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + ) + else: + old_per_token_logps = super()._get_per_token_logps( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + ) else: old_per_token_logps = None if self.beta == 0.0: ref_per_token_logps = None elif self.ref_model is not None: - print(f"{dist.get_rank()}: prompt_completion_ids.shape: {prompt_completion_ids.shape}") - print(f"{dist.get_rank()}: attention_mask.shape: {attention_mask.shape}") - print(f"{dist.get_rank()}: logits_to_keep: {logits_to_keep}") - - ref_per_token_logps = self._get_per_token_logps( - self.ref_model, - prompt_completion_ids, - attention_mask, - logits_to_keep, - ) - else: - with self.accelerator.unwrap_model(self.model).disable_adapter(): - ref_per_token_logps = self._get_per_token_logps( - self.model, + if self.args.sequence_parallel_degree > 1: + ref_per_token_logps, _ = self._get_per_token_logps_v2( + self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep, ) + else: + ref_per_token_logps = super()._get_per_token_logps( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + if self.args.sequence_parallel_degree > 1: + ref_per_token_logps, _ = self._get_per_token_logps_v2( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + ) + else: + ref_per_token_logps = super()._get_per_token_logps( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + ) # Decode the generated completions completions_text = self.processing_class.batch_decode( @@ -848,6 +814,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): f"reward {reward_func.config._name_or_path.split('/')[-1]}" ) else: + # pylint: disable=protected-access reward_func_name = reward_func.__name__ with profiling_context(self, reward_func_name): if isinstance( @@ -870,6 +837,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): padding_side="right", add_special_tokens=False, ) + # pylint: disable=protected-access reward_inputs = Trainer._prepare_inputs(self, reward_inputs) with torch.inference_mode(): rewards_per_func[:, i] = reward_func(**reward_inputs).logits[ @@ -966,6 +934,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): ): # Module instead of PretrainedModel for compat with compiled models reward_func_name = reward_func.config._name_or_path.split("/")[-1] else: + # pylint: disable=protected-access reward_func_name = reward_func.__name__ # Only calculate mean for samples where this reward function was applied (non-NaN values) mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() @@ -1016,6 +985,115 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): "advantages": advantages, } + def _get_per_token_logps_v2( + self, model, input_ids, attention_mask, logits_to_keep, completion_mask=None + ): + # Pad sequence to be divisible by SP degree if needed + total_seq_len = input_ids.shape[1] + if total_seq_len % self.local_world_size != 0: + pad_len = self.local_world_size - (total_seq_len % self.local_world_size) + pad_token_id = self.processing_class.pad_token_id or 0 + + # Pad input_ids and attention_mask + padding = torch.full( + (input_ids.shape[0], pad_len), + pad_token_id, + dtype=input_ids.dtype, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, padding], dim=1) + + attn_padding = torch.zeros( + (attention_mask.shape[0], pad_len), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, attn_padding], dim=1) + if completion_mask is not None: + completion_mask = torch.cat([completion_mask, attn_padding], dim=1) + + total_seq_len += pad_len + logits_to_keep += pad_len + + # Split the sequence + slice_size = total_seq_len // self.local_world_size + start = self.local_rank * slice_size + end = start + slice_size + + # Get our slice + input_ids_slice = input_ids[:, start:end] + attention_mask_slice = attention_mask[:, start:end] + + # Calculate where our slice starts and ends relative to the completion tokens + local_completion_mask = None + prompt_len = input_ids.size(1) - logits_to_keep + if start >= prompt_len: + # Slice starts within the completion section + start_in_completion = start - prompt_len + end_in_completion = min(end - prompt_len, logits_to_keep) + local_logits_to_keep = end_in_completion - start_in_completion + if completion_mask is not None: + local_completion_mask = completion_mask[ + :, start_in_completion:end_in_completion + ] + elif end <= prompt_len: + # Slice is entirely within the prompt section (no completion tokens) + local_logits_to_keep = 0 + if completion_mask is not None: + local_completion_mask = torch.zeros( + (completion_mask.size(0), 0), device=completion_mask.device + ) + else: + # Slice contains the boundary between prompt and completion + start_in_completion = 0 + end_in_completion = min(end - prompt_len, logits_to_keep) + local_logits_to_keep = end_in_completion - start_in_completion + if completion_mask is not None: + local_completion_mask = completion_mask[ + :, start_in_completion:end_in_completion + ] + + # Get logits with enough context to compute log probs + logits = model( + input_ids=input_ids_slice, + attention_mask=attention_mask_slice, + logits_to_keep=local_logits_to_keep + 1, + ).logits + + # Only the last rank that contains completion tokens needs to remove the last logit + is_last_rank_with_completions = ( + self.local_rank == self.local_world_size - 1 # Last rank overall + or end + >= prompt_len + + logits_to_keep # Our slice includes the last completion token + ) + + if is_last_rank_with_completions: + logits = logits[:, :-1] + if local_completion_mask is not None: + local_completion_mask = local_completion_mask[:, :-1] + local_logits_to_keep -= 1 + + if start >= prompt_len: + # For ranks where slice is all completion tokens, + # we need to offset to match the logits (which predict the next token) + offset = 1 # Skip the first token as it's predicted by the last token of the previous rank + local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep] + else: + # For the rank that contains the prompt-completion boundary, + # we need to take completion tokens only + offset = prompt_len - start # Where completions start in our slice + local_input_ids = input_ids_slice[:, offset : offset + local_logits_to_keep] + + logits = logits[ + :, -local_logits_to_keep: + ] # Take only logits for completion tokens + logits = logits / self.temperature + per_token_logps = selective_log_softmax(logits, local_input_ids) + + return per_token_logps, local_completion_mask + + # pylint: disable=unused-argument @profiling_decorator def compute_loss( self, model, inputs, return_outputs=False, num_items_in_batch=None @@ -1029,103 +1107,21 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): inputs["completion_ids"], inputs["completion_mask"], ) - input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) if self.args.sequence_parallel_degree > 1: - # Pad sequence to be divisible by SP degree if needed - total_seq_len = input_ids.shape[1] - if total_seq_len % self.local_world_size != 0: - pad_len = self.local_world_size - ( - total_seq_len % self.local_world_size - ) - pad_token_id = self.processing_class.pad_token_id or 0 - - # Pad input_ids and attention_mask - padding = torch.full( - (input_ids.shape[0], pad_len), - pad_token_id, - dtype=input_ids.dtype, - device=input_ids.device, - ) - input_ids = torch.cat([input_ids, padding], dim=1) - - attn_padding = torch.zeros( - (attention_mask.shape[0], pad_len), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - attention_mask = torch.cat([attention_mask, attn_padding], dim=1) - - total_seq_len += pad_len - logits_to_keep += pad_len - - # Split the sequence - slice_size = total_seq_len // self.local_world_size - start = self.local_rank * slice_size - end = start + slice_size - - # Get our slice - input_ids_slice = input_ids[:, start:end] - attention_mask_slice = attention_mask[:, start:end] - - # Calculate how many completion tokens each rank should process - prompt_len = prompt_ids.size(1) - completion_len = completion_ids.size(1) # This is equal to logits_to_keep - - # Calculate where our slice starts and ends relative to the completion tokens - if start >= prompt_len: - # Slice starts within the completion section - start_in_completion = start - prompt_len - end_in_completion = min(end - prompt_len, completion_len) - local_logits_to_keep = end_in_completion - start_in_completion - completion_mask = completion_mask[ - :, start_in_completion:end_in_completion - ] - elif end <= prompt_len: - # Slice is entirely within the prompt section (no completion tokens) - local_logits_to_keep = 0 - completion_mask = torch.zeros( - (completion_mask.size(0), 0), device=completion_mask.device - ) - else: - # Slice contains the boundary between prompt and completion - start_in_completion = 0 - end_in_completion = min(end - prompt_len, completion_len) - local_logits_to_keep = end_in_completion - start_in_completion - completion_mask = completion_mask[ - :, start_in_completion:end_in_completion - ] - - # Run model on our slice - if local_logits_to_keep > 0: - # Get logits with enough context to compute log probs - logits = model( - input_ids=input_ids_slice, - attention_mask=attention_mask_slice, - logits_to_keep=local_logits_to_keep + 1, - ).logits - - # Remove the last prediction token on the last rank - # if self.local_rank == self.local_world_size - 1: - # logits = logits[:, :-1, :] - logits = logits[:, :-1, :] - - # Compute log probabilities on our local slice - local_input_ids = input_ids_slice[:, -local_logits_to_keep:] - logits = logits / self.temperature - per_token_logps = selective_log_softmax(logits, local_input_ids) - else: - # This rank doesn't have any tokens to keep - per_token_logps = torch.zeros( - (input_ids.shape[0], 0), - dtype=torch.float32, - device=input_ids.device, - ) + per_token_logps, completion_mask = self._get_per_token_logps_v2( + model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + completion_mask, + ) else: per_token_logps = super()._get_per_token_logps( - model, input_ids, attention_mask, logits_to_keep + model, prompt_completion_ids, attention_mask, logits_to_keep ) # Compute the KL divergence between the model and the reference model @@ -1155,17 +1151,6 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl - if dist.get_rank() == 0: - import ipdb - - ipdb.set_trace() - dist.barrier() - if dist.get_rank() == 1: - import ipdb - - ipdb.set_trace() - dist.barrier() - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() # Log metrics @@ -1184,155 +1169,3 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): ) return loss - - # def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - # # if self.args.sequence_parallel_degree > 1: - # if False: - # # Handle padding to make sequence length divisible by world size - # total_seq_len = input_ids.shape[1] - # if total_seq_len % self.local_world_size != 0: - # # Pad to make divisible - # pad_len = self.local_world_size - ( - # total_seq_len % self.local_world_size - # ) - # pad_token_id = self.processing_class.pad_token_id or 0 - - # # Pad input_ids - # padding = torch.full( - # (input_ids.shape[0], pad_len), - # pad_token_id, - # dtype=input_ids.dtype, - # device=input_ids.device, - # ) - # input_ids = torch.cat([input_ids, padding], dim=1) - - # # Pad attention mask - # if attention_mask is not None: - # attn_padding = torch.zeros( - # (attention_mask.shape[0], pad_len), - # dtype=attention_mask.dtype, - # device=attention_mask.device, - # ) - # attention_mask = torch.cat([attention_mask, attn_padding], dim=1) - - # total_seq_len += pad_len - # logits_to_keep += pad_len - - # # Share logits_to_keep across ranks to ensure consistency - # lt_keep = torch.tensor([logits_to_keep], device=input_ids.device) - # dist.broadcast(lt_keep, src=0, group=self.sp_group) - # logits_to_keep = lt_keep.item() - - # # Split the sequence across ranks - # slice_size = total_seq_len // self.local_world_size - # start = self.local_rank * slice_size - # end = start + slice_size - - # # Slice for this rank - # input_ids_slice = input_ids[:, start:end] - # attention_mask_slice = ( - # attention_mask[:, start:end] if attention_mask is not None else None - # ) - - # # Calculate how many tokens this rank needs to keep - # tokens_before_slice = self.local_rank * slice_size - # local_logits_to_keep = 0 - - # if tokens_before_slice < logits_to_keep: - # # This rank has tokens we need to keep - # local_logits_to_keep = min( - # slice_size, logits_to_keep - tokens_before_slice - # ) - - # # Run the model on our slice - # if local_logits_to_keep > 0: - # logits = model( - # input_ids=input_ids_slice, - # attention_mask=attention_mask_slice, - # logits_to_keep=local_logits_to_keep + 1, - # ).logits - # if self.local_rank == self.local_world_size - 1: - # logits = logits[:, :-1, :] - - # # Get the relevant input_ids for computing log probs - # # Ensure this is the correct slice that corresponds to the logits - # relevant_input_ids = input_ids_slice[:, -local_logits_to_keep:] - # else: - # # Create empty logits with correct shape if we don't need to keep any - # vocab_size = model.config.vocab_size - # logits = torch.zeros( - # (input_ids.shape[0], 0, vocab_size), - # dtype=torch.float32, - # device=input_ids.device, - # ) - # relevant_input_ids = torch.zeros( - # (input_ids.shape[0], 0), - # dtype=torch.float32, - # device=input_ids.device, - # ) - - # # Temperature scaling - # logits = logits / self.temperature - - # print(f"{dist.get_rank()}: logits.shape: {logits.shape}") - # print( - # f"{dist.get_rank()}: relevant_input_ids.shape: {relevant_input_ids.shape}" - # ) - - # return selective_log_softmax(logits, relevant_input_ids) - - # # All-gather results across SP group with proper shape handling - # # local_shape = torch.tensor([logits.shape[1]], device=logits.device) - # # all_shapes = [ - # # torch.zeros_like(local_shape) for _ in range(self.local_world_size) - # # ] - # # dist.all_gather(all_shapes, local_shape, group=self.sp_group) - - # # # Create full tensor to hold the complete result - # # full_logits = torch.zeros( - # # (input_ids.shape[0], logits_to_keep, model.config.vocab_size), - # # dtype=torch.float32, - # # device=input_ids.device, - # # ) - - # # # Calculate positions for each rank's contribution - # # position = 0 - # # for i in range(self.local_world_size): - # # shape = all_shapes[i].item() - # # if i < self.local_rank: - # # position += shape - - # # # Add our contribution to the full tensor - # # if local_logits_to_keep > 0: - # # # Make sure we're not exceeding bounds - # # end_pos = min(position + local_logits_to_keep, logits_to_keep) - # # copy_size = end_pos - position - - # # if dist.get_rank() == 0: - # # import ipdb; ipdb.set_trace() - # # dist.barrier() - # # if dist.get_rank() == 1: - # # import ipdb; ipdb.set_trace() - # # dist.barrier() - - # # if copy_size > 0: - # # full_logits[:, position:end_pos, :] = logits[:, :copy_size, :] - - # # # Combine results via all-reduce - # # dist.all_reduce(full_logits, op=dist.ReduceOp.SUM, group=self.sp_group) - - # # # Remove the last prediction token - # # full_logits = full_logits[:, :-1, :] - - # # # Get the relevant input_ids for computing log probs - # # # Ensure this is the correct slice that corresponds to the logits - # # relevant_input_ids = input_ids[:, -logits_to_keep:] - - # # # Temperature scaling - # # full_logits = full_logits / self.temperature - - # # return selective_log_softmax(full_logits, relevant_input_ids) - # else: - # return super()._get_per_token_logps( - # model, input_ids, attention_mask, logits_to_keep - # ) diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 87e385b68..2b782fece 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -160,7 +160,6 @@ class SequenceParallelMixin: ) -<<<<<<< HEAD class SequenceParallelContextManager: """ Context manager for sequence parallelism operations. @@ -313,40 +312,3 @@ class SequenceParallelContextManager: result[:, pos] = gathered_tensor[:, i] return result -======= -class SequenceParallelismManager: - def __init__(self, local_rank, local_world_size): - self.local_rank = local_rank - self.local_world_size = local_world_size - - @contextmanager - def apply(self, batch): - """ - Context manager that applies sequence parallelism slicing to a batch, - and restores the original batch afterward if needed. - - Args: - batch: Batch dictionary from parent collator. - - Yields: - Sliced batch dictionary for use in the model. - """ - # Get local (start, end) for sequence parallelism slicing - total_seq_len = batch["input_ids"].size(1) - slice_size = total_seq_len // self.local_world_size - start = self.local_rank * slice_size - end = start + slice_size - - # Update params for varlen ring attention calculation - if batch.get("position_ids") is not None: - update_ring_attn_params( - input_ids=batch["input_ids"], position_ids=batch["position_ids"] - ) - - # Slice batch for sequence parallel processing - for key in batch: - if isinstance(batch[key], torch.Tensor) and batch[key].size(1) == total_seq_len: - batch[key] = batch[key][:, start:end] - - yield batch ->>>>>>> c0054f07 (progress)