sample gen support sft (#3240) [skip ci]
* add:parameters + callback * sft core + logging * indentation fix * logger fix * loger fix in sft * gen sample on eval * lint * deprecation
This commit is contained in:
@@ -122,6 +122,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
ColabCallback = colab_inference_post_train_callback(trainer)
|
||||
callbacks.append(ColabCallback(self.cfg))
|
||||
|
||||
if getattr(self.cfg, "generate_samples", False):
|
||||
from axolotl.utils.callbacks.generation import SFTGenerationCallback
|
||||
|
||||
callbacks.append(SFTGenerationCallback(trainer))
|
||||
LOG.info("SFT sample generation enabled")
|
||||
|
||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||
return callbacks
|
||||
|
||||
|
||||
84
src/axolotl/utils/callbacks/generation.py
Normal file
84
src/axolotl/utils/callbacks/generation.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Callback for generating samples during SFT/Pretrain training."""
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from axolotl.utils.generation.sft import generate_samples
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SFTGenerationCallback(TrainerCallback):
|
||||
"""Callback for generating samples during SFT/Pretrain training."""
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def on_evaluate(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate samples at specified intervals."""
|
||||
cfg = self.trainer.axolotl_cfg
|
||||
|
||||
if not getattr(cfg, "generate_samples", False):
|
||||
return
|
||||
|
||||
dataloader = None
|
||||
try:
|
||||
if getattr(self.trainer, "eval_dataset", None) is not None:
|
||||
dataloader = self.trainer.get_eval_dataloader()
|
||||
LOG.info(
|
||||
f"Using eval dataloader for generation at step {state.global_step}"
|
||||
)
|
||||
except Exception as e:
|
||||
LOG.warning(f"Could not get eval dataloader: {e}")
|
||||
dataloader = None
|
||||
|
||||
if dataloader is None:
|
||||
dataloader = self.trainer.get_train_dataloader()
|
||||
LOG.info(
|
||||
f"Using train dataloader for generation at step {state.global_step}"
|
||||
)
|
||||
|
||||
samples = generate_samples(
|
||||
model=self.trainer.model,
|
||||
tokenizer=self.trainer.processing_class,
|
||||
dataloader=dataloader,
|
||||
num_generation_samples=getattr(cfg, "num_generation_samples", 3),
|
||||
max_new_tokens=getattr(cfg, "generation_max_new_tokens", 50),
|
||||
temperature=getattr(cfg, "generation_temperature", 0.7),
|
||||
top_p=getattr(cfg, "generation_top_p", None),
|
||||
top_k=getattr(cfg, "generation_top_k", None),
|
||||
do_sample=getattr(cfg, "generation_do_sample", True),
|
||||
prompt_ratio=getattr(cfg, "generation_prompt_ratio", 0.5),
|
||||
)
|
||||
self._log_samples(samples, state.global_step)
|
||||
|
||||
def _log_samples(self, samples: list, step: int):
|
||||
"""Log generated samples to console and W&B."""
|
||||
from axolotl.utils.generation.sft import format_generation_for_logging
|
||||
|
||||
for i, sample in enumerate(samples):
|
||||
console_text, wandb_text = format_generation_for_logging(sample, i, step)
|
||||
|
||||
LOG.info(console_text)
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
f"samples/sample_{i + 1}": wandb.Html(
|
||||
f"<pre>{wandb_text}</pre>"
|
||||
)
|
||||
},
|
||||
step=step,
|
||||
)
|
||||
except (ImportError, Exception):
|
||||
pass
|
||||
5
src/axolotl/utils/generation/__init__.py
Normal file
5
src/axolotl/utils/generation/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Generation utilities for monitoring during training."""
|
||||
|
||||
from .sft import format_generation_for_logging, generate_samples
|
||||
|
||||
__all__ = ["generate_samples", "format_generation_for_logging"]
|
||||
174
src/axolotl/utils/generation/sft.py
Normal file
174
src/axolotl/utils/generation/sft.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Sample generation utilities for SFT/Pretrain training."""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import extract_model_from_parallel
|
||||
from colorama import Fore, Style
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def generate_samples(
|
||||
model: torch.nn.Module,
|
||||
tokenizer: Any,
|
||||
dataloader: Any,
|
||||
num_generation_samples: int = 3,
|
||||
max_new_tokens: int = 50,
|
||||
temperature: float = 0.7,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
do_sample: bool = True,
|
||||
prompt_ratio: float = 0.5,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Generate samples from the model during training for monitoring.
|
||||
|
||||
Args:
|
||||
model: The model to generate from
|
||||
tokenizer: The tokenizer to use for encoding/decoding
|
||||
dataloader: Dataloader to sample prompts from
|
||||
num_generation_samples: Number of samples to generate
|
||||
max_new_tokens: Maximum new tokens to generate
|
||||
temperature: Sampling temperature (0.0 = greedy)
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
do_sample: Whether to use sampling vs greedy decoding
|
||||
prompt_ratio: Ratio of sequence to use as prompt (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
List of dicts with 'prompt', 'generated', and 'full_text' keys
|
||||
"""
|
||||
unwrapped_model = extract_model_from_parallel(model)
|
||||
|
||||
training = unwrapped_model.training
|
||||
unwrapped_model.eval()
|
||||
|
||||
device = next(unwrapped_model.parameters()).device
|
||||
|
||||
generations = []
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
samples_collected = 0
|
||||
|
||||
for batch in dataloader:
|
||||
if samples_collected >= num_generation_samples:
|
||||
break
|
||||
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch.get("attention_mask")
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(device)
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
indices = torch.randperm(batch_size)[
|
||||
: num_generation_samples - samples_collected
|
||||
]
|
||||
|
||||
for idx in indices:
|
||||
if samples_collected >= num_generation_samples:
|
||||
break
|
||||
|
||||
sequence = input_ids[idx]
|
||||
|
||||
if attention_mask is not None:
|
||||
seq_len = attention_mask[idx].sum().item()
|
||||
else:
|
||||
seq_len = sequence.shape[0]
|
||||
|
||||
if seq_len < 5:
|
||||
continue
|
||||
|
||||
prompt_len = max(1, int(seq_len * prompt_ratio))
|
||||
prompt_ids = sequence[:prompt_len].unsqueeze(0)
|
||||
|
||||
try:
|
||||
generation_config = {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"do_sample": do_sample,
|
||||
"pad_token_id": tokenizer.pad_token_id
|
||||
if tokenizer.pad_token_id is not None
|
||||
else tokenizer.eos_token_id,
|
||||
}
|
||||
|
||||
if do_sample:
|
||||
generation_config["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
generation_config["top_p"] = top_p
|
||||
if top_k is not None:
|
||||
generation_config["top_k"] = top_k
|
||||
|
||||
generated_ids = unwrapped_model.generate(
|
||||
prompt_ids, **generation_config
|
||||
)
|
||||
|
||||
prompt_text = tokenizer.decode(
|
||||
prompt_ids[0], skip_special_tokens=True
|
||||
)
|
||||
generated_text = tokenizer.decode(
|
||||
generated_ids[0][prompt_len:], skip_special_tokens=True
|
||||
)
|
||||
full_text = tokenizer.decode(
|
||||
generated_ids[0], skip_special_tokens=True
|
||||
)
|
||||
|
||||
generations.append(
|
||||
{
|
||||
"prompt": prompt_text,
|
||||
"generated": generated_text,
|
||||
"full_text": full_text,
|
||||
}
|
||||
)
|
||||
|
||||
samples_collected += 1
|
||||
|
||||
except Exception as e:
|
||||
LOG.warning(f"Failed to generate sample: {e}", exc_info=True)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
LOG.warning(f"Error during sample generation: {e}", exc_info=True)
|
||||
|
||||
if training:
|
||||
unwrapped_model.train()
|
||||
else:
|
||||
unwrapped_model.eval()
|
||||
|
||||
return generations
|
||||
|
||||
|
||||
def format_generation_for_logging(
|
||||
sample: dict, sample_idx: int, step: int
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Format a generation sample for pretty logging.
|
||||
|
||||
Args:
|
||||
sample: Dict with 'prompt', 'generated', and 'full_text' keys
|
||||
sample_idx: Index of the sample
|
||||
step: Current training step
|
||||
|
||||
Returns:
|
||||
Tuple of (console_text, wandb_text)
|
||||
"""
|
||||
console_text = (
|
||||
f"\n{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
|
||||
f"{Style.BRIGHT}{Fore.GREEN}Sample {sample_idx + 1} (Step {step}){Style.RESET_ALL}\n"
|
||||
f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
|
||||
f"{Style.BRIGHT}{Fore.YELLOW}[PROMPT]{Style.RESET_ALL}\n{sample['prompt']}\n\n"
|
||||
f"{Style.BRIGHT}{Fore.MAGENTA}[GENERATED]{Style.RESET_ALL}\n{sample['generated']}\n"
|
||||
f"{Style.BRIGHT}{Fore.CYAN}{'=' * 80}{Style.RESET_ALL}\n"
|
||||
)
|
||||
wandb_text = (
|
||||
f"\n{'=' * 80}\n"
|
||||
f"Sample {sample_idx + 1} (Step {step})\n"
|
||||
f"{'=' * 80}\n"
|
||||
f"[PROMPT]\n{sample['prompt']}\n\n"
|
||||
f"[GENERATED]\n{sample['generated']}\n"
|
||||
f"{'=' * 80}\n"
|
||||
)
|
||||
|
||||
return console_text, wandb_text
|
||||
@@ -338,18 +338,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
ddp_find_unused_parameters: bool | None = None
|
||||
|
||||
eval_table_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0"
|
||||
},
|
||||
)
|
||||
eval_max_new_tokens: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Total number of tokens generated for predictions sent to wandb. Default is 128"
|
||||
},
|
||||
)
|
||||
do_causal_lm_eval: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -1106,6 +1094,46 @@ class AxolotlInputConfig(
|
||||
"description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html"
|
||||
},
|
||||
)
|
||||
generate_samples: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Enable sample generation during training for monitoring"
|
||||
},
|
||||
)
|
||||
num_generation_samples: int | None = Field(
|
||||
default=3,
|
||||
json_schema_extra={
|
||||
"description": "Number of samples to generate at each interval"
|
||||
},
|
||||
)
|
||||
generation_max_new_tokens: int | None = Field(
|
||||
default=50,
|
||||
json_schema_extra={"description": "Maximum new tokens to generate per sample"},
|
||||
)
|
||||
generation_temperature: float | None = Field(
|
||||
default=0.7,
|
||||
json_schema_extra={
|
||||
"description": "Temperature for sample generation (0.0 = greedy)"
|
||||
},
|
||||
)
|
||||
generation_top_p: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Nucleus sampling parameter for generation"},
|
||||
)
|
||||
generation_top_k: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Top-k sampling parameter for generation"},
|
||||
)
|
||||
generation_prompt_ratio: float | None = Field(
|
||||
default=0.5,
|
||||
json_schema_extra={"description": "Ratio of input to use as prompt (0.0-1.0)"},
|
||||
)
|
||||
generation_do_sample: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use sampling (vs greedy decoding)"
|
||||
},
|
||||
)
|
||||
|
||||
@field_serializer("datasets")
|
||||
def datasets_serializer(
|
||||
|
||||
@@ -17,6 +17,8 @@ class DeprecatedParameters(BaseModel):
|
||||
noisy_embedding_alpha: float | None = None
|
||||
dpo_beta: float | None = None
|
||||
evaluation_strategy: str | None = None
|
||||
eval_table_size: int | None = None
|
||||
eval_max_new_tokens: int | None = None
|
||||
|
||||
@field_validator("max_packed_sequence_len")
|
||||
@classmethod
|
||||
@@ -55,6 +57,27 @@ class DeprecatedParameters(BaseModel):
|
||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||
return evaluation_strategy
|
||||
|
||||
@field_validator("eval_table_size")
|
||||
@classmethod
|
||||
def validate_eval_table_size(cls, eval_table_size):
|
||||
if eval_table_size is not None:
|
||||
LOG.warning(
|
||||
"eval_table_size is deprecated and superseded by generate_samples config. "
|
||||
"Please use generate_samples: true and num_generation_samples instead. "
|
||||
"The LogPredictionCallback is replaced by the new sample generation feature."
|
||||
)
|
||||
return eval_table_size
|
||||
|
||||
@field_validator("eval_max_new_tokens")
|
||||
@classmethod
|
||||
def validate_eval_max_new_tokens(cls, eval_max_new_tokens):
|
||||
if eval_max_new_tokens is not None:
|
||||
LOG.warning(
|
||||
"eval_max_new_tokens is deprecated and superseded by generate_samples config. "
|
||||
"Please use generation_max_new_tokens instead."
|
||||
)
|
||||
return eval_max_new_tokens
|
||||
|
||||
|
||||
class RemappedParameters(BaseModel):
|
||||
"""Parameters that have been remapped to other names"""
|
||||
|
||||
Reference in New Issue
Block a user