validation fixes 20240923 (#1925)
* validation fixes 20240923 * fix run name for wandb and defaults for chat template fields * fix gradio inference with llama chat template
This commit is contained in:
@@ -30,6 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
@@ -234,7 +235,8 @@ def do_inference_gradio(
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
default_tokens: Dict[str, str] = {}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
@@ -242,10 +244,13 @@ def do_inference_gradio(
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
chat_template_str = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
@@ -259,7 +264,24 @@ def do_inference_gradio(
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
if chat_template_str:
|
||||
batch = tokenizer.apply_chat_template(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
@@ -282,6 +304,7 @@ def do_inference_gradio(
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {
|
||||
"inputs": batch["input_ids"].to(cfg.device),
|
||||
"attention_mask": batch["attention_mask"].to(cfg.device),
|
||||
"generation_config": generation_config,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
@@ -1417,6 +1417,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
report_to = []
|
||||
if self.cfg.use_wandb:
|
||||
report_to.append("wandb")
|
||||
if self.cfg.wandb_name:
|
||||
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
|
||||
if self.cfg.use_mlflow:
|
||||
report_to.append("mlflow")
|
||||
if self.cfg.use_tensorboard:
|
||||
@@ -1574,6 +1576,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
|
||||
# unset run_name so wandb sets up experiment names
|
||||
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
|
||||
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
|
||||
None
|
||||
)
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
}
|
||||
|
||||
@@ -375,8 +375,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail",
|
||||
|
||||
@@ -1017,12 +1017,20 @@ class AxolotlInputConfig(
|
||||
return neftune_noise_alpha
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check(self):
|
||||
def check_rl_beta(self):
|
||||
if self.dpo_beta and not self.rl_beta:
|
||||
self.rl_beta = self.dpo_beta
|
||||
del self.dpo_beta
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_simpo_warmup(self):
|
||||
if self.rl == "simpo" and self.warmup_ratio:
|
||||
raise ValueError(
|
||||
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_frozen(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user