ensure that the bias is also in the correct dtype (#1848) [skip ci]

* ensure that the bias is also in the correct dtype

* add nightly for dpo-qlora-fsdp
This commit is contained in:
Wing Lian
2024-08-22 11:45:00 -04:00
committed by GitHub
parent c3fc529bfc
commit 5b0b774e38
4 changed files with 115 additions and 3 deletions

View File

@@ -1846,6 +1846,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):

View File

@@ -1102,9 +1102,20 @@ def load_lora(model, cfg, inference=False, config_only=False):
def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
weight_mismatch = False
bias_mismatch = False
try:
if module.weight.dtype != dtype:
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
module.to(dtype)
weight_mismatch = module.weight.dtype != dtype
except AttributeError:
pass
try:
bias_mismatch = module.bias.dtype != dtype
except AttributeError:
pass
if weight_mismatch:
print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
if bias_mismatch:
print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
if weight_mismatch or bias_mismatch:
module.to(dtype)