@@ -385,7 +385,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
||||
return ranges
|
||||
|
||||
def log_table_from_dataloader(name: str, table_dataloader):
|
||||
table = wandb.Table(
|
||||
table = wandb.Table( # type: ignore[attr-defined]
|
||||
columns=[
|
||||
"id",
|
||||
"Prompt",
|
||||
@@ -506,7 +506,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
||||
)
|
||||
row_index += 1
|
||||
|
||||
wandb.run.log({f"{name} - Predictions vs Ground Truth": table})
|
||||
wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
|
||||
|
||||
if is_main_process():
|
||||
log_table_from_dataloader("Eval", eval_dataloader)
|
||||
|
||||
Reference in New Issue
Block a user