fix: saving state dict and eval for Context Parallel (#3382) [skip ci]
* clone state_dict if none * patch calculating eval loss for cp
This commit is contained in:
@@ -719,6 +719,13 @@ class AxolotlTrainer(
|
|||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.accelerator.get_state_dict(self.model)
|
||||||
|
if state_dict is not None:
|
||||||
|
state_dict = {
|
||||||
|
k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in state_dict.items()
|
||||||
|
}
|
||||||
supported_classes = (
|
supported_classes = (
|
||||||
(PreTrainedModel,)
|
(PreTrainedModel,)
|
||||||
if not is_peft_available()
|
if not is_peft_available()
|
||||||
|
|||||||
@@ -218,6 +218,9 @@ class SequenceParallelContextManager:
|
|||||||
self.original_seq_len = 0
|
self.original_seq_len = 0
|
||||||
self.pad_len = 0
|
self.pad_len = 0
|
||||||
|
|
||||||
|
# Track local valid token count for eval loss correction across CP ranks
|
||||||
|
self._local_valid_tokens: torch.Tensor | None = None
|
||||||
|
|
||||||
# Create a partially applied version of the apply_sequence_parallelism function
|
# Create a partially applied version of the apply_sequence_parallelism function
|
||||||
self.apply_sequence_parallelism = functools.partial(
|
self.apply_sequence_parallelism = functools.partial(
|
||||||
apply_sequence_parallelism,
|
apply_sequence_parallelism,
|
||||||
@@ -270,6 +273,18 @@ class SequenceParallelContextManager:
|
|||||||
self.apply_sequence_parallelism(updated_kwargs)
|
self.apply_sequence_parallelism(updated_kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track local valid tokens for eval loss correction
|
||||||
|
if "labels" in updated_kwargs and not self.models[0].training:
|
||||||
|
self._local_valid_tokens = (
|
||||||
|
(updated_kwargs["labels"] != -100).sum().float()
|
||||||
|
)
|
||||||
|
# Strip num_items_in_batch during eval so the model uses
|
||||||
|
# reduction='mean', allowing the post-hook weighted all-reduce
|
||||||
|
# formula (loss * local_valid) to correctly recover the loss sum
|
||||||
|
updated_kwargs.pop("num_items_in_batch", None)
|
||||||
|
else:
|
||||||
|
self._local_valid_tokens = None
|
||||||
|
|
||||||
return remaining_args, updated_kwargs
|
return remaining_args, updated_kwargs
|
||||||
|
|
||||||
# Forward post-hook to gather outputs
|
# Forward post-hook to gather outputs
|
||||||
@@ -287,6 +302,44 @@ class SequenceParallelContextManager:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
# Post-hook to correct eval loss via weighted all-reduce across CP ranks
|
||||||
|
def eval_loss_correction_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||||
|
if self._local_valid_tokens is None:
|
||||||
|
return output
|
||||||
|
if not hasattr(output, "loss") or output.loss is None:
|
||||||
|
return output
|
||||||
|
|
||||||
|
local_valid = self._local_valid_tokens.to(output.loss.device)
|
||||||
|
loss = output.loss.detach().clone()
|
||||||
|
|
||||||
|
# Handle rank with zero valid tokens (loss is NaN)
|
||||||
|
if local_valid.item() == 0:
|
||||||
|
weighted_loss = torch.zeros(1, device=loss.device, dtype=loss.dtype)
|
||||||
|
else:
|
||||||
|
weighted_loss = loss * local_valid
|
||||||
|
|
||||||
|
total_valid = local_valid.clone()
|
||||||
|
dist.all_reduce(
|
||||||
|
weighted_loss,
|
||||||
|
op=dist.ReduceOp.SUM,
|
||||||
|
group=self.process_group,
|
||||||
|
)
|
||||||
|
dist.all_reduce(
|
||||||
|
total_valid,
|
||||||
|
op=dist.ReduceOp.SUM,
|
||||||
|
group=self.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_valid.item() > 0:
|
||||||
|
output["loss"] = (weighted_loss / total_valid).squeeze()
|
||||||
|
else:
|
||||||
|
output["loss"] = torch.tensor(
|
||||||
|
float("nan"), device=loss.device, dtype=loss.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
self._local_valid_tokens = None
|
||||||
|
return output
|
||||||
|
|
||||||
# Register hooks
|
# Register hooks
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
self.hook_handles.append(
|
self.hook_handles.append(
|
||||||
@@ -298,6 +351,10 @@ class SequenceParallelContextManager:
|
|||||||
self.hook_handles.append(
|
self.hook_handles.append(
|
||||||
model.register_forward_hook(sequence_parallel_post_hook)
|
model.register_forward_hook(sequence_parallel_post_hook)
|
||||||
)
|
)
|
||||||
|
# Always register eval loss correction hook
|
||||||
|
self.hook_handles.append(
|
||||||
|
model.register_forward_hook(eval_loss_correction_post_hook)
|
||||||
|
)
|
||||||
|
|
||||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||||
|
|||||||
Reference in New Issue
Block a user