Remove FP32 cast
This commit is contained in:
@@ -692,7 +692,6 @@ def mistral_causallm_forward(
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
||||
Reference in New Issue
Block a user