Add training callback to send predictions to WandB table (#521)
* WIP Add training callback to send predictions to WandB table * WIP improve wandb table reporting callback * WIP improve wandb table reporting callback (cont) * Add VSCode launching for debugging * Add tiny llama example * WIP attempt to improve post-eval prediction generation for table * WIP attempt to improve post-eval prediction generation for table - part 2 * WIP batch generation * WIP attempt to handle sample_packing using position_ids for wandb prediction table * WIP add code for debugging * Fix sample_packing support for wandb prediction table * Clean up code for PR review * Add eval_table_size, eval_table_max_new_tokens configs & clean up code * Clean up PR, delete VSCode config, add tiny-llama example * Add eval_table_size, eval_table_max_new_tokens documentation. Fix linting/formatting
This commit is contained in:
@@ -534,6 +534,9 @@ eval_steps: # leave empty to eval at each epoch
|
|||||||
save_total_limit: # checkpoints saved at a time
|
save_total_limit: # checkpoints saved at a time
|
||||||
max_steps:
|
max_steps:
|
||||||
|
|
||||||
|
eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
|
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|
||||||
# save model as safetensors (require safetensors package)
|
# save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ flash_attention: true
|
|||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
|
eval_table_size: 5
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ flash_attention: true
|
|||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
|
eval_table_size: 5
|
||||||
save_steps:
|
save_steps:
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|||||||
69
examples/llama-2/tiny-llama.yml
Normal file
69
examples/llama-2/tiny-llama.yml
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
||||||
|
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
|
||||||
|
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_llama_derived_model: true
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_run_id:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 3
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
eval_steps: 20
|
||||||
|
eval_table_size: 5
|
||||||
|
save_steps:
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
@@ -193,7 +193,7 @@ def flashattn_forward(
|
|||||||
# only on first autoregressive step q,k,v have same seqlen
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
is_causal = key_states.shape == query_states.shape
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None:
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
qkv = torch.stack(
|
||||||
[query_states, key_states, value_states], dim=2
|
[query_states, key_states, value_states], dim=2
|
||||||
@@ -261,6 +261,8 @@ def flashattn_forward(
|
|||||||
if attention_mask is not None
|
if attention_mask is not None
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
q_unpad,
|
q_unpad,
|
||||||
kv_unpad,
|
kv_unpad,
|
||||||
|
|||||||
@@ -11,10 +11,13 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import wandb
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
GenerationConfig,
|
||||||
|
Trainer,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainerControl,
|
TrainerControl,
|
||||||
TrainerState,
|
TrainerState,
|
||||||
@@ -323,3 +326,191 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
metrics[key] = val
|
metrics[key] = val
|
||||||
|
|
||||||
return BenchEvalCallback
|
return BenchEvalCallback
|
||||||
|
|
||||||
|
|
||||||
|
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
||||||
|
class LogPredictionCallback(TrainerCallback):
|
||||||
|
"""Callback to log prediction values during each evaluation"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.logged = False
|
||||||
|
|
||||||
|
def on_evaluate(
|
||||||
|
self,
|
||||||
|
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
train_dataloader, # pylint: disable=unused-argument
|
||||||
|
eval_dataloader,
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
eval_table_size = self.cfg.eval_table_size
|
||||||
|
|
||||||
|
if eval_table_size <= 0:
|
||||||
|
return control
|
||||||
|
|
||||||
|
trainer.model.eval()
|
||||||
|
device = torch.device(self.cfg.device)
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
max_new_tokens=self.cfg.eval_table_max_new_tokens,
|
||||||
|
bos_token_id=tokenizer.bos_token_id,
|
||||||
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
do_sample=False,
|
||||||
|
use_cache=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
output_scores=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def logits_to_tokens(logits) -> str:
|
||||||
|
probabilities = torch.softmax(logits, dim=-1)
|
||||||
|
# Get the predicted token ids (the ones with the highest probability)
|
||||||
|
predicted_token_ids = torch.argmax(probabilities, dim=-1)
|
||||||
|
return predicted_token_ids
|
||||||
|
|
||||||
|
def find_ranges(lst):
|
||||||
|
ranges = []
|
||||||
|
start = 0
|
||||||
|
for i in range(1, len(lst)):
|
||||||
|
if lst[i] == 0:
|
||||||
|
ranges.append((start, i - 1))
|
||||||
|
start = i
|
||||||
|
end = len(lst) - 1
|
||||||
|
ranges.append((start, end))
|
||||||
|
return ranges
|
||||||
|
|
||||||
|
def log_table_from_dataloader(name: str, table_dataloader):
|
||||||
|
table = wandb.Table(
|
||||||
|
columns=[
|
||||||
|
"id",
|
||||||
|
"Prompt",
|
||||||
|
"Correct Completion",
|
||||||
|
"Predicted Completion (model.generate)",
|
||||||
|
"Predicted Completion (trainer.prediction_step)",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
row_index = 0
|
||||||
|
|
||||||
|
for batch in tqdm(table_dataloader):
|
||||||
|
if row_index > eval_table_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
batch_labels = batch["labels"].to(device)
|
||||||
|
batch_input_ids = batch["input_ids"].to(device)
|
||||||
|
|
||||||
|
if "position_ids" in batch:
|
||||||
|
batch_pos_ids = batch["position_ids"].tolist()
|
||||||
|
else:
|
||||||
|
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||||
|
|
||||||
|
(_, batch_logits, _) = trainer.prediction_step(
|
||||||
|
trainer.model,
|
||||||
|
batch,
|
||||||
|
prediction_loss_only=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_token_ids_list = []
|
||||||
|
pred_step_token_ids_list = []
|
||||||
|
completion_token_ids_list = []
|
||||||
|
|
||||||
|
for input_ids_all, labels_all, pos_ids, logits in zip(
|
||||||
|
batch_input_ids,
|
||||||
|
batch_labels,
|
||||||
|
batch_pos_ids,
|
||||||
|
batch_logits,
|
||||||
|
):
|
||||||
|
if pos_ids is None:
|
||||||
|
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||||
|
else:
|
||||||
|
pos_ranges = find_ranges(pos_ids)
|
||||||
|
|
||||||
|
for pos_range in pos_ranges:
|
||||||
|
start, end = pos_range
|
||||||
|
if start == end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_ids = input_ids_all[start : end + 1]
|
||||||
|
labels = labels_all[start : end + 1]
|
||||||
|
|
||||||
|
tokens_without_loss = labels == IGNORE_INDEX
|
||||||
|
tokens_with_loss = labels != IGNORE_INDEX
|
||||||
|
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
|
||||||
|
prompt_token_includes = (
|
||||||
|
tokens_without_loss & tokens_exclude_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_token_ids = input_ids[prompt_token_includes]
|
||||||
|
prompt_token_ids_list.append(prompt_token_ids)
|
||||||
|
|
||||||
|
completion_token_ids = input_ids[tokens_with_loss]
|
||||||
|
completion_token_ids_list.append(completion_token_ids)
|
||||||
|
|
||||||
|
pred_step_token_ids = logits_to_tokens(
|
||||||
|
logits[start : end + 1]
|
||||||
|
)[tokens_with_loss]
|
||||||
|
pred_step_token_ids_list.append(pred_step_token_ids)
|
||||||
|
|
||||||
|
prompt_texts = tokenizer.batch_decode(
|
||||||
|
prompt_token_ids_list, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
completion_texts = tokenizer.batch_decode(
|
||||||
|
completion_token_ids_list, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
pred_step_texts = tokenizer.batch_decode(
|
||||||
|
pred_step_token_ids_list, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt_encoding = tokenizer(
|
||||||
|
prompt_texts, padding=True, return_tensors="pt"
|
||||||
|
).to(self.cfg.device)
|
||||||
|
predictions = trainer.model.generate(
|
||||||
|
**prompt_encoding, generation_config=generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||||
|
prediction_without_prompt_tokens_list = []
|
||||||
|
for prompt_token_ids, prediction_tokens in zip(
|
||||||
|
prompt_token_ids_list, prediction_all_tokens
|
||||||
|
):
|
||||||
|
prediction_without_prompt_tokens = prediction_tokens[
|
||||||
|
len(prompt_token_ids) :
|
||||||
|
]
|
||||||
|
prediction_without_prompt_tokens_list.append(
|
||||||
|
prediction_without_prompt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted_texts = tokenizer.batch_decode(
|
||||||
|
prediction_without_prompt_tokens_list, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for (
|
||||||
|
prompt_text,
|
||||||
|
completion_text,
|
||||||
|
prediction_text,
|
||||||
|
pred_step_text,
|
||||||
|
) in zip(
|
||||||
|
prompt_texts, completion_texts, predicted_texts, pred_step_texts
|
||||||
|
):
|
||||||
|
table.add_data(
|
||||||
|
row_index,
|
||||||
|
prompt_text,
|
||||||
|
completion_text,
|
||||||
|
prediction_text,
|
||||||
|
pred_step_text,
|
||||||
|
)
|
||||||
|
row_index += 1
|
||||||
|
|
||||||
|
wandb.run.log({f"{name} - Predictions vs Ground Truth": table})
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
log_table_from_dataloader("Eval", eval_dataloader)
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
return LogPredictionCallback
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ def normalize_config(cfg):
|
|||||||
)
|
)
|
||||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
cfg.eval_table_size = cfg.eval_table_size or 0
|
||||||
|
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
|
||||||
choose_device(cfg)
|
choose_device(cfg)
|
||||||
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
||||||
if cfg.ddp:
|
if cfg.ddp:
|
||||||
|
|||||||
@@ -296,10 +296,10 @@ def load_model(
|
|||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len >= model.config.max_position_embeddings
|
and cfg.sequence_len > model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
|
||||||
)
|
)
|
||||||
model.config.max_position_embeddings = cfg.sequence_len
|
model.config.max_position_embeddings = cfg.sequence_len
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
@@ -703,6 +704,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.use_wandb and cfg.eval_table_size > 0:
|
||||||
|
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
||||||
|
trainer.add_callback(LogPredictionCallback(cfg))
|
||||||
|
|
||||||
if cfg.do_bench_eval:
|
if cfg.do_bench_eval:
|
||||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user