ensure that the bias is also in the correct dtype (#1848) [skip ci]
* ensure that the bias is also in the correct dtype * add nightly for dpo-qlora-fsdp
This commit is contained in:
@@ -1846,6 +1846,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
|
||||
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||
|
||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||
|
||||
@@ -1102,9 +1102,20 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
|
||||
def ensure_dtype(model, dtype=torch.bfloat16):
|
||||
for name, module in model.named_modules():
|
||||
weight_mismatch = False
|
||||
bias_mismatch = False
|
||||
try:
|
||||
if module.weight.dtype != dtype:
|
||||
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
|
||||
module.to(dtype)
|
||||
weight_mismatch = module.weight.dtype != dtype
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
bias_mismatch = module.bias.dtype != dtype
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if weight_mismatch:
|
||||
print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
|
||||
if bias_mismatch:
|
||||
print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
|
||||
if weight_mismatch or bias_mismatch:
|
||||
module.to(dtype)
|
||||
|
||||
Reference in New Issue
Block a user