diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py index d45ca5ed3..600187c63 100644 --- a/src/axolotl/monkeypatch/trainer_grad_accum.py +++ b/src/axolotl/monkeypatch/trainer_grad_accum.py @@ -49,6 +49,7 @@ ORIGINAL_LLAMA_FCLM_CODE = """ cache_position=cache_position, **kwargs, ) + hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])