From f0a189131b160d9012e56ab5f0e1b28e55dc72a9 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Mon, 21 Apr 2025 15:53:29 -0400 Subject: [PATCH] amend model loading for hqq + fix hqq version --- src/axolotl/utils/models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ac14e37c5..c5b168ff7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1044,8 +1044,11 @@ class ModelLoader: config=self.model_config, ) else: - if self.cfg.hqq: - # if using hqq, we need to set device_map to gpu otherwise the loading get stuck + if self.cfg.hqq and torch.cuda.device_count() < 2: + # for some reason on single gpu, we need to set device_map to auto/cuda + # otherwise you run into tensors on two devices error during training + # Doesn't affect multi-gpu tho + self.model_kwargs["device_map"] = "auto" self.model = self.auto_model_loader.from_pretrained( self.base_model,