fix wandb so mypy doesn't complain (#562)
* fix wandb so mypy doesn't complain * fix wandb so mypy doesn't complain * no need for mypy override anymore
This commit is contained in:
@@ -30,3 +30,4 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
|
wandb
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
src_dir = os.path.join(project_root, "src")
|
src_dir = os.path.join(project_root, "src")
|
||||||
|
|||||||
@@ -367,7 +367,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
output_scores=False,
|
output_scores=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def logits_to_tokens(logits) -> str:
|
def logits_to_tokens(logits) -> torch.Tensor:
|
||||||
probabilities = torch.softmax(logits, dim=-1)
|
probabilities = torch.softmax(logits, dim=-1)
|
||||||
# Get the predicted token ids (the ones with the highest probability)
|
# Get the predicted token ids (the ones with the highest probability)
|
||||||
predicted_token_ids = torch.argmax(probabilities, dim=-1)
|
predicted_token_ids = torch.argmax(probabilities, dim=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user