fix: apply fix for only CP mode

This commit is contained in:
NanoCode012
2026-02-25 14:49:46 +07:00
parent 0d0122cabe
commit eb13054672

View File

@@ -720,10 +720,14 @@ class AxolotlTrainer(
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
# fix for Context Parallel save
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
# fix for Context Parallel save: CP eval invalidates tensor storage
# pointers, so clone to CPU to get fresh valid storage for safetensors
if (
state_dict is not None
and self.axolotl_cfg
and self.axolotl_cfg.context_parallel_size
and self.axolotl_cfg.context_parallel_size > 1
):
state_dict = {
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()