make sure everything stays in the same dtype when using dpo + FSDP (#1559)
This commit is contained in:
@@ -993,3 +993,13 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
return model, lora_config
|
||||
|
||||
|
||||
def ensure_dtype(model, dtype=torch.bfloat16):
|
||||
for name, module in model.named_modules():
|
||||
try:
|
||||
if module.weight.dtype != dtype:
|
||||
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
|
||||
module.to(dtype)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user