amend model loading for hqq + fix hqq version

This commit is contained in:
Sunny Liu
2025-04-21 15:53:29 -04:00
parent c8fb5baad6
commit f0a189131b

View File

@@ -1044,8 +1044,11 @@ class ModelLoader:
config=self.model_config,
)
else:
if self.cfg.hqq:
# if using hqq, we need to set device_map to gpu otherwise the loading get stuck
if self.cfg.hqq and torch.cuda.device_count() < 2:
# for some reason on single gpu, we need to set device_map to auto/cuda
# otherwise you run into tensors on two devices error during training
# Doesn't affect multi-gpu tho
self.model_kwargs["device_map"] = "auto"
self.model = self.auto_model_loader.from_pretrained(
self.base_model,