diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 414abeb4d..354177baf 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -725,7 +725,7 @@ class AxolotlTrainer( state_dict = self.accelerator.get_state_dict(self.model) if state_dict is not None: state_dict = { - k: v.clone() if isinstance(v, torch.Tensor) else v + k: v.detach().cpu() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() }