fix wandb so mypy doesn't complain (#562)

* fix wandb so mypy doesn't complain

* fix wandb so mypy doesn't complain

* no need for mypy override anymore
This commit is contained in:
Wing Lian
2023-09-13 10:36:16 -04:00
committed by GitHub
parent 5b67ea98a6
commit bf0804447c
4 changed files with 3 additions and 2 deletions

View File

@@ -30,3 +30,4 @@ scipy
scikit-learn==1.2.2
pynvml
art
wandb

View File

@@ -26,7 +26,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb import setup_wandb_env_vars
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")

View File

@@ -367,7 +367,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
output_scores=False,
)
def logits_to_tokens(logits) -> str:
def logits_to_tokens(logits) -> torch.Tensor:
probabilities = torch.softmax(logits, dim=-1)
# Get the predicted token ids (the ones with the highest probability)
predicted_token_ids = torch.argmax(probabilities, dim=-1)