diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 354177baf..d055608dd 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -720,10 +720,14 @@ class AxolotlTrainer( os.makedirs(output_dir, exist_ok=True) LOG.info(f"Saving model checkpoint to {output_dir}") - # fix for Context Parallel save - if state_dict is None: - state_dict = self.accelerator.get_state_dict(self.model) - if state_dict is not None: + # fix for Context Parallel save: CP eval invalidates tensor storage + # pointers, so clone to CPU to get fresh valid storage for safetensors + if ( + state_dict is not None + and self.axolotl_cfg + and self.axolotl_cfg.context_parallel_size + and self.axolotl_cfg.context_parallel_size > 1 + ): state_dict = { k: v.detach().cpu() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()