amend model loading for hqq + fix hqq version
This commit is contained in:
@@ -1044,8 +1044,11 @@ class ModelLoader:
|
|||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.cfg.hqq:
|
if self.cfg.hqq and torch.cuda.device_count() < 2:
|
||||||
# if using hqq, we need to set device_map to gpu otherwise the loading get stuck
|
# for some reason on single gpu, we need to set device_map to auto/cuda
|
||||||
|
# otherwise you run into tensors on two devices error during training
|
||||||
|
# Doesn't affect multi-gpu tho
|
||||||
|
|
||||||
self.model_kwargs["device_map"] = "auto"
|
self.model_kwargs["device_map"] = "auto"
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
|
|||||||
Reference in New Issue
Block a user