amend model loading for hqq + fix hqq version
This commit is contained in:
@@ -1044,8 +1044,11 @@ class ModelLoader:
|
||||
config=self.model_config,
|
||||
)
|
||||
else:
|
||||
if self.cfg.hqq:
|
||||
# if using hqq, we need to set device_map to gpu otherwise the loading get stuck
|
||||
if self.cfg.hqq and torch.cuda.device_count() < 2:
|
||||
# for some reason on single gpu, we need to set device_map to auto/cuda
|
||||
# otherwise you run into tensors on two devices error during training
|
||||
# Doesn't affect multi-gpu tho
|
||||
|
||||
self.model_kwargs["device_map"] = "auto"
|
||||
self.model = self.auto_model_loader.from_pretrained(
|
||||
self.base_model,
|
||||
|
||||
Reference in New Issue
Block a user