validation fixes 20240923 (#1925)
* validation fixes 20240923 * fix run name for wandb and defaults for chat template fields * fix gradio inference with llama chat template
This commit is contained in:
@@ -30,6 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
@@ -234,7 +235,8 @@ def do_inference_gradio(
|
|||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
|
default_tokens: Dict[str, str] = {}
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
for token, symbol in default_tokens.items():
|
||||||
# If the token isn't already specified in the config, add it
|
# If the token isn't already specified in the config, add it
|
||||||
@@ -242,10 +244,13 @@ def do_inference_gradio(
|
|||||||
tokenizer.add_special_tokens({token: symbol})
|
tokenizer.add_special_tokens({token: symbol})
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
|
chat_template_str = None
|
||||||
if prompter:
|
if prompter:
|
||||||
prompter_module = getattr(
|
prompter_module = getattr(
|
||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
elif cfg.chat_template:
|
||||||
|
chat_template_str = chat_templates(cfg.chat_template)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
@@ -259,7 +264,24 @@ def do_inference_gradio(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = instruction.strip()
|
prompt = instruction.strip()
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
if chat_template_str:
|
||||||
|
batch = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -282,6 +304,7 @@ def do_inference_gradio(
|
|||||||
streamer = TextIteratorStreamer(tokenizer)
|
streamer = TextIteratorStreamer(tokenizer)
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"inputs": batch["input_ids"].to(cfg.device),
|
"inputs": batch["input_ids"].to(cfg.device),
|
||||||
|
"attention_mask": batch["attention_mask"].to(cfg.device),
|
||||||
"generation_config": generation_config,
|
"generation_config": generation_config,
|
||||||
"streamer": streamer,
|
"streamer": streamer,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1417,6 +1417,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
report_to = []
|
report_to = []
|
||||||
if self.cfg.use_wandb:
|
if self.cfg.use_wandb:
|
||||||
report_to.append("wandb")
|
report_to.append("wandb")
|
||||||
|
if self.cfg.wandb_name:
|
||||||
|
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
|
||||||
if self.cfg.use_mlflow:
|
if self.cfg.use_mlflow:
|
||||||
report_to.append("mlflow")
|
report_to.append("mlflow")
|
||||||
if self.cfg.use_tensorboard:
|
if self.cfg.use_tensorboard:
|
||||||
@@ -1574,6 +1576,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
training_args = self.hook_post_create_training_args(training_args)
|
training_args = self.hook_post_create_training_args(training_args)
|
||||||
|
|
||||||
|
# unset run_name so wandb sets up experiment names
|
||||||
|
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
|
||||||
|
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
"padding": True, # True/"longest" is the default
|
"padding": True, # True/"longest" is the default
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -375,8 +375,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
prompter_params = {
|
prompter_params = {
|
||||||
"tokenizer": tokenizer,
|
"tokenizer": tokenizer,
|
||||||
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||||
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||||
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||||
"message_field_training_detail": ds_cfg.get(
|
"message_field_training_detail": ds_cfg.get(
|
||||||
"message_field_training_detail",
|
"message_field_training_detail",
|
||||||
|
|||||||
@@ -1017,12 +1017,20 @@ class AxolotlInputConfig(
|
|||||||
return neftune_noise_alpha
|
return neftune_noise_alpha
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check(self):
|
def check_rl_beta(self):
|
||||||
if self.dpo_beta and not self.rl_beta:
|
if self.dpo_beta and not self.rl_beta:
|
||||||
self.rl_beta = self.dpo_beta
|
self.rl_beta = self.dpo_beta
|
||||||
del self.dpo_beta
|
del self.dpo_beta
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_simpo_warmup(self):
|
||||||
|
if self.rl == "simpo" and self.warmup_ratio:
|
||||||
|
raise ValueError(
|
||||||
|
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_frozen(cls, data):
|
def check_frozen(cls, data):
|
||||||
|
|||||||
Reference in New Issue
Block a user