@@ -385,7 +385,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
return ranges
|
return ranges
|
||||||
|
|
||||||
def log_table_from_dataloader(name: str, table_dataloader):
|
def log_table_from_dataloader(name: str, table_dataloader):
|
||||||
table = wandb.Table(
|
table = wandb.Table( # type: ignore[attr-defined]
|
||||||
columns=[
|
columns=[
|
||||||
"id",
|
"id",
|
||||||
"Prompt",
|
"Prompt",
|
||||||
@@ -506,7 +506,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
)
|
)
|
||||||
row_index += 1
|
row_index += 1
|
||||||
|
|
||||||
wandb.run.log({f"{name} - Predictions vs Ground Truth": table})
|
wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
log_table_from_dataloader("Eval", eval_dataloader)
|
log_table_from_dataloader("Eval", eval_dataloader)
|
||||||
|
|||||||
Reference in New Issue
Block a user