diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 81c804309..f9479cb59 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -692,7 +692,6 @@ def mistral_causallm_forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: