fix wandb so mypy doesn't complain (#562)
* fix wandb so mypy doesn't complain * fix wandb so mypy doesn't complain * no need for mypy override anymore
This commit is contained in:
@@ -367,7 +367,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
||||
output_scores=False,
|
||||
)
|
||||
|
||||
def logits_to_tokens(logits) -> str:
|
||||
def logits_to_tokens(logits) -> torch.Tensor:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
# Get the predicted token ids (the ones with the highest probability)
|
||||
predicted_token_ids = torch.argmax(probabilities, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user