diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index a45e246a1..77e7b573b 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -719,6 +719,13 @@ class AxolotlTrainer( output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) LOG.info(f"Saving model checkpoint to {output_dir}") + if state_dict is None: + state_dict = self.accelerator.get_state_dict(self.model) + if state_dict is not None: + state_dict = { + k: v.clone() if isinstance(v, torch.Tensor) else v + for k, v in state_dict.items() + } supported_classes = ( (PreTrainedModel,) if not is_peft_available() diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 78b3d1cae..7f6af7d48 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -218,6 +218,9 @@ class SequenceParallelContextManager: self.original_seq_len = 0 self.pad_len = 0 + # Track local valid token count for eval loss correction across CP ranks + self._local_valid_tokens: torch.Tensor | None = None + # Create a partially applied version of the apply_sequence_parallelism function self.apply_sequence_parallelism = functools.partial( apply_sequence_parallelism, @@ -270,6 +273,18 @@ class SequenceParallelContextManager: self.apply_sequence_parallelism(updated_kwargs) ) + # Track local valid tokens for eval loss correction + if "labels" in updated_kwargs and not self.models[0].training: + self._local_valid_tokens = ( + (updated_kwargs["labels"] != -100).sum().float() + ) + # Strip num_items_in_batch during eval so the model uses + # reduction='mean', allowing the post-hook weighted all-reduce + # formula (loss * local_valid) to correctly recover the loss sum + updated_kwargs.pop("num_items_in_batch", None) + else: + self._local_valid_tokens = None + return remaining_args, updated_kwargs # Forward post-hook to gather outputs @@ -287,6 +302,44 @@ class SequenceParallelContextManager: return output + # Post-hook to correct eval loss via weighted all-reduce across CP ranks + def eval_loss_correction_post_hook(_, __, output: ModelOutput) -> ModelOutput: + if self._local_valid_tokens is None: + return output + if not hasattr(output, "loss") or output.loss is None: + return output + + local_valid = self._local_valid_tokens.to(output.loss.device) + loss = output.loss.detach().clone() + + # Handle rank with zero valid tokens (loss is NaN) + if local_valid.item() == 0: + weighted_loss = torch.zeros(1, device=loss.device, dtype=loss.dtype) + else: + weighted_loss = loss * local_valid + + total_valid = local_valid.clone() + dist.all_reduce( + weighted_loss, + op=dist.ReduceOp.SUM, + group=self.process_group, + ) + dist.all_reduce( + total_valid, + op=dist.ReduceOp.SUM, + group=self.process_group, + ) + + if total_valid.item() > 0: + output["loss"] = (weighted_loss / total_valid).squeeze() + else: + output["loss"] = torch.tensor( + float("nan"), device=loss.device, dtype=loss.dtype + ) + + self._local_valid_tokens = None + return output + # Register hooks for model in self.models: self.hook_handles.append( @@ -298,6 +351,10 @@ class SequenceParallelContextManager: self.hook_handles.append( model.register_forward_hook(sequence_parallel_post_hook) ) + # Always register eval loss correction hook + self.hook_handles.append( + model.register_forward_hook(eval_loss_correction_post_hook) + ) def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: """Gather sharded outputs from all ranks and reconstruct the full tensor."""