diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ac14e37c5..c5b168ff7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1044,8 +1044,11 @@ class ModelLoader: config=self.model_config, ) else: - if self.cfg.hqq: - # if using hqq, we need to set device_map to gpu otherwise the loading get stuck + if self.cfg.hqq and torch.cuda.device_count() < 2: + # for some reason on single gpu, we need to set device_map to auto/cuda + # otherwise you run into tensors on two devices error during training + # Doesn't affect multi-gpu tho + self.model_kwargs["device_map"] = "auto" self.model = self.auto_model_loader.from_pretrained( self.base_model,