Remove FP32 cast

This commit is contained in:
Casper
2023-12-07 16:28:25 +01:00
parent 8671ed5a0c
commit 4f9b172c47

View File

@@ -692,7 +692,6 @@ def mistral_causallm_forward(
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None: