fix: ds3 and fsdp lmbench eval (#2102) [ski[p ci]
* fix: ds3 and fsdp lmbench eval * chore: update comment * fix: test signature
This commit is contained in:
@@ -28,6 +28,7 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
@@ -46,6 +47,7 @@ from axolotl.utils.distributed import (
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
@@ -64,7 +66,10 @@ class EvalFirstStepCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and state.global_step == 1
|
||||
):
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
@@ -375,7 +380,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
for metric in self.cfg.eval_causal_lm_metrics:
|
||||
if metric == "perplexity":
|
||||
max_seq_len = self.cfg.eval_max_new_tokens
|
||||
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
|
||||
metrics[metric] = Perplexity(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
metrics[metric] = evaluate.load(metric)
|
||||
@@ -392,8 +400,11 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
eval_dataloader,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
trainer.model.eval()
|
||||
device = torch.device(self.cfg.device)
|
||||
trainer.model_wrapped.eval()
|
||||
|
||||
device = torch.device(
|
||||
self.cfg.device
|
||||
) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
generation_config = GenerationConfig(
|
||||
@@ -430,6 +441,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
for k in metric._feature_names() # pylint: disable=protected-access
|
||||
if k in kwargs
|
||||
}
|
||||
|
||||
if isinstance(metric, Perplexity):
|
||||
metric_kwargs["model"] = trainer.model_wrapped
|
||||
|
||||
metric_score = metric.compute(**metric_kwargs)
|
||||
return (
|
||||
metric_score["score"]
|
||||
@@ -465,89 +480,97 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
def predict_with_generate():
|
||||
eval_src, eval_pred, eval_ref = [], [], []
|
||||
|
||||
for batch in tqdm(eval_dataloader):
|
||||
batch_labels = batch["labels"].to(device)
|
||||
batch_input_ids = batch["input_ids"].to(device)
|
||||
with unwrap_model_for_generation(
|
||||
trainer.model_wrapped, trainer.accelerator
|
||||
) as unwrapped_model:
|
||||
for batch in tqdm(eval_dataloader, disable=not is_main_process()):
|
||||
batch_labels = batch["labels"].to(device)
|
||||
batch_input_ids = batch["input_ids"].to(device)
|
||||
|
||||
if "position_ids" in batch:
|
||||
batch_pos_ids = batch["position_ids"].tolist()
|
||||
else:
|
||||
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||
|
||||
prompt_token_ids_list = []
|
||||
completion_token_ids_list = []
|
||||
|
||||
for input_ids_all, labels_all, pos_ids in zip(
|
||||
batch_input_ids,
|
||||
batch_labels,
|
||||
batch_pos_ids,
|
||||
):
|
||||
if pos_ids is None:
|
||||
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||
if "position_ids" in batch:
|
||||
batch_pos_ids = batch["position_ids"].tolist()
|
||||
else:
|
||||
pos_ranges = find_ranges(pos_ids)
|
||||
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||
|
||||
for pos_range in pos_ranges:
|
||||
start, end = pos_range
|
||||
if start == end:
|
||||
continue
|
||||
prompt_token_ids_list = []
|
||||
completion_token_ids_list = []
|
||||
|
||||
input_ids = input_ids_all[start : end + 1]
|
||||
labels = labels_all[start : end + 1]
|
||||
for input_ids_all, labels_all, pos_ids in zip(
|
||||
batch_input_ids,
|
||||
batch_labels,
|
||||
batch_pos_ids,
|
||||
):
|
||||
if pos_ids is None:
|
||||
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||
else:
|
||||
pos_ranges = find_ranges(pos_ids)
|
||||
|
||||
tokens_without_loss = labels == IGNORE_INDEX
|
||||
tokens_with_loss = labels != IGNORE_INDEX
|
||||
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
|
||||
prompt_token_includes = (
|
||||
tokens_without_loss & tokens_exclude_padding
|
||||
for pos_range in pos_ranges:
|
||||
start, end = pos_range
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
input_ids = input_ids_all[start : end + 1]
|
||||
labels = labels_all[start : end + 1]
|
||||
|
||||
tokens_without_loss = labels == IGNORE_INDEX
|
||||
tokens_with_loss = labels != IGNORE_INDEX
|
||||
tokens_exclude_padding = (
|
||||
input_ids != tokenizer.pad_token_id
|
||||
)
|
||||
prompt_token_includes = (
|
||||
tokens_without_loss & tokens_exclude_padding
|
||||
)
|
||||
|
||||
prompt_token_ids = input_ids[prompt_token_includes]
|
||||
prompt_token_ids_list.append(prompt_token_ids)
|
||||
|
||||
completion_token_ids = input_ids[tokens_with_loss]
|
||||
completion_token_ids_list.append(completion_token_ids)
|
||||
|
||||
prompt_texts = tokenizer.batch_decode(
|
||||
prompt_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
completion_texts = tokenizer.batch_decode(
|
||||
completion_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_encoding = tokenizer(
|
||||
prompt_texts, padding=True, return_tensors="pt"
|
||||
).to(device)
|
||||
|
||||
predictions = unwrapped_model.generate(
|
||||
**prompt_encoding, generation_config=generation_config
|
||||
)
|
||||
|
||||
prompt_token_ids = input_ids[prompt_token_includes]
|
||||
prompt_token_ids_list.append(prompt_token_ids)
|
||||
del prompt_encoding
|
||||
|
||||
completion_token_ids = input_ids[tokens_with_loss]
|
||||
completion_token_ids_list.append(completion_token_ids)
|
||||
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||
prediction_without_prompt_tokens_list = []
|
||||
for prompt_token_ids, prediction_tokens in zip(
|
||||
prompt_token_ids_list, prediction_all_tokens
|
||||
):
|
||||
prediction_without_prompt_tokens = prediction_tokens[
|
||||
len(prompt_token_ids) :
|
||||
]
|
||||
prediction_without_prompt_tokens_list.append(
|
||||
prediction_without_prompt_tokens
|
||||
)
|
||||
|
||||
prompt_texts = tokenizer.batch_decode(
|
||||
prompt_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
completion_texts = tokenizer.batch_decode(
|
||||
completion_token_ids_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_encoding = tokenizer(
|
||||
prompt_texts, padding=True, return_tensors="pt"
|
||||
).to(self.cfg.device)
|
||||
predictions = trainer.model.generate(
|
||||
**prompt_encoding, generation_config=generation_config
|
||||
predicted_texts = tokenizer.batch_decode(
|
||||
prediction_without_prompt_tokens_list,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||
prediction_without_prompt_tokens_list = []
|
||||
for prompt_token_ids, prediction_tokens in zip(
|
||||
prompt_token_ids_list, prediction_all_tokens
|
||||
):
|
||||
prediction_without_prompt_tokens = prediction_tokens[
|
||||
len(prompt_token_ids) :
|
||||
]
|
||||
prediction_without_prompt_tokens_list.append(
|
||||
prediction_without_prompt_tokens
|
||||
)
|
||||
|
||||
predicted_texts = tokenizer.batch_decode(
|
||||
prediction_without_prompt_tokens_list, skip_special_tokens=True
|
||||
)
|
||||
|
||||
eval_src.extend(prompt_texts)
|
||||
eval_pred.extend(predicted_texts)
|
||||
eval_ref.extend(completion_texts)
|
||||
eval_src.extend(prompt_texts)
|
||||
eval_pred.extend(predicted_texts)
|
||||
eval_ref.extend(completion_texts)
|
||||
|
||||
return eval_src, eval_pred, eval_ref
|
||||
|
||||
if is_main_process():
|
||||
eval_preds = predict_with_generate()
|
||||
trainer.log(evaluate_preds(*eval_preds))
|
||||
eval_preds = predict_with_generate()
|
||||
trainer.log(evaluate_preds(*eval_preds))
|
||||
|
||||
return control
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from transformers.modeling_outputs import CausalLMOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
|
||||
class Perplexity:
|
||||
"""
|
||||
@@ -17,16 +19,13 @@ class Perplexity:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_seq_len: int,
|
||||
stride: int = 512,
|
||||
) -> None:
|
||||
self.max_seq_len = max_seq_len
|
||||
self.stride = stride
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = model.device
|
||||
self.name = "perplexity"
|
||||
|
||||
def _feature_names(self) -> List[str]:
|
||||
@@ -34,6 +33,7 @@ class Perplexity:
|
||||
|
||||
def compute(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
references: Optional[List[str]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
@@ -41,17 +41,21 @@ class Perplexity:
|
||||
"""
|
||||
assert references is not None, "Missing parameter: references"
|
||||
|
||||
model.eval()
|
||||
|
||||
references_tokenized = self.tokenizer(
|
||||
references, return_tensors="pt", padding=True, truncation=True
|
||||
)
|
||||
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
|
||||
input_ids = input_ids.to(self.device)
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
sequence_length = input_ids.size(1)
|
||||
|
||||
losses = []
|
||||
prev_end_loc = 0
|
||||
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
|
||||
for begin_loc in tqdm(
|
||||
range(0, sequence_length, self.stride), disable=not is_main_process()
|
||||
):
|
||||
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
|
||||
trg_len = end_loc - prev_end_loc
|
||||
input_ids_slice = input_ids[:, begin_loc:end_loc]
|
||||
@@ -59,7 +63,7 @@ class Perplexity:
|
||||
labels_slice[:, :-trg_len] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
outputs: CausalLMOutput = self.model(
|
||||
outputs: CausalLMOutput = model(
|
||||
input_ids=input_ids_slice, labels=labels_slice
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user